Skip to content

Commit 8dc1866

Browse files
Merge pull request #25 from jacksonwalters/add_nist_sp_800_38F_key_wrapping
Add NIST SP 800-38F (key wrapping)
2 parents f5d33d9 + 310ae82 commit 8dc1866

File tree

3 files changed

+396
-0
lines changed

3 files changed

+396
-0
lines changed

include/kw.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef KW_H
2+
#define KW_H
3+
4+
#include <stddef.h>
5+
#include <stdint.h>
6+
#include "aes_wrapper.h"
7+
8+
/**
9+
* Wrap a plaintext key using AES Key Wrap (RFC 3394 / NIST SP 800-38F).
10+
*/
11+
int kw_wrap(const uint8_t *plaintext, size_t plen,
12+
uint8_t *ciphertext, const struct aes_ctx *ctx);
13+
14+
/**
15+
* Unwrap a wrapped key using AES Key Wrap.
16+
*/
17+
int kw_unwrap(const uint8_t *ciphertext, size_t clen,
18+
uint8_t *plaintext, const struct aes_ctx *ctx);
19+
20+
#endif

src/kw.c

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#include "kw.h"
2+
#include <string.h>
3+
#include <stdint.h>
4+
5+
static const uint8_t ICV1[8] = { 0xA6,0xA6,0xA6,0xA6,0xA6,0xA6,0xA6,0xA6 };
6+
7+
static void xor64(uint8_t *block, uint64_t t) {
8+
for (int i = 0; i < 8; i++) {
9+
block[i] ^= (uint8_t)(t >> (56 - 8 * i));
10+
}
11+
}
12+
13+
int kw_wrap(const uint8_t *plaintext, size_t plen,
14+
uint8_t *ciphertext, const struct aes_ctx *ctx)
15+
{
16+
if (!plaintext || !ciphertext || !ctx) return -1;
17+
if (plen % 8 != 0 || plen < 16) return -1;
18+
19+
size_t n = plen / 8;
20+
if (n > 32) return -1;
21+
22+
uint8_t A[8];
23+
memcpy(A, ICV1, 8);
24+
uint8_t R[32][8];
25+
26+
for (size_t i = 0; i < n; i++) {
27+
memcpy(R[i], plaintext + 8*i, 8);
28+
}
29+
30+
uint8_t B[16];
31+
for (size_t j = 0; j <= 5; j++) {
32+
for (size_t i = 0; i < n; i++) {
33+
memcpy(B, A, 8);
34+
memcpy(B+8, R[i], 8);
35+
aes_block_wrapper(B, B, ctx);
36+
uint64_t t = (uint64_t)(n * j + i + 1);
37+
memcpy(A, B, 8);
38+
xor64(A, t);
39+
memcpy(R[i], B+8, 8);
40+
}
41+
}
42+
43+
memcpy(ciphertext, A, 8);
44+
for (size_t i = 0; i < n; i++) {
45+
memcpy(ciphertext + 8*(i+1), R[i], 8);
46+
}
47+
48+
return 0;
49+
}
50+
51+
int kw_unwrap(const uint8_t *ciphertext, size_t clen,
52+
uint8_t *plaintext, const struct aes_ctx *ctx)
53+
{
54+
if (!ciphertext || !plaintext || !ctx) return -1;
55+
if (clen % 8 != 0 || clen < 24) return -1;
56+
57+
size_t n = clen/8 - 1;
58+
if (n > 32) return -1;
59+
60+
uint8_t A[8];
61+
memcpy(A, ciphertext, 8);
62+
uint8_t R[32][8];
63+
for (size_t i = 0; i < n; i++) {
64+
memcpy(R[i], ciphertext + 8*(i+1), 8);
65+
}
66+
67+
uint8_t B[16];
68+
for (int j = 5; j >= 0; j--) {
69+
for (int i = (int)n-1; i >= 0; i--) {
70+
uint64_t t = (uint64_t)(n * j + i + 1);
71+
xor64(A, t);
72+
memcpy(B, A, 8);
73+
memcpy(B+8, R[i], 8);
74+
aes_block_wrapper_dec(B, B, ctx);
75+
memcpy(A, B, 8);
76+
memcpy(R[i], B+8, 8);
77+
}
78+
}
79+
80+
if (memcmp(A, ICV1, 8) != 0) {
81+
return -1;
82+
}
83+
84+
for (size_t i = 0; i < n; i++) {
85+
memcpy(plaintext + 8*i, R[i], 8);
86+
}
87+
88+
return 0;
89+
}

