Skip to content

Commit

Permalink
add support for counted-variadic
Browse files Browse the repository at this point in the history
  • Loading branch information
popenta committed Jun 25, 2024
1 parent c2ef1c1 commit 3205894
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 1 deletion.
4 changes: 4 additions & 0 deletions multiversx_sdk/abi/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from multiversx_sdk.abi.biguint_value import BigUIntValue
from multiversx_sdk.abi.bool_value import BoolValue
from multiversx_sdk.abi.bytes_value import BytesValue
from multiversx_sdk.abi.counted_variadic_values import CountedVariadicValues
from multiversx_sdk.abi.enum_value import EnumValue
from multiversx_sdk.abi.fields import Field
from multiversx_sdk.abi.interface import IPayloadHolder
Expand Down Expand Up @@ -231,6 +232,9 @@ def _create_prototype(self, type_formula: TypeFormula) -> Any:
if name == "variadic":
type_parameter = type_formula.type_parameters[0]
return VariadicValues([], item_creator=lambda: self._create_prototype(type_parameter))
if name == "counted-variadic":
type_parameter = type_formula.type_parameters[0]
return CountedVariadicValues([], item_creator=lambda: self._create_prototype(type_parameter))
if name == "multi":
return MultiValue([self._create_prototype(type_parameter) for type_parameter in type_formula.type_parameters])

Expand Down
19 changes: 18 additions & 1 deletion multiversx_sdk/abi/abi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from multiversx_sdk.abi.address_value import AddressValue
from multiversx_sdk.abi.biguint_value import BigUIntValue
from multiversx_sdk.abi.bytes_value import BytesValue
from multiversx_sdk.abi.counted_variadic_values import CountedVariadicValues
from multiversx_sdk.abi.enum_value import EnumValue
from multiversx_sdk.abi.fields import Field
from multiversx_sdk.abi.list_value import ListValue
from multiversx_sdk.abi.option_value import OptionValue
from multiversx_sdk.abi.small_int_values import U64Value
from multiversx_sdk.abi.small_int_values import U32Value, U64Value
from multiversx_sdk.abi.string_value import StringValue
from multiversx_sdk.abi.struct_value import StructValue
from multiversx_sdk.abi.variadic_values import VariadicValues
Expand Down Expand Up @@ -54,6 +55,22 @@ def test_abi():
assert abi.endpoints_prototypes_by_name["add"].output_parameters == []


def test_load_abi_with_counted_variadic():
abi = Abi.load(testdata / "counted-variadic.abi.json")

bar_prototype = abi.endpoints_prototypes_by_name["bar"]

assert isinstance(bar_prototype.input_parameters[0], CountedVariadicValues)
assert bar_prototype.input_parameters[0].items == []
assert bar_prototype.input_parameters[0].item_creator
assert bar_prototype.input_parameters[0].item_creator() == U32Value()

assert isinstance(bar_prototype.input_parameters[1], CountedVariadicValues)
assert bar_prototype.input_parameters[1].items == []
assert bar_prototype.input_parameters[1].item_creator
assert bar_prototype.input_parameters[1].item_creator() == BytesValue()


def test_encode_endpoint_input_parameters_artificial_contract():
abi = Abi.load(testdata / "artificial.abi.json")

Expand Down
42 changes: 42 additions & 0 deletions multiversx_sdk/abi/counted_variadic_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Callable, List, Optional, Union

from multiversx_sdk.abi.interface import IPayloadHolder, ISingleValue
from multiversx_sdk.abi.multi_value import MultiValue
from multiversx_sdk.abi.shared import convert_native_value_to_list


class CountedVariadicValues(IPayloadHolder):
def __init__(self,
items: Optional[List[Union[ISingleValue, MultiValue]]] = None,
item_creator: Optional[Callable[[], Union[ISingleValue, MultiValue]]] = None) -> None:
self.items = items or []
self.length = len(self.items)

self.item_creator = item_creator

def set_payload(self, value: Any):
if not self.item_creator:
raise ValueError("populating variadic values from a native object requires the item creator to be set")

native_items, _ = convert_native_value_to_list(value)
self.length = len(native_items)

self.items.clear()

for native_item in native_items:
item = self.item_creator()
item.set_payload(native_item)
self.items.append(item)

def get_payload(self) -> Any:
return [item.get_payload() for item in self.items]

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, CountedVariadicValues)
and self.items == other.items
and self.item_creator == other.item_creator
)

def __iter__(self) -> Any:
return iter(self.items)
39 changes: 39 additions & 0 deletions multiversx_sdk/abi/counted_variadic_values_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import re

import pytest

from multiversx_sdk.abi.counted_variadic_values import CountedVariadicValues
from multiversx_sdk.abi.multi_value import MultiValue
from multiversx_sdk.abi.small_int_values import U32Value
from multiversx_sdk.abi.string_value import StringValue


def test_set_payload_and_get_payload():
# Simple
values = CountedVariadicValues(item_creator=lambda: U32Value())
values.set_payload([1, 2, 3])

assert values.length == 3
assert values.items == [U32Value(1), U32Value(2), U32Value(3)]
assert values.get_payload() == [1, 2, 3]

# Nested
values = CountedVariadicValues(item_creator=lambda: MultiValue([U32Value(), StringValue()]))
values.set_payload([[42, "hello"], [43, "world"]])

assert values.length == 2
assert values.items == [
MultiValue([U32Value(42), StringValue("hello")]),
MultiValue([U32Value(43), StringValue("world")])
]

assert values.get_payload() == [[42, "hello"], [43, "world"]]

# With errors
with pytest.raises(ValueError, match="populating variadic values from a native object requires the item creator to be set"):
CountedVariadicValues().set_payload([1, 2, 3])

