Skip to content

Commit

Permalink
Add Element.get_wrap()
Browse files Browse the repository at this point in the history
  • Loading branch information
TeamSpen210 committed Jan 23, 2025
1 parent 3c4fb8b commit d3c9545
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Version (dev)
* Handle entities being added/removed during iteration of :py:meth:`VMF.search() <srctools.vmf.VMF.search>`.
* Share common strings in the engine database to save some space.
* Fix saving the `PHYSCOLLIDE` BSP lump.
* Add :py:meth:`srctools.dmx.Element.get_wrap()`, allowing handling defaults more conveniently.
* Make :py:attr:`EntityDef.kv <srctools.fgd.EntityDef.kv>`, :py:attr:`.inp <srctools.fgd.EntityDef.inp>`
and :py:attr:`.out <srctools.fgd.EntityDef.out>` views settable, improve behaviour.

Expand Down
15 changes: 15 additions & 0 deletions src/srctools/dmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,21 @@ def clear(self) -> None:
"""Remove all attributes from the element."""
self._members.clear()

def get_wrap(self, name: str, default: Union[Attribute, ConvValue, Sequence[ConvValue]] = '', /) -> Attribute:
"""Retrieve the specified attribute, or return a default.
If the defalt is an Attribute it is returned unchanged, otherwise it is wrapped in
a temporary attribute.
"""
try:
return self._members[name.casefold()]
except KeyError:
if isinstance(default, Attribute):
return default
else:
typ, val = deduce_type(default)
return Attribute(name, typ, val)

def pop(self, name: str, default: Union[Attribute, ConvValue, Sequence[ConvValue]] = _UNSET) -> Attribute:
"""Remove the specified attribute and return it.
Expand Down
36 changes: 26 additions & 10 deletions tests/test_dmx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test the datamodel exchange implementation."""
from typing import cast
from typing import Any, List, Tuple, cast
from collections.abc import Callable
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -489,7 +489,7 @@ def test_binary_text_conversion(typ: ValueType, attr: str, value, binary: bytes,
'int', 'float', 'bool', 'str', 'bytes', 'time', 'color',
'vec2', 'vec3', 'vec4', 'angle', 'quaternion', 'matrix',
])
def test_attr_array_constructor(typ: ValueType, iterable: list, expected: list) -> None:
def test_attr_array_constructor(typ: ValueType, iterable: List[Any], expected: List[Any]) -> None:
"""Test constructing an array attribute."""
expected_type = type(expected[0])
attr = Attribute.array('some_array', typ, iterable)
Expand All @@ -515,7 +515,7 @@ def test_attr_array_elem_constructor() -> None:
assert attr[2].val_elem is elem1


deduce_type_tests = [
deduce_type_tests: List[Tuple[Any, ValueType, Any]] = [
(5, ValueType.INT, 5),
(5.0, ValueType.FLOAT, 5.0),
(False, ValueType.BOOL, False),
Expand All @@ -537,16 +537,16 @@ def test_attr_array_elem_constructor() -> None:
]


@pytest.mark.parametrize('input, val_type, output', deduce_type_tests)
def test_deduce_type_basic(input, val_type, output) -> None:
@pytest.mark.parametrize('input_val, val_type, output', deduce_type_tests)
def test_deduce_type_basic(input_val: Any, val_type: ValueType, output: Any) -> None:
"""Test type deduction behaves correctly."""
[test_type, test_val] = deduce_type(input)
[test_type, test_val] = deduce_type(input_val)
assert test_type is val_type
assert type(test_val) is type(output)
assert test_val == output


@pytest.mark.parametrize('input, val_type, output', [
@pytest.mark.parametrize('input_val, val_type, output', [
# Add the above tests here too.
([inp, inp, inp], val_typ, [out, out, out])
for inp, val_typ, out in
Expand Down Expand Up @@ -574,11 +574,11 @@ def test_deduce_type_basic(input, val_type, output) -> None:
(collections.deque([1.0, 2.0, 3.0]), ValueType.FLOAT, [1.0, 2.0, 3.0]),
(range(5), ValueType.INT, [0, 1, 2, 3, 4]),
])
def test_deduce_type_array(input, val_type, output) -> None:
def test_deduce_type_array(input_val: List[Any], val_type: ValueType, output: List[Any]) -> None:
"""Test array deduction, and some special cases."""
[test_type, test_arr] = deduce_type(input)
[test_type, test_arr] = deduce_type(input_val)
assert test_type is val_type
assert len(input) == len(test_arr), repr(test_arr)
assert len(input_val) == len(test_arr), repr(test_arr)
for i, (test, out) in enumerate(zip(test_arr, output)):
assert type(test) is type(out), f'{i}: {test!r} != {out!r}'
assert test == out, f'{i}: {test!r} != {out!r}'
Expand All @@ -601,6 +601,22 @@ def test_deduce_type_adv() -> None:
print(deduce_type(range(0)))


def test_elem_get() -> None:
"""Test the behaviours that get various values."""
elem = Element('Named', 'Type')
elem['keyName'] = attr = Attribute.string('keyName', 'hello')
assert elem['keyName'] is attr
with pytest.raises(KeyError):
elem['missing']
assert elem.get_wrap('keyName') is attr
assert elem.get_wrap('keyName', 'default') is attr
assert elem.get_wrap('miSSing', 45) == Attribute.int('miSSing', 45)
assert elem.get_wrap('miSSing', True) == Attribute.bool('miSSing', True)
assert elem.get_wrap('miSSing', [1.0, 2.0]) == Attribute.array(
'miSSing', ValueType.FLOAT, [1.0, 2.0],
)


def test_special_attr_name() -> None:
"""Test the special behaviour of the "name" attribute.
Expand Down

0 comments on commit d3c9545

Please sign in to comment.