Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.

Commit e363d07

Browse files
committed
fix read_struct
Inconsistency between server and client side thrift struct may cause undefined behavior. Traverse whole struct fields to ignore missed fields. Use `skip` instead of raising exception to avoid wrong type fields.
1 parent 61ac867 commit e363d07

File tree

3 files changed

+135
-10
lines changed

3 files changed

+135
-10
lines changed

tests/test_type_mismatch.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from unittest import TestCase
2+
3+
from thriftpy.thrift import TType, TPayload
4+
5+
from thriftpy.transport.memory import TMemoryBuffer
6+
from thriftpy.protocol.binary import TBinaryProtocol
7+
8+
from thriftpy._compat import CYTHON
9+
10+
11+
class Struct(TPayload):
12+
thrift_spec = {
13+
1: (TType.I32, 'a', False),
14+
2: (TType.STRING, 'b', False),
15+
3: (TType.DOUBLE, 'c', False)
16+
}
17+
default_spec = [('a', None), ('b', None), ('c', None)]
18+
19+
20+
class TItem(TPayload):
21+
thrift_spec = {
22+
1: (TType.I32, "id", False),
23+
2: (TType.LIST, "phones", TType.STRING, False),
24+
3: (TType.MAP, "addr", (TType.I32, TType.STRING), False),
25+
4: (TType.LIST, "data", (TType.STRUCT, Struct), False)
26+
}
27+
default_spec = [("id", None), ("phones", None), ("addr", None),
28+
("data", None)]
29+
30+
31+
class MismatchTestCase(TestCase):
32+
BUFFER = TMemoryBuffer
33+
PROTO = TBinaryProtocol
34+
35+
def test_list_type_mismatch(self):
36+
class TMismatchItem(TPayload):
37+
thrift_spec = {
38+
1: (TType.I32, "id", False),
39+
2: (TType.LIST, "phones", (TType.I32, False), False),
40+
}
41+
default_spec = [("id", None), ("phones", None)]
42+
43+
t = self.BUFFER()
44+
p = self.PROTO(t)
45+
46+
item = TItem(id=37, phones=["23424", "235125"])
47+
p.write_struct(item)
48+
p.write_message_end()
49+
50+
item2 = TMismatchItem()
51+
p.read_struct(item2)
52+
53+
assert item2.phones == []
54+
55+
def test_map_type_mismatch(self):
56+
class TMismatchItem(TPayload):
57+
thrift_spec = {
58+
1: (TType.I32, "id", False),
59+
3: (TType.MAP, "addr", (TType.STRING, TType.STRING), False)
60+
}
61+
default_spec = [("id", None), ("addr", None)]
62+
63+
t = self.BUFFER()
64+
p = self.PROTO(t)
65+
66+
item = TItem(id=37, addr={1: "hello", 2: "world"})
67+
p.write_struct(item)
68+
p.write_message_end()
69+
70+
item2 = TMismatchItem()
71+
p.read_struct(item2)
72+
73+
assert item2.addr == {}
74+
75+
def test_struct_mismatch(self):
76+
class MismatchStruct(TPayload):
77+
thrift_spec = {
78+
1: (TType.STRING, 'a', False),
79+
2: (TType.STRING, 'b', False)
80+
}
81+
default_spec = [('a', None), ('b', None)]
82+
83+
class TMismatchItem(TPayload):
84+
thrift_spec = {
85+
1: (TType.I32, "id", False),
86+
2: (TType.LIST, "phones", TType.STRING, False),
87+
3: (TType.MAP, "addr", (TType.I32, TType.STRING), False),
88+
4: (TType.LIST, "data", (TType.STRUCT, MismatchStruct), False)
89+
}
90+
default_spec = [("id", None), ("phones", None), ("addr", None)]
91+
92+
t = self.BUFFER()
93+
p = self.PROTO(t)
94+
95+
item = TItem(id=37, data=[Struct(a=1, b="hello", c=0.123),
96+
Struct(a=2, b="world", c=34.342346),
97+
Struct(a=3, b="when", c=25235.14)])
98+
p.write_struct(item)
99+
p.write_message_end()
100+
101+
item2 = TMismatchItem()
102+
p.read_struct(item2)
103+
104+
assert len(item2.data) == 3
105+
assert all([i.b for i in item2.data])
106+
107+
108+
if CYTHON:
109+
from thriftpy.transport.memory import TCyMemoryBuffer
110+
from thriftpy.protocol.cybin import TCyBinaryProtocol
111+
112+
class CyMismatchTestCase(MismatchTestCase):
113+
BUFFER = TCyMemoryBuffer
114+
PROTO = TCyBinaryProtocol