# With errors
with pytest.raises(ValueError, match=re.escape("invalid literal for int() with base 10: 'foo'")):
values = CountedVariadicValues(item_creator=lambda: U32Value())
values.set_payload(["foo"])
22 changes: 22 additions & 0 deletions multiversx_sdk/abi/serializer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, List, Sequence

from multiversx_sdk.abi.codec import Codec
from multiversx_sdk.abi.counted_variadic_values import CountedVariadicValues
from multiversx_sdk.abi.interface import ISingleValue
from multiversx_sdk.abi.multi_value import *
from multiversx_sdk.abi.optional_value import OptionalValue
from multiversx_sdk.abi.parts import PartsHolder
from multiversx_sdk.abi.small_int_values import U32Value
from multiversx_sdk.abi.variadic_values import VariadicValues


Expand Down Expand Up @@ -50,6 +52,10 @@ def _do_serialize(self, parts_holder: PartsHolder, input_values: Sequence[Any]):
if i != len(input_values) - 1:
raise ValueError("variadic values must be last among input values")

self._do_serialize(parts_holder, value.items)
elif isinstance(value, CountedVariadicValues):
length = U32Value(value.length)
self._do_serialize(parts_holder, [length])
self._do_serialize(parts_holder, value.items)
elif isinstance(value, ISingleValue):
parts_holder.append_empty_part()
Expand Down Expand Up @@ -92,6 +98,8 @@ def _do_deserialize(self, parts_holder: PartsHolder, output_values: Sequence[Any
raise ValueError("variadic values must be last among output values")

self._deserialize_variadic_values(parts_holder, value)
elif isinstance(value, CountedVariadicValues):
self._deserialize_counted_variadic_values(parts_holder, value)
elif isinstance(value, ISingleValue):
self._deserialize_single_value(parts_holder, value)
else:
Expand All @@ -108,6 +116,20 @@ def _deserialize_variadic_values(self, parts_holder: PartsHolder, value: Variadi

value.items.append(new_item)

def _deserialize_counted_variadic_values(self, parts_holder: PartsHolder, value: CountedVariadicValues):
if value.item_creator is None:
raise Exception("cannot decode list: item creator is None")

self._deserialize_single_value(parts_holder, U32Value())

while not parts_holder.is_focused_beyond_last_part():
new_item = value.item_creator()

self._do_deserialize(parts_holder, [new_item])

value.items.append(new_item)
value.length += 1

def _deserialize_single_value(self, parts_holder: PartsHolder, value: ISingleValue):
part = parts_holder.read_whole_focused_part()
self.codec.decode_top_level(part, value)
Expand Down
63 changes: 63 additions & 0 deletions multiversx_sdk/abi/serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from multiversx_sdk.abi.address_value import AddressValue
from multiversx_sdk.abi.biguint_value import BigUIntValue
from multiversx_sdk.abi.bytes_value import BytesValue
from multiversx_sdk.abi.counted_variadic_values import CountedVariadicValues
from multiversx_sdk.abi.enum_value import EnumValue
from multiversx_sdk.abi.fields import *
from multiversx_sdk.abi.list_value import ListValue
Expand Down Expand Up @@ -143,6 +144,27 @@ def test_serialize():

assert data == "41@42@43"

# counted-variadic, of different types
data = serializer.serialize([
CountedVariadicValues(
items=[
U8Value(0x42),
U16Value(0x4243),
]),
])
assert data == "02@42@4243"

# variadic<u8>
data = serializer.serialize([
CountedVariadicValues(
items=[
U8Value(0x42),
U8Value(0x43),
]),
U8Value(0x44),
])
assert data == "02@42@43@44"


def test_deserialize():
serializer = Serializer(parts_separator="@")
Expand Down Expand Up @@ -317,6 +339,47 @@ def test_deserialize():

serializer.deserialize("0100", [destination])

# counted-variadic<u8>
destination = CountedVariadicValues(
item_creator=lambda: U8Value()
)

serializer.deserialize("03@2A@2B@2C", [destination])
assert destination.length == 3
assert destination.items == [
U8Value(0x2A),
U8Value(0x2B),
U8Value(0x2C),
]

# counted-variadic<u8>, with empty items
destination = CountedVariadicValues(
item_creator=lambda: U8Value()
)

serializer.deserialize("04@@01@00@", [destination])

assert destination.length == 4
assert destination.items == [
U8Value(0x00),
U8Value(0x01),
U8Value(0x00),
U8Value(0x00),
]

# variadic<u32>
destination = CountedVariadicValues(
item_creator=lambda: U32Value()
)

serializer.deserialize("02@AABBCCDD@DDCCBBAA", [destination])

assert destination.length == 2
assert destination.items == [
U32Value(0xAABBCCDD),
U32Value(0xDDCCBBAA),
]


def test_real_world_multisig_propose_batch():
"""
Expand Down
43 changes: 43 additions & 0 deletions multiversx_sdk/testutils/testdata/counted-variadic.abi.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"endpoints": [
{
"name": "foo",
"inputs": [
{
"type": "counted-variadic<Dummy>"
}
],
"outputs": []
},
{
"name": "bar",
"inputs": [
{
"type": "counted-variadic<u32>"
},
{
"type": "counted-variadic<bytes>"
}
],
"outputs": [
{
"type": "counted-variadic<u32>"
},
{
"type": "counted-variadic<bytes>"
}
]
}
],
"types": {
"Dummy": {
"type": "struct",
"fields": [
{
"name": "a",
"type": "u32"
}
]
}
}
}

0 comments on commit 3205894

Please sign in to comment.