Skip to content

Commit

Permalink
fix: fixed crash caused by custom factory methods
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 15, 2025
1 parent 0802d38 commit ef707af
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 30 deletions.
36 changes: 31 additions & 5 deletions src/inline_snapshot/_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import typing
from dataclasses import dataclass

from inline_snapshot._source_file import SourceFile

Expand Down Expand Up @@ -38,21 +39,46 @@ class Item(typing.NamedTuple):
node: ast.expr


@dataclass
class FrameContext:
globals: dict
locals: dict


@dataclass
class AdapterContext:
file: SourceFile
frame: FrameContext | None

def eval(self, node):
assert self.frame is not None

return eval(
compile(ast.Expression(node), self.file.filename, "eval"),
self.frame.globals,
self.frame.locals,
)


class Adapter:
# TODO remove context
context: SourceFile

def __init__(self, context):
self.context = context
adapter_context: AdapterContext

def __init__(self, context: AdapterContext):
self.adapter_context = context
self.context = context.file

def get_adapter(self, old_value, new_value) -> Adapter:
if type(old_value) is not type(new_value):
from .value_adapter import ValueAdapter

return ValueAdapter(self.context)
return ValueAdapter(self.adapter_context)

adapter_type = get_adapter_type(old_value)
if adapter_type is not None:
return adapter_type(self.context)
return adapter_type(self.adapter_context)
assert False

def assign(self, old_value, old_node, new_value):
Expand All @@ -61,7 +87,7 @@ def assign(self, old_value, old_node, new_value):
def value_assign(self, old_value, old_node, new_value):
from .value_adapter import ValueAdapter

adapter = ValueAdapter(self.context)
adapter = ValueAdapter(self.adapter_context)
result = yield from adapter.assign(old_value, old_node, new_value)
return result

Expand Down
3 changes: 2 additions & 1 deletion src/inline_snapshot/_adapter/dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def repr(cls, value):
def map(cls, value, map_function):
return {k: adapter_map(v, map_function) for k, v in value.items()}

def items(self, value, node):
@classmethod
def items(cls, value, node):
if node is None:
return [Item(value=value, node=None) for value in value.values()]

Expand Down
11 changes: 9 additions & 2 deletions src/inline_snapshot/_adapter/generic_call_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def map(cls, value, map_function):
},
)

def items(self, value, node):
new_args, new_kwargs = self.arguments(value)
@classmethod
def items(cls, value, node):
new_args, new_kwargs = cls.arguments(value)

if node is not None:
assert isinstance(node, ast.Call)
Expand All @@ -101,6 +102,12 @@ def assign(self, old_value, old_node, new_value):
result = yield from self.value_assign(old_value, old_node, new_value)
return result

call_type = self.adapter_context.eval(old_node.func)

if not (isinstance(call_type, type) and self.check_type(call_type)):
result = yield from self.value_assign(old_value, old_node, new_value)
return result

# positional arguments
for pos_arg in old_node.args:
if isinstance(pos_arg, ast.Starred):
Expand Down
5 changes: 3 additions & 2 deletions src/inline_snapshot/_adapter/sequence_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def map(cls, value, map_function):
result = [adapter_map(v, map_function) for v in value]
return cls.value_type(result)

def items(self, value, node):
@classmethod
def items(cls, value, node):
if node is None:
return [Item(value=v, node=None) for v in value]

assert isinstance(node, self.node_type), (node, self)
assert isinstance(node, cls.node_type), (node, cls)
assert len(value) == len(node.elts)

return [Item(value=v, node=n) for v, n in zip(value, node.elts)]
Expand Down
57 changes: 37 additions & 20 deletions src/inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from inline_snapshot._adapter.adapter import adapter_map
from inline_snapshot._source_file import SourceFile

from ._adapter.adapter import AdapterContext
from ._adapter.adapter import FrameContext
from ._adapter.adapter import get_adapter_type
from ._change import CallArg
from ._change import Change
Expand Down Expand Up @@ -83,12 +85,17 @@ class GenericValue(Snapshot):
_old_value: Any
_current_op = "undefined"
_ast_node: ast.Expr
_file: SourceFile
_context: AdapterContext

@property
def _file(self):
return self._context.file

def get_adapter(self, value):
return get_adapter_type(value)(self._file)
return get_adapter_type(value)(self._context)

def _re_eval(self, value):
def _re_eval(self, value, context: AdapterContext):
self._context = context

def re_eval(old_value, node, value):
if isinstance(old_value, Unmanaged):
Expand Down Expand Up @@ -168,13 +175,13 @@ def __getitem__(self, _item):


class UndecidedValue(GenericValue):
def __init__(self, old_value, ast_node, source):
def __init__(self, old_value, ast_node, context: AdapterContext):

