-
Notifications
You must be signed in to change notification settings - Fork 0
/
protocol_tlscontext.py
154 lines (128 loc) · 7.14 KB
/
protocol_tlscontext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# ------------------------------------------------------------------------------
# TLSの接続状態を管理して、鍵導出を支援するためのクラス
# ------------------------------------------------------------------------------
from protocol_types import HandshakeType
from protocol_extensions import ExtensionType
from protocol_ciphersuite import CipherSuite
from protocol_ext_supportedgroups import NamedGroup
import crypto_hkdf as hkdf
class TLSContext:
def __init__(self, peer_name: str):
# TLSのどちら側の通信者か
assert peer_name in ('client', 'server')
self.peer_name = peer_name
# TLSのやりとりで送信されてきたメッセージを格納する。
# 辞書のkeyはクラス名 (ClientHelloなど) 、valueはTLSPlaintextクラスのインスタンス
self.tls_messages: dict[str, bytes] = {}
# Handshakeレコードのrecord.fragment部分のバイト列を結合したもの
self.tls_messages_bytes: list[bytes] = []
def append_msg(self, handshake: bytes):
self.tls_messages[handshake.msg_type] = handshake
self.tls_messages_bytes.append(bytes(handshake))
def get_messages_byte(self) -> bytes:
return b''.join(self.tls_messages_bytes)
def set_key_exchange(self, dhkex_classes: dict, secret_keys: dict):
self.client_hello = self.tls_messages.get(HandshakeType.client_hello)
self.server_hello = self.tls_messages.get(HandshakeType.server_hello)
self.dhkex_classes = dhkex_classes
self.secret_keys = secret_keys
self._derive_negotiated_params()
def _derive_negotiated_params(self):
self.cipher_suite = self.server_hello.msg.cipher_suite
# 共通鍵の導出
peer_share = None
if self.peer_name == 'client':
for ext in self.server_hello.msg.extensions:
if ext.extension_type == ExtensionType.key_share:
if ext.extension_data.shares.group == NamedGroup.x25519:
peer_share = ext.extension_data.shares
dhkex_class = self.dhkex_classes[NamedGroup.x25519]
secret_key = self.secret_keys[NamedGroup.x25519]
if ext.extension_data.shares.group == NamedGroup.ffdhe4096:
peer_share = ext.extension_data.shares
dhkex_class = self.dhkex_classes[NamedGroup.ffdhe4096]
secret_key = self.secret_keys[NamedGroup.ffdhe4096]
break
elif self.peer_name == 'server':
ext = self.client_hello.msg.extensions \
.find(lambda ext: ext.extension_type == ExtensionType.key_share)
for client_share in ext.extension_data.shares:
if client_share.group == NamedGroup.x25519:
peer_share = client_share
dhkex_class = self.dhkex_classes[NamedGroup.x25519]
secret_key = self.secret_keys[NamedGroup.x25519]
break
if client_share.group == NamedGroup.ffdhe4096:
peer_share = client_share
dhkex_class = self.dhkex_classes[NamedGroup.ffdhe4096]
secret_key = self.secret_keys[NamedGroup.ffdhe4096]
break
self.shared_key = dhkex_class(
secret_key, peer_share.key_exchange.get_raw_bytes())
# print('[+] shared key:', self.shared_key.hex())
self.hash_name = CipherSuite.get_hash_name(self.cipher_suite)
self.secret_size = CipherSuite.get_hash_size(self.cipher_suite)
self.hash_size = hkdf.hash_size(self.hash_name)
def key_schedule_in_handshake(self):
messages = self.get_messages_byte()
secret = bytearray(self.secret_size)
psk = bytearray(self.secret_size)
# early secret
secret = hkdf.HKDF_extract(secret, psk, self.hash_name)
self.early_secret = secret
print('[+] early secret:', secret.hex())
# handshake secret
secret = hkdf.derive_secret(secret, b'derived', b'', self.hash_name)
secret = hkdf.HKDF_extract(secret, self.shared_key, self.hash_name)
self.handshake_secret = secret
print('[+] handshake secret:', secret.hex())
self.client_hs_traffic_secret = \
hkdf.derive_secret(secret, b'c hs traffic', messages, self.hash_name)
self.server_hs_traffic_secret = \
hkdf.derive_secret(secret, b's hs traffic', messages, self.hash_name)
# print('[+] c hs traffic:', client_hs_traffic_secret.hex())
# print('[+] s hs traffic:', server_hs_traffic_secret.hex())
self.cipher_class = CipherSuite.get_cipher_class(self.cipher_suite)
key_size = self.cipher_class.key_size
nonce_size = self.cipher_class.nonce_size
client_write_key, client_write_iv = \
hkdf.gen_key_and_iv(self.client_hs_traffic_secret,
key_size, nonce_size, self.hash_name)
server_write_key, server_write_iv = \
hkdf.gen_key_and_iv(self.server_hs_traffic_secret,
key_size, nonce_size, self.hash_name)
self.client_traffic_crypto = self.cipher_class(
key=client_write_key, nonce=client_write_iv)
self.server_traffic_crypto = self.cipher_class(
key=server_write_key, nonce=server_write_iv)
def key_schedule_in_app_data(self):
messages = self.get_messages_byte()
secret = self.handshake_secret
label = bytearray(self.secret_size)
# master secret
secret = hkdf.derive_secret(secret, b'derived', b'')
secret = hkdf.HKDF_extract(secret, label, self.hash_name)
self.master_secret = secret
print('[+] master secret:', secret.hex())
self.client_app_traffic_secret = \
hkdf.derive_secret(secret, b'c ap traffic', messages, self.hash_name)
self.server_app_traffic_secret = \
hkdf.derive_secret(secret, b's ap traffic', messages, self.hash_name)
# print('[+] c ap traffic:', client_app_traffic_secret.hex())
# print('[+] s ap traffic:', server_app_traffic_secret.hex())
key_size = self.cipher_class.key_size
nonce_size = self.cipher_class.nonce_size
client_app_write_key, client_app_write_iv = \
hkdf.gen_key_and_iv(self.client_app_traffic_secret, key_size,
nonce_size, self.hash_name)
server_app_write_key, server_app_write_iv = \
hkdf.gen_key_and_iv(self.server_app_traffic_secret, key_size,
nonce_size, self.hash_name)
# print('[+] client_app_write_key:', client_app_write_key.hex())
# print('[+] client_app_write_iv:', client_app_write_iv.hex())
# print('[+] server_app_write_key:', server_app_write_key.hex())
# print('[+] server_app_write_iv:', server_app_write_iv.hex())
self.client_app_data_crypto = self.cipher_class(
key=client_app_write_key, nonce=client_app_write_iv)
self.server_app_data_crypto = self.cipher_class(
key=server_app_write_key, nonce=server_app_write_iv)