From e0d47202f3a8d515794d3b0f24d4ac16e9ee07be Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 17 Jan 2025 12:08:50 +0100 Subject: [PATCH] fix: handle positional dataclass arguments --- pyproject.toml | 2 +- .../_adapter/generic_call_adapter.py | 7 +- tests/adapter/test_dataclass.py | 67 +++++++++++++++++-- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1cf18fc8..4af94662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,7 @@ dependencies = [ installer="uv" [tool.hatch.envs.cov.scripts] -gh=[ +github=[ "- rm htmlcov/*", "gh run download -n html-report -D htmlcov", "xdg-open htmlcov/index.html", diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index 7e942319..e405d6e7 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -265,8 +265,11 @@ def arguments(cls, value): return ([], kwargs) def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) + if isinstance(pos_or_name, str): + return getattr(value, pos_or_name) + else: + args = [field for field in fields(value) if field.init] + return args[pos_or_name] try: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index da2eedce..9c88e1dd 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -95,7 +95,8 @@ class A: c:list=field(default_factory=list) def test_something(): - assert A(a=1) == snapshot(A(a=1,b=2,c=[])) + for _ in [1,2]: + assert A(a=1) == snapshot(A(a=1,b=2,c=[])) """ ).run_inline( ["--inline-snapshot=update"], @@ -112,7 +113,47 @@ class A: c:list=field(default_factory=list) def test_something(): - assert A(a=1) == snapshot(A(a=1)) + for _ in [1,2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_dataclass_positional_arguments(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + for _ in [1,2]: + assert A(a=1) == snapshot(A(1,2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + for _ in [1,2]: + assert A(a=1) == snapshot(A(1,2)) """ } ), @@ -400,12 +441,18 @@ def argument(cls, value, pos_or_name): return value.l[pos_or_name] def test_L1(): - assert L(1,2) == snapshot(L(1)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1)), "not equal" def test_L2(): - assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" + +def test_L3(): + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" """ - ).run_pytest( + ).run_pytest().run_pytest( ["--inline-snapshot=fix"], changed_files=snapshot( { @@ -439,10 +486,16 @@ def argument(cls, value, pos_or_name): return value.l[pos_or_name] def test_L1(): - assert L(1,2) == snapshot(L(1, 2)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" def test_L2(): - assert L(1,2) == snapshot(L(1, 2)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" + +def test_L3(): + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" """ } ),