Skip to content

Commit

Permalink
Merge pull request #2 from minormending/fix-non-repeated-msg-nullable
Browse files Browse the repository at this point in the history
Fix non-repeated message defaults
  • Loading branch information
minormending authored Apr 18, 2022
2 parents 51137f8 + 3b7d115 commit 19ef569
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 19 deletions.
11 changes: 8 additions & 3 deletions protobuf2arr/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def msg_to_arr(obj: Message) -> List[Any]:
list_vals.append(msg_to_arr(item))
val = list_vals
else:
val = msg_to_arr(val)
if default_values and str(val).strip() in default_values:
val = None
else:
val = msg_to_arr(val)
elif default_values:
if field.type == field.TYPE_BYTES and str(val, "UTF-8") in default_values:
val = None
Expand All @@ -51,8 +54,8 @@ def arr_to_msg(arr: List[Any], msg: Message) -> Message:
num = idx + 1
field = msg.DESCRIPTOR.fields_by_number[num]
if field.type == field.TYPE_MESSAGE:
cls = field.message_type._concrete_class
if field.label == field.LABEL_REPEATED:
cls = field.message_type._concrete_class
models = []
for sub_item in item:
# None-type is Message with default values
Expand All @@ -66,8 +69,10 @@ def arr_to_msg(arr: List[Any], msg: Message) -> Message:
models.append(model)
_assign_field_value(msg, field, models)
else:
# None-type is Message with default values
value = [None for _ in cls.DESCRIPTOR.fields] if item is None else item
model = getattr(msg, field.name)
arr_to_msg(item, model) # fill model
arr_to_msg(value, model) # fill model
elif field.type == field.TYPE_BYTES and isinstance(item, str):
_assign_field_value(msg, field, item.encode("UTF-8"))
elif item == None and (options := field.GetOptions()):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "protobuf2arr"
version = "0.1.3"
version = "0.1.4"
description = "Translate a protobuf message to Google's RPC array format."
authors = ["Kevin Ramdath <krpent@gmail.com>"]
repository = "https://github.com/minormending/protobuf2arr"
Expand Down
1 change: 1 addition & 0 deletions tests/test_basic.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ message TestQueue {
repeated int32 repeated_int = 7 [(nullable) = '[]'];

repeated TestItem items = 8 [(nullable) = ''];
TestItem field_item = 9 [(nullable) = ''];

message TestItem {
int32 item_field_int = 1 [(nullable) = '0'];
Expand Down
14 changes: 8 additions & 6 deletions tests/test_basic_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 30 additions & 9 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,24 @@ def _test_deserialization(

def test_serialization_basic_default_empty(self):
queue = TestQueueBasic()
queue.field_item.item_field_int = 0
arr = msg_to_arr(queue)
self.assertEqual(arr, [None, None, None, None, None, None, None, []])
self.assertEqual(arr, [None, None, None, None, None, None, None, [], None])

serial = serialize_msg2arr(queue)
self.assertEqual(serial, "[null,null,null,null,null,null,null,[]]")
self.assertEqual(serial, "[null,null,null,null,null,null,null,[],null]")

self._test_deserialization(queue, arr, serial, TestQueueBasic)

def test_serialization_basic_default_empty_subitems(self):
queue = TestQueueBasic()
queue.items.append(TestQueueBasic.TestItem())
queue.field_item.item_field_int = 0
arr = msg_to_arr(queue)
self.assertEqual(arr, [None, None, None, None, None, None, None, [None]])
self.assertEqual(arr, [None, None, None, None, None, None, None, [None], None])

serial = serialize_msg2arr(queue)
self.assertEqual(serial, "[null,null,null,null,null,null,null,[null]]")
self.assertEqual(serial, "[null,null,null,null,null,null,null,[null],null]")

self._test_deserialization(queue, arr, serial, TestQueueBasic)

Expand All @@ -54,12 +56,13 @@ def test_serialization_basic_default_filled(self):
queue.field_enum = TestQueueBasic.TestEnum.TEST_ENUM_0
queue.repeated_int.append(1)
queue.repeated_int.pop()
queue.field_item.item_field_int = 0

arr = msg_to_arr(queue)
self.assertEqual(arr, [None, None, None, None, None, None, None, []])
self.assertEqual(arr, [None, None, None, None, None, None, None, [], None])

serial = serialize_msg2arr(queue)
self.assertEqual(serial, "[null,null,null,null,null,null,null,[]]")
self.assertEqual(serial, "[null,null,null,null,null,null,null,[],null]")

self._test_deserialization(queue, arr, serial, TestQueueBasic)

Expand All @@ -73,14 +76,29 @@ def test_serialization_basic(self):
queue.field_enum = TestQueueBasic.TestEnum.TEST_ENUM_1
queue.repeated_int.append(24)
queue.repeated_int.append(37)
queue.field_item.item_field_int = 43

arr = msg_to_arr(queue)
self.assertEqual(
arr, [100, 77.89, "Hello World", True, b"bytes", 1, [24, 37], []]
arr,
[
100,
77.89,
"Hello World",
True,
b"bytes",
1,
[24, 37],
[],
[43, None, None, None, None, None],
],
)

serial = serialize_msg2arr(queue)
self.assertEqual(serial, '[100,77.89,"Hello World",true,"bytes",1,[24,37],[]]')
self.assertEqual(
serial,
'[100,77.89,"Hello World",true,"bytes",1,[24,37],[],[43,null,null,null,null,null]]',
)

self._test_deserialization(queue, arr, serial, TestQueueBasic)

Expand All @@ -104,6 +122,8 @@ def test_serialization_basic_with_subitems(self):
subitem.item_field_enum = TestQueueBasic.TestEnum.TEST_ENUM_2
queue.items.append(subitem)

queue.field_item.item_field_int = 0

arr = msg_to_arr(queue)
self.assertEqual(
arr,
Expand All @@ -116,13 +136,14 @@ def test_serialization_basic_with_subitems(self):
1,
[24, 37],
[[45, 23.67, "Hello Inner World", True, b"bytes inner", 2]],
None,
],
)

serial = serialize_msg2arr(queue)
self.assertEqual(
serial,
'[100,77.89,"Hello World",true,"bytes",1,[24,37],[[45,23.67,"Hello Inner World",true,"bytes inner",2]]]',
'[100,77.89,"Hello World",true,"bytes",1,[24,37],[[45,23.67,"Hello Inner World",true,"bytes inner",2]],null]',
)

self._test_deserialization(queue, arr, serial, TestQueueBasic)
Expand Down

0 comments on commit 19ef569

Please sign in to comment.