old_value = adapter_map(old_value, map_unmanaged)
self._old_value = old_value
self._new_value = undefined
self._ast_node = ast_node
self._file = SourceFile(source)
self._context = context

def _change(self, cls):
self.__class__ = cls
Expand All @@ -186,13 +193,13 @@ def _get_changes(self) -> Iterator[Change]:

def handle(node, obj):

adapter = self.get_adapter(obj)
adapter = get_adapter_type(obj)
if adapter is not None and hasattr(adapter, "items"):
for item in adapter.items(obj, node):
yield from handle(item.node, item.value)
return

if not isinstance(obj, Unmanaged):
if not isinstance(obj, Unmanaged) and node is not None:
new_token = value_to_token(obj)
if self._file._token_of_node(node) != new_token:
new_code = self._file._token_to_code(new_token)
Expand Down Expand Up @@ -259,7 +266,12 @@ def __eq__(self, other):
_missing_values += 1

if not compare_only() and self._new_value is undefined:
adapter = Adapter(self._file).get_adapter(self._old_value, other)
frame = inspect.currentframe()
assert frame is not None
frame = frame.f_back
assert frame is not None

adapter = Adapter(self._context).get_adapter(self._old_value, other)
it = iter(adapter.assign(self._old_value, self._ast_node, clone(other)))
self._changes = []
while True:
Expand Down Expand Up @@ -473,18 +485,18 @@ def __getitem__(self, index):
child_node = self._ast_node.values[pos]

self._new_value[index] = UndecidedValue(
old_value.get(index, undefined), child_node, self._file
old_value.get(index, undefined), child_node, self._context
)

return self._new_value[index]

def _re_eval(self, value):
super()._re_eval(value)
def _re_eval(self, value, context: AdapterContext):
super()._re_eval(value, context)

if self._new_value is not undefined and self._old_value is not undefined:
for key, s in self._new_value.items():
if key in self._old_value:
s._re_eval(self._old_value[key])
s._re_eval(self._old_value[key], context)

def _new_code(self):
return (
Expand Down Expand Up @@ -594,6 +606,12 @@ def snapshot(obj: Any = undefined) -> Any:

expr = Source.executing(frame)

source = getattr(expr, "source", None) if expr is not None else None
context = AdapterContext(
file=SourceFile(source),
frame=FrameContext(globals=frame.f_globals, locals=frame.f_locals),
)

module = inspect.getmodule(frame)
if module is not None and module.__file__ is not None:
_files_with_snapshots.add(module.__file__)
Expand All @@ -604,12 +622,12 @@ def snapshot(obj: Any = undefined) -> Any:
node = expr.node
if node is None:
# we can run without knowing of the calling expression but we will not be able to fix code
snapshots[key] = SnapshotReference(obj, None)
snapshots[key] = SnapshotReference(obj, None, context)
else:
assert isinstance(node, ast.Call)
snapshots[key] = SnapshotReference(obj, expr)
snapshots[key] = SnapshotReference(obj, expr, context)
else:
snapshots[key]._re_eval(obj)
snapshots[key]._re_eval(obj, context)

return snapshots[key]._value

Expand All @@ -627,12 +645,11 @@ def used_externals(tree):


class SnapshotReference:
def __init__(self, value, expr):
def __init__(self, value, expr, context: AdapterContext):
self._expr = expr
node = expr.node.args[0] if expr is not None and expr.node.args else None
source = expr.source if expr is not None else None
self._value = UndecidedValue(value, node, source)
self._uses_externals = []
self._value = UndecidedValue(value, node, context)

def _changes(self):

Expand All @@ -657,5 +674,5 @@ def _changes(self):

yield from self._value._get_changes()

def _re_eval(self, obj):
self._value._re_eval(obj)
def _re_eval(self, obj, context: AdapterContext):
self._value._re_eval(obj, context)
39 changes: 39 additions & 0 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,42 @@ def test_something():
).run_pytest(
changed_files=snapshot({}),
)


def test_pydantic_factory_method():
Example(
"""\
from inline_snapshot import snapshot
from pydantic import BaseModel
class A(BaseModel):
a:int
@classmethod
def from_str(cls,s):
return cls(a=int(s))
def test_something():
assert A(a=2) == snapshot(A.from_str("1"))
"""
).run_pytest(
["--inline-snapshot=fix"],
changed_files=snapshot(
{
"test_something.py": """\
from inline_snapshot import snapshot
from pydantic import BaseModel
class A(BaseModel):
a:int
@classmethod
def from_str(cls,s):
return cls(a=int(s))
def test_something():
assert A(a=2) == snapshot(A(a=2))
"""
}
),
)

0 comments on commit ef707af

Please sign in to comment.