Skip to content

Commit b755248

Browse files
committed
fix: fixed crash caused by custom factory methods
1 parent 0802d38 commit b755248

File tree

6 files changed

+105
-29
lines changed

6 files changed

+105
-29
lines changed

src/inline_snapshot/_adapter/adapter.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import typing
5+
from dataclasses import dataclass
56

67
from inline_snapshot._source_file import SourceFile
78

@@ -38,21 +39,46 @@ class Item(typing.NamedTuple):
3839
node: ast.expr
3940

4041

42+
@dataclass
43+
class FrameContext:
44+
globals: dict
45+
locals: dict
46+
47+
48+
@dataclass
49+
class AdapterContext:
50+
file: SourceFile
51+
frame: FrameContext | None
52+
53+
def eval(self, node):
54+
assert self.frame is not None
55+
56+
return eval(
57+
compile(ast.Expression(node), self.file.filename, "eval"),
58+
self.frame.globals,
59+
self.frame.locals,
60+
)
61+
62+
4163
class Adapter:
64+
# TODO remove context
4265
context: SourceFile
4366

44-
def __init__(self, context):
45-
self.context = context
67+
adapter_context: AdapterContext
68+
69+
def __init__(self, context: AdapterContext):
70+
self.adapter_context = context
71+
self.context = context.file
4672

4773
def get_adapter(self, old_value, new_value) -> Adapter:
4874
if type(old_value) is not type(new_value):
4975
from .value_adapter import ValueAdapter
5076

51-
return ValueAdapter(self.context)
77+
return ValueAdapter(self.adapter_context)
5278

5379
adapter_type = get_adapter_type(old_value)
5480
if adapter_type is not None:
55-
return adapter_type(self.context)
81+
return adapter_type(self.adapter_context)
5682
assert False
5783

5884
def assign(self, old_value, old_node, new_value):
@@ -61,7 +87,7 @@ def assign(self, old_value, old_node, new_value):
6187
def value_assign(self, old_value, old_node, new_value):
6288
from .value_adapter import ValueAdapter
6389

64-
adapter = ValueAdapter(self.context)
90+
adapter = ValueAdapter(self.adapter_context)
6591
result = yield from adapter.assign(old_value, old_node, new_value)
6692
return result
6793

src/inline_snapshot/_adapter/dict_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def repr(cls, value):
3030
def map(cls, value, map_function):
3131
return {k: adapter_map(v, map_function) for k, v in value.items()}
3232

33-
def items(self, value, node):
33+
@classmethod
34+
def items(cls, value, node):
3435
if node is None:
3536
return [Item(value=value, node=None) for value in value.values()]
3637

src/inline_snapshot/_adapter/generic_call_adapter.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ def map(cls, value, map_function):
7575
},
7676
)
7777

78-
def items(self, value, node):
79-
new_args, new_kwargs = self.arguments(value)
78+
@classmethod
79+
def items(cls, value, node):
80+
new_args, new_kwargs = cls.arguments(value)
8081

8182
if node is not None:
8283
assert isinstance(node, ast.Call)
@@ -101,6 +102,12 @@ def assign(self, old_value, old_node, new_value):
101102
result = yield from self.value_assign(old_value, old_node, new_value)
102103
return result
103104

105+
call_type = self.adapter_context.eval(old_node.func)
106+
107+
if not self.check_type(call_type):
108+
result = yield from self.value_assign(old_value, old_node, new_value)
109+
return result
110+
104111
# positional arguments
105112
for pos_arg in old_node.args:
106113
if isinstance(pos_arg, ast.Starred):

src/inline_snapshot/_adapter/sequence_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ def map(cls, value, map_function):
3434
result = [adapter_map(v, map_function) for v in value]
3535
return cls.value_type(result)
3636

37-
def items(self, value, node):
37+
@classmethod
38+
def items(cls, value, node):
3839
if node is None:
3940
return [Item(value=v, node=None) for v in value]
4041

41-
assert isinstance(node, self.node_type), (node, self)
42+
assert isinstance(node, cls.node_type), (node, cls)
4243
assert len(value) == len(node.elts)
4344

4445
return [Item(value=v, node=n) for v, n in zip(value, node.elts)]

src/inline_snapshot/_inline_snapshot.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from inline_snapshot._adapter.adapter import adapter_map
1515
from inline_snapshot._source_file import SourceFile
1616

17+
from ._adapter.adapter import AdapterContext
18+
from ._adapter.adapter import FrameContext
1719
from ._adapter.adapter import get_adapter_type
1820
from ._change import CallArg
1921
from ._change import Change
@@ -83,12 +85,17 @@ class GenericValue(Snapshot):
8385
_old_value: Any
8486
_current_op = "undefined"
8587
_ast_node: ast.Expr
86-
_file: SourceFile
88+
_context: AdapterContext
89+
90+
@property
91+
def _file(self):
92+
return self._context.file
8793

8894
def get_adapter(self, value):
89-
return get_adapter_type(value)(self._file)
95+
return get_adapter_type(value)(self._context)
9096

91-
def _re_eval(self, value):
97+
def _re_eval(self, value, context: AdapterContext):
98+
self._context = context
9299

93100
def re_eval(old_value, node, value):
94101
if isinstance(old_value, Unmanaged):
@@ -168,13 +175,13 @@ def __getitem__(self, _item):
168175

