Improve Milenage library for [R1-R5] (#1153)

This commit is contained in:
Sukchan Lee 2021-09-01 19:38:36 +09:00
parent 233db575ea
commit 1354947164
1 changed files with 99 additions and 4 deletions

View File

@ -22,10 +22,12 @@
#define os_memcmp memcmp
#define os_memcmp_const memcmp
int aes_128_encrypt_block(const uint8_t *key,
const uint8_t *in, uint8_t *out);
static void ShiftBits(uint8_t r, uint8_t rijndaelInput[16],
uint8_t temp[16], const uint8_t opc[16]);
static uint8_t *bits_shift(uint32_t bit_valid, uint8_t *dst,
uint8_t *src, uint32_t numBits);
int aes_128_encrypt_block(const uint8_t *key,
static int aes_128_encrypt_block(const uint8_t *key,
const uint8_t *in, uint8_t *out)
{
const int key_bits = 128;
@ -55,8 +57,10 @@ int milenage_f1(const uint8_t *opc, const uint8_t *k,
{
uint8_t tmp1[16], tmp2[16], tmp3[16];
int i;
#if 1 /* R1-R5 issues1153 */
uint8_t r1 = 64;
#endif
/* tmp1 = TEMP = E_K(RAND XOR OP_C) */
for (i = 0; i < 16; i++)
tmp1[i] = _rand[i] ^ opc[i];
if (aes_128_encrypt_block(k, tmp1, tmp1))
@ -70,8 +74,12 @@ int milenage_f1(const uint8_t *opc, const uint8_t *k,
/* OUT1 = E_K(TEMP XOR rot(IN1 XOR OP_C, r1) XOR c1) XOR OP_C */
/* rotate (tmp2 XOR OP_C) by r1 (= 0x40 = 8 bytes) */
#if 0 /* R1-R5 issues1153 */
for (i = 0; i < 16; i++)
tmp3[(i + 8) % 16] = tmp2[i] ^ opc[i];
#else
ShiftBits(r1, tmp3, tmp2, opc);
#endif
/* XOR with TEMP = E_K(RAND XOR OP_C) */
for (i = 0; i < 16; i++)
tmp3[i] ^= tmp1[i];
@ -109,6 +117,13 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k,
uint8_t tmp1[16], tmp2[16], tmp3[16];
int i;
#if 1 /* R1-R5 issues1153 */
uint8_t r2 = 0;
uint8_t r3 = 32;
uint8_t r4 = 64;
uint8_t r5 = 96;
#endif
/* tmp2 = TEMP = E_K(RAND XOR OP_C) */
for (i = 0; i < 16; i++)
tmp1[i] = _rand[i] ^ opc[i];
@ -122,8 +137,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k,
/* f2 and f5 */
/* rotate by r2 (= 0, i.e., NOP) */
#if 0 /* R1-R5 issues1153 */
for (i = 0; i < 16; i++)
tmp1[i] = tmp2[i] ^ opc[i];
#else
ShiftBits(r2, tmp1, tmp2, opc);
#endif
tmp1[15] ^= 1; /* XOR c2 (= ..01) */
/* f5 || f2 = E_K(tmp1) XOR OP_c */
if (aes_128_encrypt_block(k, tmp1, tmp3))
@ -138,8 +157,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k,
/* f3 */
if (ck) {
/* rotate by r3 = 0x20 = 4 bytes */
#if 0 /* R1-R5 issues1153 */
for (i = 0; i < 16; i++)
tmp1[(i + 12) % 16] = tmp2[i] ^ opc[i];
#else
ShiftBits(r3, tmp1, tmp2, opc);
#endif
tmp1[15] ^= 2; /* XOR c3 (= ..02) */
if (aes_128_encrypt_block(k, tmp1, ck))
return -1;
@ -150,8 +173,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k,
/* f4 */
if (ik) {
/* rotate by r4 = 0x40 = 8 bytes */
#if 0 /* R1-R5 issues1153 */
for (i = 0; i < 16; i++)
tmp1[(i + 8) % 16] = tmp2[i] ^ opc[i];
#else
ShiftBits(r4, tmp1, tmp2, opc);
#endif
tmp1[15] ^= 4; /* XOR c4 (= ..04) */
if (aes_128_encrypt_block(k, tmp1, ik))
return -1;
@ -162,8 +189,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k,
/* f5* */
if (akstar) {
/* rotate by r5 = 0x60 = 12 bytes */
#if 0 /* R1-R5 issues1153 */
for (i = 0; i < 16; i++)
tmp1[(i + 4) % 16] = tmp2[i] ^ opc[i];
#else
ShiftBits(r5, tmp1, tmp2, opc);
#endif
tmp1[15] ^= 8; /* XOR c5 (= ..08) */
if (aes_128_encrypt_block(k, tmp1, tmp1))
return -1;
@ -364,3 +395,67 @@ void milenage_opc(const uint8_t *k, const uint8_t *op, uint8_t *opc)
opc[i] ^= op[i];
}
}
static void ShiftBits(uint8_t r, uint8_t rijndaelInput[16],
uint8_t temp[16], const uint8_t opc[16])
{
uint32_t deltlen = 16 - (r / 8);
uint32_t leftout = r % 8;
uint32_t i;
if (leftout == 0) {
for (i = 0; i < 16; i++) {
rijndaelInput[(i+deltlen) % 16] = temp[i] ^ opc[i];
}
} else {
uint8_t temp1[16];
uint32_t move_bits;
uint8_t temp2;
for (i = 0; i < 16; i++) {
temp1[(i + deltlen) % 16] = temp[i] ^ opc[i];
}
rijndaelInput[15] = 0;
move_bits = 8 - leftout;
bits_shift(move_bits, &rijndaelInput[0], temp1, (128 - leftout));
temp2 = temp1[0] >> (8-leftout);
rijndaelInput[15] |= temp2;
}
}
static uint8_t *bits_shift(uint32_t bit_valid, uint8_t *dst,
uint8_t *src, uint32_t numBits)
{
uint32_t bit_used = bit_valid;
uint32_t bit_empty = 8 - bit_used;
uint32_t numBytes = numBits >> 3;
uint32_t leftBits = numBits & 0x7;
uint32_t i = 0;
uint8_t *newDst = 0;
for (i = 0; i < numBytes; i++) {
dst[i] = (src[i] << bit_empty) | (src[i+1] >> bit_used);
}
if (leftBits) {
if (leftBits == bit_used) {
dst[numBytes] = src[numBytes] << bit_empty;
bit_valid = 8;
newDst = &src[numBytes+1];
} else if (leftBits < bit_used) {
dst[numBytes] = src[numBytes] << bit_empty;
bit_valid = bit_used - leftBits;
newDst = &src[numBytes];
} else {
dst[numBytes] = src[numBytes] << bit_empty |
(src[numBytes+1] >> bit_used);
bit_valid = 8 - (leftBits - bit_used);
newDst = &src[numBytes+1];
}
} else {
bit_valid = bit_used;
newDst = &src[numBytes];
}
return newDst;
}