thriftpy/protocol/binary.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def read_val(inbuf, ttype, spec=None):
243243
r_type, sz = read_list_begin(inbuf)
244244
# the v_type is useless here since we already get it from spec
245245
if r_type != v_type:
246-
raise Exception("Message Corrupt")
246+
for _ in range(sz):
247+
skip(inbuf, r_type)
248+
return []
247249

248250
for i in range(sz):
249251
result.append(read_val(inbuf, v_type, v_spec))
@@ -265,7 +267,10 @@ def read_val(inbuf, ttype, spec=None):
265267
result = {}
266268
sk_type, sv_type, sz = read_map_begin(inbuf)
267269
if sk_type != k_type or sv_type != v_type:
268-
raise Exception("Message Corrupt")
270+
for _ in range(sz):
271+
skip(inbuf, sk_type)
272+
skip(inbuf, sv_type)
273+
return {}
269274

270275
for i in range(sz):
271276
k_val = read_val(inbuf, k_type, k_spec)
@@ -281,8 +286,7 @@ def read_val(inbuf, ttype, spec=None):
281286

282287

283288
def read_struct(inbuf, obj):
284-
# The max loop count equals field count + a final stop byte.
285-
for i in range(len(obj.thrift_spec) + 1):
289+
while True:
286290
f_type, fid = read_field_begin(inbuf)
287291
if f_type == TType.STOP:
288292
break
@@ -300,7 +304,8 @@ def read_struct(inbuf, obj):
300304
# it really should equal here. but since we already wasted
301305
# space storing the duplicate info, let's check it.
302306
if f_type != sf_type:
303-
raise Exception("Message Corrupt")
307+
skip(inbuf, f_type)
308+
continue
304309

305310
setattr(obj, f_name, read_val(inbuf, f_type, f_container_spec))
306311

thriftpy/protocol/cybin/cybin.pyx

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ cdef inline int write_double(CyTransportBase buf, double val) except -1:
9696

9797
cdef inline read_struct(CyTransportBase buf, obj):
9898
cdef dict field_specs = obj.thrift_spec
99-
cdef int fid, i
99+
cdef int fid
100100
cdef TType field_type, ttype
101101
cdef tuple field_spec
102102
cdef str name
103103

104-
for i in range(len(field_specs) + 1):
104+
while True:
105105
field_type = <TType>read_i08(buf)
106106
if field_type == T_STOP:
107107
break
@@ -114,7 +114,8 @@ cdef inline read_struct(CyTransportBase buf, obj):
114114
field_spec = field_specs[fid]
115115
ttype = field_spec[0]
116116
if field_type != ttype:
117-
raise ProtocolError("Message Corrupt")
117+
skip(buf, field_type)
118+
continue
118119

119120
name = field_spec[1]
120121
if len(field_spec) == 2:
@@ -211,7 +212,9 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):
211212
size = read_i32(buf)
212213

213214
if orig_type != v_type:
214-
raise ProtocolError("Message Corrupt")
215+
for _ in range(size):
216+
skip(buf, orig_type)
217+
return []
215218

216219
return [c_read_val(buf, v_type, v_spec) for _ in range(size)]
217220

@@ -237,7 +240,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None):
237240
size = read_i32(buf)
238241

239242
if orig_key_type != k_type or orig_type != v_type:
240-
raise ProtocolError("Message Corrupt")
243+
for _ in range(size):
244+
skip(buf, orig_key_type)
245+
skip(buf, orig_type)
246+
return {}
241247

242248
return {c_read_val(buf, k_type, k_spec): c_read_val(buf, v_type, v_spec)
243249
for _ in range(size)}

0 commit comments

Comments
 (0)