169176

170177
class UndecidedValue(GenericValue):
171-
def __init__(self, old_value, ast_node, source):
178+
def __init__(self, old_value, ast_node, context: AdapterContext):
172179

173180
old_value = adapter_map(old_value, map_unmanaged)
174181
self._old_value = old_value
175182
self._new_value = undefined
176183
self._ast_node = ast_node
177-
self._file = SourceFile(source)
184+
self._context = context
178185

179186
def _change(self, cls):
180187
self.__class__ = cls
@@ -186,7 +193,7 @@ def _get_changes(self) -> Iterator[Change]:
186193

187194
def handle(node, obj):
188195

189-
adapter = self.get_adapter(obj)
196+
adapter = get_adapter_type(obj)
190197
if adapter is not None and hasattr(adapter, "items"):
191198
for item in adapter.items(obj, node):
192199
yield from handle(item.node, item.value)
@@ -259,7 +266,12 @@ def __eq__(self, other):
259266
_missing_values += 1
260267

261268
if not compare_only() and self._new_value is undefined:
262-
adapter = Adapter(self._file).get_adapter(self._old_value, other)
269+
frame = inspect.currentframe()
270+
assert frame is not None
271+
frame = frame.f_back
272+
assert frame is not None
273+
274+
adapter = Adapter(self._context).get_adapter(self._old_value, other)
263275
it = iter(adapter.assign(self._old_value, self._ast_node, clone(other)))
264276
self._changes = []
265277
while True:
@@ -473,18 +485,18 @@ def __getitem__(self, index):
473485
child_node = self._ast_node.values[pos]
474486

475487
self._new_value[index] = UndecidedValue(
476-
old_value.get(index, undefined), child_node, self._file
488+
old_value.get(index, undefined), child_node, self._context
477489
)
478490

479491
return self._new_value[index]
480492

481-
def _re_eval(self, value):
482-
super()._re_eval(value)
493+
def _re_eval(self, value, context: AdapterContext):
494+
super()._re_eval(value, context)
483495

484496
if self._new_value is not undefined and self._old_value is not undefined:
485497
for key, s in self._new_value.items():
486498
if key in self._old_value:
487-
s._re_eval(self._old_value[key])
499+
s._re_eval(self._old_value[key], context)
488500

489501
def _new_code(self):
490502
return (
@@ -594,6 +606,12 @@ def snapshot(obj: Any = undefined) -> Any:
594606

595607
expr = Source.executing(frame)
596608

609+
source = getattr(expr, "source", None) if expr is not None else None
610+
context = AdapterContext(
611+
file=SourceFile(source),
612+
frame=FrameContext(globals=frame.f_globals, locals=frame.f_locals),
613+
)
614+
597615
module = inspect.getmodule(frame)
598616
if module is not None and module.__file__ is not None:
599617
_files_with_snapshots.add(module.__file__)
@@ -604,12 +622,12 @@ def snapshot(obj: Any = undefined) -> Any:
604622
node = expr.node
605623
if node is None:
606624
# we can run without knowing of the calling expression but we will not be able to fix code
607-
snapshots[key] = SnapshotReference(obj, None)
625+
snapshots[key] = SnapshotReference(obj, None, context)
608626
else:
609627
assert isinstance(node, ast.Call)
610-
snapshots[key] = SnapshotReference(obj, expr)
628+
snapshots[key] = SnapshotReference(obj, expr, context)
611629
else:
612-
snapshots[key]._re_eval(obj)
630+
snapshots[key]._re_eval(obj, context)
613631

614632
return snapshots[key]._value
615633

@@ -627,12 +645,11 @@ def used_externals(tree):
627645

628646

629647
class SnapshotReference:
630-
def __init__(self, value, expr):
648+
def __init__(self, value, expr, context: AdapterContext):
631649
self._expr = expr
632650
node = expr.node.args[0] if expr is not None and expr.node.args else None
633651
source = expr.source if expr is not None else None
634-
self._value = UndecidedValue(value, node, source)
635-
self._uses_externals = []
652+
self._value = UndecidedValue(value, node, context)
636653

637654
def _changes(self):
638655

@@ -657,5 +674,5 @@ def _changes(self):
657674

658675
yield from self._value._get_changes()
659676

660-
def _re_eval(self, obj):
661-
self._value._re_eval(obj)
677+
def _re_eval(self, obj, context: AdapterContext):
678+
self._value._re_eval(obj, context)

tests/test_pydantic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,27 @@ def test_something():
135135
).run_pytest(
136136
changed_files=snapshot({}),
137137
)
138+
139+
140+
def test_pydantic_factory_method():
141+
Example(
142+
"""\
143+
from inline_snapshot import snapshot
144+
from pydantic import BaseModel
145+
146+
class A(BaseModel):
147+
a:int
148+
149+
@classmethod
150+
def from_str(cls,s):
151+
return cls(a=int(s))
152+
153+
def test_something():
154+
assert A(a=2) == snapshot(A.from_str("1"))
155+
"""
156+
).run_pytest(
157+
["--inline-snapshot=fix"],
158+
changed_files=snapshot({}),
159+
stderr=snapshot(""),
160+
report=snapshot(""),
161+
)

0 commit comments

Comments
 (0)