tests/test_kw.c

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
#include <stdio.h>
2+
#include <stdint.h>
3+
#include <string.h>
4+
#include <stdlib.h>
5+
6+
#include "../include/kw.h"
7+
#include "../include/aes_wrapper.h"
8+
#include "../include/key_expansion_128.h"
9+
#include "../include/key_expansion_192.h"
10+
#include "../include/key_expansion_256.h"
11+
#include "../include/sbox.h"
12+
13+
/* small helper to print hex for debugging */
14+
static void hexprint(const uint8_t *buf, size_t len) {
15+
for (size_t i = 0; i < len; ++i) printf("%02X", buf[i]);
16+
}
17+
18+
/* prepare aes_ctx from KEK: generate sbox and round_keys appropriate for kek_len */
19+
static int make_aes_ctx(const uint8_t *kek, size_t kek_len, struct aes_ctx *ctx,
20+
uint8_t **round_keys_storage)
21+
{
22+
if (!kek || !ctx || !round_keys_storage) return -1;
23+
24+
uint8_t *sbox = (uint8_t*)malloc(256);
25+
if (!sbox) return -1;
26+
initialize_aes_sbox(sbox);
27+
28+
/* Round-key sizes in your implementation:
29+
- AES-128 -> 176 bytes
30+
- AES-192 -> 208 bytes
31+
- AES-256 -> 240 bytes
32+
*/
33+
uint8_t *rk = NULL;
34+
if (kek_len == 16) {
35+
rk = (uint8_t*)malloc(176);
36+
if (!rk) { free(sbox); return -1; }
37+
aes_key_expansion_128(kek, rk, sbox);
38+
} else if (kek_len == 24) {
39+
rk = (uint8_t*)malloc(208);
40+
if (!rk) { free(sbox); return -1; }
41+
aes_key_expansion_192(kek, rk, sbox);
42+
} else if (kek_len == 32) {
43+
rk = (uint8_t*)malloc(240);
44+
if (!rk) { free(sbox); return -1; }
45+
aes_key_expansion_256(kek, rk, sbox);
46+
} else {
47+
free(sbox);
48+
return -1;
49+
}
50+
51+
ctx->round_keys = rk;
52+
ctx->sbox = sbox;
53+
ctx->key_len = kek_len;
54+
*round_keys_storage = rk;
55+
return 0;
56+
}
57+
58+
static void free_aes_ctx(struct aes_ctx *ctx)
59+
{
60+
if (!ctx) return;
61+
if (ctx->round_keys) free((void*)ctx->round_keys);
62+
if (ctx->sbox) free((void*)ctx->sbox);
63+
ctx->round_keys = NULL;
64+
ctx->sbox = NULL;
65+
}
66+
67+
/* A test vector driver that does wrap -> unwrap and checks round-trip.
68+
If expected_wrapped != NULL, it will also compare the produced wrapped
69+
bytes against that expected value (byte-for-byte). */
70+
static int run_kw_vector(const uint8_t *kek, size_t kek_len,
71+
const uint8_t *plaintext, size_t plen,
72+
const uint8_t *expected_wrapped, size_t expected_wrapped_len)
73+
{
74+
struct aes_ctx ctx;
75+
uint8_t *rk_storage = NULL;
76+
if (make_aes_ctx(kek, kek_len, &ctx, &rk_storage) != 0) {
77+
fprintf(stderr, "Failed to init AES ctx for kek_len=%zu\n", kek_len);
78+
return 1;
79+
}
80+
81+
size_t clen = plen + 8;
82+
uint8_t *wrapped = (uint8_t*)malloc(clen);
83+
uint8_t *unwrapped = (uint8_t*)malloc(plen);
84+
if (!wrapped || !unwrapped) {
85+
fprintf(stderr, "Out of memory\n");
86+
free(wrapped); free(unwrapped);
87+
free_aes_ctx(&ctx);
88+
return 1;
89+
}
90+
memset(wrapped, 0, clen);
91+
memset(unwrapped, 0, plen);
92+
93+
int rc = kw_wrap(plaintext, plen, wrapped, &ctx);
94+
if (rc != 0) {
95+
fprintf(stderr, "kw_wrap failed (kek_len=%zu plen=%zu)\n", kek_len, plen);
96+
free(wrapped); free(unwrapped); free_aes_ctx(&ctx);
97+
return 1;
98+
}
99+
100+
if (expected_wrapped != NULL) {
101+
if (expected_wrapped_len != clen || memcmp(wrapped, expected_wrapped, clen) != 0) {
102+
fprintf(stderr, "KW wrap mismatch (kek_len=%zu plen=%zu)\nExpected: ", kek_len, plen);
103+
hexprint(expected_wrapped, expected_wrapped_len);
104+
fprintf(stderr, "\nGot: ");
105+
hexprint(wrapped, clen);
106+
fprintf(stderr, "\n");
107+
free(wrapped); free(unwrapped); free_aes_ctx(&ctx);
108+
return 1;
109+
}
110+
}
111+
112+
rc = kw_unwrap(wrapped, clen, unwrapped, &ctx);
113+
if (rc != 0) {
114+
fprintf(stderr, "kw_unwrap failed (kek_len=%zu plen=%zu)\n", kek_len, plen);
115+
free(wrapped); free(unwrapped); free_aes_ctx(&ctx);
116+
return 1;
117+
}
118+
119+
if (memcmp(unwrapped, plaintext, plen) != 0) {
120+
fprintf(stderr, "KW round-trip mismatch (kek_len=%zu plen=%zu)\n", kek_len, plen);
121+
fprintf(stderr, "Recovered: ");
122+
hexprint(unwrapped, plen);
123+
fprintf(stderr, "\nExpected : ");
124+
hexprint(plaintext, plen);
125+
fprintf(stderr, "\n");
126+
free(wrapped); free(unwrapped); free_aes_ctx(&ctx);
127+
return 1;
128+
}
129+
130+
free(wrapped);
131+
free(unwrapped);
132+
free_aes_ctx(&ctx);
133+
return 0;
134+
}
135+
136+
int main(void)
137+
{
138+
int fail = 0;
139+
140+
/* --- RFC 3394 examples (sections 4.1 - 4.6) --- */
141+
142+
/* 4.1: 128-bit KEK wrapping 128 bits of Key Data
143+
We also validate the wrapped output equals the RFC canonical bytes. */
144+
{
145+
const uint8_t kek128[16] = {
146+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
147+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F
148+
};
149+
const uint8_t keydata128[16] = {
150+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
151+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF
152+
};
153+
const uint8_t expected_wrapped_4_1[24] = {
154+
0x1F,0xA6,0x8B,0x0A,0x81,0x12,0xB4,0x47,
155+
0xAE,0xF3,0x4B,0xD8,0xFB,0x5A,0x7B,0x82,
156+
0x9D,0x3E,0x86,0x23,0x71,0xD2,0xCF,0xE5
157+
};
158+
159+
if (run_kw_vector(kek128, sizeof(kek128), keydata128, sizeof(keydata128),
160+
expected_wrapped_4_1, sizeof(expected_wrapped_4_1)) != 0) {
161+
fprintf(stderr, "RFC 3394 section 4.1 FAILED\n");
162+
fail = 1;
163+
} else {
164+
printf("RFC 4.1 (128-bit KEK, 128-bit keydata): passed\n");
165+
}
166+
}
167+
168+
/* 4.2: 192-bit KEK wrapping 128 bits of Key Data (round-trip check) */
169+
{
170+
const uint8_t kek192[24] = {
171+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
172+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F,
173+
0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17
174+
};
175+
const uint8_t keydata128[16] = {
176+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
177+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF
178+
};
179+
180+
if (run_kw_vector(kek192, sizeof(kek192), keydata128, sizeof(keydata128),
181+
NULL, 0) != 0) {
182+
fprintf(stderr, "RFC 3394 section 4.2 FAILED\n");
183+
fail = 1;
184+
} else {
185+
printf("RFC 4.2 (192-bit KEK, 128-bit keydata): passed (round-trip)\n");
186+
}
187+
}
188+
189+
/* 4.3: 256-bit KEK wrapping 128 bits of Key Data (round-trip check) */
190+
{
191+
const uint8_t kek256[32] = {
192+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
193+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F,
194+
0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
195+
0x18,0x19,0x1A,0x1B,0x1C,0x1D,0x1E,0x1F
196+
};
197+
const uint8_t keydata128[16] = {
198+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
199+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF
200+
};
201+
202+
if (run_kw_vector(kek256, sizeof(kek256), keydata128, sizeof(keydata128),
203+
NULL, 0) != 0) {
204+
fprintf(stderr, "RFC 3394 section 4.3 FAILED\n");
205+
fail = 1;
206+
} else {
207+
printf("RFC 4.3 (256-bit KEK, 128-bit keydata): passed (round-trip)\n");
208+
}
209+
}
210+
211+
/* 4.4: 192-bit KEK wrapping 192 bits (24 bytes) of Key Data */
212+
{
213+
const uint8_t kek192[24] = {
214+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
215+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F,
216+
0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17
217+
};
218+
const uint8_t keydata192[24] = {
219+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
220+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF,
221+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07
222+
};
223+
224+
if (run_kw_vector(kek192, sizeof(kek192), keydata192, sizeof(keydata192),
225+
NULL, 0) != 0) {
226+
fprintf(stderr, "RFC 3394 section 4.4 FAILED\n");
227+
fail = 1;
228+
} else {
229+
printf("RFC 4.4 (192-bit KEK, 192-bit keydata): passed (round-trip)\n");
230+
}
231+
}
232+
233+
/* 4.5: 256-bit KEK wrapping 192 bits (24 bytes) of Key Data */
234+
{
235+
const uint8_t kek256[32] = {
236+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
237+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F,
238+
0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
239+
0x18,0x19,0x1A,0x1B,0x1C,0x1D,0x1E,0x1F
240+
};
241+
const uint8_t keydata192[24] = {
242+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
243+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF,
244+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07
245+
};
246+
247+
if (run_kw_vector(kek256, sizeof(kek256), keydata192, sizeof(keydata192),
248+
NULL, 0) != 0) {
249+
fprintf(stderr, "RFC 3394 section 4.5 FAILED\n");
250+
fail = 1;
251+
} else {
252+
printf("RFC 4.5 (256-bit KEK, 192-bit keydata): passed (round-trip)\n");
253+
}
254+
}
255+
256+
/* 4.6: 256-bit KEK wrapping 256 bits (32 bytes) of Key Data */
257+
{
258+
const uint8_t kek256[32] = {
259+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
260+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F,
261+
0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
262+
0x18,0x19,0x1A,0x1B,0x1C,0x1D,0x1E,0x1F
263+
};
264+
const uint8_t keydata256[32] = {
265+
0x00,0x11,0x22,0x33,0x44,0x55,0x66,0x77,
266+
0x88,0x99,0xAA,0xBB,0xCC,0xDD,0xEE,0xFF,
267+
0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
268+
0x08,0x09,0x0A,0x0B,0x0C,0x0D,0x0E,0x0F
269+
};
270+
271+
if (run_kw_vector(kek256, sizeof(kek256), keydata256, sizeof(keydata256),
272+
NULL, 0) != 0) {
273+
fprintf(stderr, "RFC 3394 section 4.6 FAILED\n");
274+
fail = 1;
275+
} else {
276+
printf("RFC 4.6 (256-bit KEK, 256-bit keydata): passed (round-trip)\n");
277+
}
278+
}
279+
280+
if (fail) {
281+
fprintf(stderr, "\nOne or more RFC 3394 vectors FAILED.\n");
282+
return 1;
283+
}
284+
285+
printf("\nAll RFC 3394 test vectors passed (round-trip). ✅\n");
286+
return 0;
287+
}

0 commit comments

Comments
 (0)