diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de72e6fe..18dd28c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,12 +40,12 @@ repos: hooks: - id: setup-cfg-fmt -- repo: https://github.com/asottile/reorder-python-imports - rev: v3.13.0 +- repo: https://github.com/pycqa/isort + rev: 5.13.2 hooks: - - args: - - --py38-plus - id: reorder-python-imports + - id: isort + name: isort (python) + - hooks: - args: - --py38-plus diff --git a/changelog.d/20250118_083011_15r10nk-git_refactor.md b/changelog.d/20250118_083011_15r10nk-git_refactor.md new file mode 100644 index 00000000..eb250176 --- /dev/null +++ b/changelog.d/20250118_083011_15r10nk-git_refactor.md @@ -0,0 +1,3 @@ +### Fixed + +- fixed some issues with dataclass arguments diff --git a/pyproject.toml b/pyproject.toml index 0301e13e..4af94662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,7 @@ dependencies = [ installer="uv" [tool.hatch.envs.cov.scripts] -gh=[ +github=[ "- rm htmlcov/*", "gh run download -n html-report -D htmlcov", "xdg-open htmlcov/index.html", @@ -220,3 +220,7 @@ version = "command: cz bump --get-next" [tool.pytest.ini_options] markers=["no_rewriting: marks tests which need no code rewriting and can be used with pypy"] + +[tool.isort] +profile="black" +force_single_line=true diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index a8050adb..b24167e1 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -1,5 +1,5 @@ -from ._code_repr import customize_repr from ._code_repr import HasRepr +from ._code_repr import customize_repr from ._external import external from ._external import outsource from ._inline_snapshot import snapshot diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py index d65f4b5e..5879307c 100644 --- a/src/inline_snapshot/_adapter/adapter.py +++ b/src/inline_snapshot/_adapter/adapter.py @@ -78,7 +78,7 @@ def get_adapter(self, old_value, new_value) -> Adapter: assert False def assign(self, old_value, old_node, new_value): - raise NotImplementedError(cls) + raise NotImplementedError(self) def value_assign(self, old_value, old_node, new_value): from .value_adapter import ValueAdapter diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py index 79be6d69..cab512b1 100644 --- a/src/inline_snapshot/_adapter/dict_adapter.py +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -7,8 +7,8 @@ from .._change import DictInsert from ..syntax_warnings import InlineSnapshotSyntaxWarning from .adapter import Adapter -from .adapter import adapter_map from .adapter import Item +from .adapter import adapter_map class DictAdapter(Adapter): @@ -86,7 +86,7 @@ def assign(self, old_value, old_node, new_value): old_value.keys(), (old_node.values if old_node is not None else [None] * len(old_value)), ): - if not key in new_value: + if key not in new_value: # delete entries yield Delete("fix", self.context.file._source, node, old_value[key]) diff --git a/src/inline_snapshot/_adapter/generic_call_adapter.py b/src/inline_snapshot/_adapter/generic_call_adapter.py index 1f2e4487..e405d6e7 100644 --- a/src/inline_snapshot/_adapter/generic_call_adapter.py +++ b/src/inline_snapshot/_adapter/generic_call_adapter.py @@ -4,17 +4,17 @@ import warnings from abc import ABC from collections import defaultdict +from dataclasses import MISSING from dataclasses import fields from dataclasses import is_dataclass -from dataclasses import MISSING from typing import Any from .._change import CallArg from .._change import Delete from ..syntax_warnings import InlineSnapshotSyntaxWarning from .adapter import Adapter -from .adapter import adapter_map from .adapter import Item +from .adapter import adapter_map def get_adapter_for_type(typ): @@ -83,10 +83,17 @@ def items(cls, value, node): assert isinstance(node, ast.Call) assert all(kw.arg for kw in node.keywords) kw_arg_node = {kw.arg: kw.value for kw in node.keywords if kw.arg}.get - pos_arg_node = lambda pos: node.args[pos] + + def pos_arg_node(pos): + return node.args[pos] + else: - kw_arg_node = lambda _: None - pos_arg_node = lambda _: None + + def kw_arg_node(_): + return None + + def pos_arg_node(_): + return None return [ Item(value=arg.value, node=pos_arg_node(i)) @@ -166,7 +173,7 @@ def assign(self, old_value, old_node, new_value): # keyword arguments result_kwargs = {} for kw in old_node.keywords: - if (missing := not kw.arg in new_kwargs) or new_kwargs[kw.arg].is_default: + if (missing := kw.arg not in new_kwargs) or new_kwargs[kw.arg].is_default: # delete entries yield Delete( "fix" if missing else "update", @@ -258,8 +265,11 @@ def arguments(cls, value): return ([], kwargs) def argument(self, value, pos_or_name): - assert isinstance(pos_or_name, str) - return getattr(value, pos_or_name) + if isinstance(pos_or_name, str): + return getattr(value, pos_or_name) + else: + args = [field for field in fields(value) if field.init] + return args[pos_or_name] try: diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py index cd3e8e05..0b6020ec 100644 --- a/src/inline_snapshot/_adapter/sequence_adapter.py +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -11,8 +11,8 @@ from .._compare_context import compare_context from ..syntax_warnings import InlineSnapshotSyntaxWarning from .adapter import Adapter -from .adapter import adapter_map from .adapter import Item +from .adapter import adapter_map class SequenceAdapter(Adapter): diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 05c888f7..7caeec57 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -2,16 +2,17 @@ from collections import defaultdict from dataclasses import dataclass from typing import Any -from typing import cast from typing import DefaultDict from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union +from typing import cast from asttokens.util import Token from executing.executing import EnhancedAST + from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 3b5252bc..47c677e8 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -4,7 +4,6 @@ from functools import singledispatch from unittest import mock - real_repr = repr diff --git a/src/inline_snapshot/_config.py b/src/inline_snapshot/_config.py index aad23d1f..15cc1b0a 100644 --- a/src/inline_snapshot/_config.py +++ b/src/inline_snapshot/_config.py @@ -7,7 +7,6 @@ from typing import List from typing import Optional - if sys.version_info >= (3, 11): from tomllib import loads else: diff --git a/src/inline_snapshot/_find_external.py b/src/inline_snapshot/_find_external.py index 6d0a6bc2..c2fd528e 100644 --- a/src/inline_snapshot/_find_external.py +++ b/src/inline_snapshot/_find_external.py @@ -5,7 +5,7 @@ from executing import Source from . import _external -from . import _inline_snapshot +from ._global_state import state from ._rewrite_code import ChangeRecorder from ._rewrite_code import end_of from ._rewrite_code import start_of @@ -47,7 +47,7 @@ def used_externals_in(source) -> Set[str]: def used_externals() -> Set[str]: result = set() - for filename in _inline_snapshot._files_with_snapshots: + for filename in state().files_with_snapshots: result |= used_externals_in(pathlib.Path(filename).read_text("utf-8")) return result diff --git a/src/inline_snapshot/_flags.py b/src/inline_snapshot/_flags.py new file mode 100644 index 00000000..06f9e6fc --- /dev/null +++ b/src/inline_snapshot/_flags.py @@ -0,0 +1,24 @@ +from typing import Set + +from ._types import Category + + +class Flags: + """ + fix: the value needs to be changed to pass the tests + update: the value should be updated because the token-stream has changed + create: the snapshot is empty `snapshot()` + trim: the snapshot contains more values than neccessary. 1 could be trimmed in `5 in snapshot([1,5])`. + """ + + def __init__(self, flags: Set[Category] = set()): + self.fix = "fix" in flags + self.update = "update" in flags + self.create = "create" in flags + self.trim = "trim" in flags + + def to_set(self): + return {k for k, v in self.__dict__.items() if v} + + def __repr__(self): + return f"Flags({self.to_set()})" diff --git a/src/inline_snapshot/_global_state.py b/src/inline_snapshot/_global_state.py new file mode 100644 index 00000000..bea11c2c --- /dev/null +++ b/src/inline_snapshot/_global_state.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import contextlib +from dataclasses import dataclass +from dataclasses import field +from typing import Generator + +from ._flags import Flags + + +@dataclass +class State: + # snapshot + missing_values: int = 0 + incorrect_values: int = 0 + + snapshots: dict = field(default_factory=dict) + update_flags: Flags = field(default_factory=Flags) + active: bool = True + files_with_snapshots: set[str] = field(default_factory=set) + + # external + storage = None + + +_current = State() +_current.active = False + + +def state() -> State: + global _current + return _current + + +@contextlib.contextmanager +def snapshot_env() -> Generator[State]: + + global _current + old = _current + _current = State() + + try: + yield _current + finally: + _current = old diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 5aed5ed2..67c5f5c4 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,551 +1,19 @@ import ast -import copy import inspect from typing import Any -from typing import cast -from typing import Dict # noqa -from typing import Iterator -from typing import List -from typing import Set -from typing import Tuple # noqa from typing import TypeVar +from typing import cast from executing import Source -from inline_snapshot._adapter.adapter import Adapter -from inline_snapshot._adapter.adapter import adapter_map + from inline_snapshot._source_file import SourceFile from ._adapter.adapter import AdapterContext from ._adapter.adapter import FrameContext -from ._adapter.adapter import get_adapter_type from ._change import CallArg -from ._change import Change -from ._change import Delete -from ._change import DictInsert -from ._change import ListInsert -from ._change import Replace -from ._code_repr import code_repr -from ._compare_context import compare_only -from ._exceptions import UsageError +from ._global_state import state from ._sentinels import undefined -from ._types import Category -from ._types import Snapshot -from ._unmanaged import map_unmanaged -from ._unmanaged import Unmanaged -from ._unmanaged import update_allowed -from ._utils import value_to_token - - -snapshots = {} # type: Dict[Tuple[int, int], SnapshotReference] - -_active = False - -_files_with_snapshots: Set[str] = set() - -_missing_values = 0 -_incorrect_values = 0 - - -def _return(result): - global _incorrect_values - if not result: - _incorrect_values += 1 - return result - - -class Flags: - """ - fix: the value needs to be changed to pass the tests - update: the value should be updated because the token-stream has changed - create: the snapshot is empty `snapshot()` - trim: the snapshot contains more values than neccessary. 1 could be trimmed in `5 in snapshot([1,5])`. - """ - - def __init__(self, flags: Set[Category] = set()): - self.fix = "fix" in flags - self.update = "update" in flags - self.create = "create" in flags - self.trim = "trim" in flags - - def to_set(self): - return {k for k, v in self.__dict__.items() if v} - - def __repr__(self): - return f"Flags({self.to_set()})" - - -_update_flags = Flags() - - -def ignore_old_value(): - return _update_flags.fix or _update_flags.update - - -class GenericValue(Snapshot): - _new_value: Any - _old_value: Any - _current_op = "undefined" - _ast_node: ast.Expr - _context: AdapterContext - - @property - def _file(self): - return self._context.file - - def get_adapter(self, value): - return get_adapter_type(value)(self._context) - - def _re_eval(self, value, context: AdapterContext): - self._context = context - - def re_eval(old_value, node, value): - if isinstance(old_value, Unmanaged): - old_value.value = value - return - - assert type(old_value) is type(value) - - adapter = self.get_adapter(old_value) - if adapter is not None and hasattr(adapter, "items"): - old_items = adapter.items(old_value, node) - new_items = adapter.items(value, node) - assert len(old_items) == len(new_items) - - for old_item, new_item in zip(old_items, new_items): - re_eval(old_item.value, old_item.node, new_item.value) - - else: - if update_allowed(old_value): - if not old_value == value: - raise UsageError( - "snapshot value should not change. Use Is(...) for dynamic snapshot parts." - ) - else: - assert False, "old_value should be converted to Unmanaged" - - re_eval(self._old_value, self._ast_node, value) - - def _ignore_old(self): - return ( - _update_flags.fix - or _update_flags.update - or _update_flags.create - or self._old_value is undefined - ) - - def _visible_value(self): - if self._ignore_old(): - return self._new_value - else: - return self._old_value - - def _get_changes(self) -> Iterator[Change]: - raise NotImplementedError() - - def _new_code(self): - raise NotImplementedError() - - def __repr__(self): - return repr(self._visible_value()) - - def _type_error(self, op): - __tracebackhide__ = True - raise TypeError( - f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`" - ) - - def __eq__(self, _other): - __tracebackhide__ = True - self._type_error("==") - - def __le__(self, _other): - __tracebackhide__ = True - self._type_error("<=") - - def __ge__(self, _other): - __tracebackhide__ = True - self._type_error(">=") - - def __contains__(self, _other): - __tracebackhide__ = True - self._type_error("in") - - def __getitem__(self, _item): - __tracebackhide__ = True - self._type_error("snapshot[key]") - - -class UndecidedValue(GenericValue): - def __init__(self, old_value, ast_node, context: AdapterContext): - - old_value = adapter_map(old_value, map_unmanaged) - self._old_value = old_value - self._new_value = undefined - self._ast_node = ast_node - self._context = context - - def _change(self, cls): - self.__class__ = cls - - def _new_code(self): - assert False - - def _get_changes(self) -> Iterator[Change]: - - def handle(node, obj): - - adapter = get_adapter_type(obj) - if adapter is not None and hasattr(adapter, "items"): - for item in adapter.items(obj, node): - yield from handle(item.node, item.value) - return - - if not isinstance(obj, Unmanaged) and node is not None: - new_token = value_to_token(obj) - if self._file._token_of_node(node) != new_token: - new_code = self._file._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - file=self._file, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - - if self._file._source is not None: - yield from handle(self._ast_node, self._old_value) - - # functions which determine the type - - def __eq__(self, other): - self._change(EqValue) - return self == other - - def __le__(self, other): - self._change(MinValue) - return self <= other - - def __ge__(self, other): - self._change(MaxValue) - return self >= other - - def __contains__(self, other): - self._change(CollectionValue) - return other in self - - def __getitem__(self, item): - self._change(DictValue) - return self[item] - - -def clone(obj): - new = copy.deepcopy(obj) - if not obj == new: - raise UsageError( - f"""\ -inline-snapshot uses `copy.deepcopy` to copy objects, -but the copied object is not equal to the original one: - -original: {code_repr(obj)} -copied: {code_repr(new)} - -Please fix the way your object is copied or your __eq__ implementation. -""" - ) - return new - - -class EqValue(GenericValue): - _current_op = "x == snapshot" - _changes: List[Change] - - def __eq__(self, other): - global _missing_values - if self._old_value is undefined: - _missing_values += 1 - - if not compare_only() and self._new_value is undefined: - adapter = Adapter(self._context).get_adapter(self._old_value, other) - it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) - self._changes = [] - while True: - try: - self._changes.append(next(it)) - except StopIteration as ex: - self._new_value = ex.value - break - - return _return(self._visible_value() == other) - - # if self._new_value is undefined: - # self._new_value = use_valid_old_values(self._old_value, clone(other)) - # if self._old_value is undefined or ignore_old_value(): - # return True - # return _return(self._old_value == other) - # else: - # return _return(self._new_value == other) - - def _new_code(self): - return self._file._value_to_code(self._new_value) - - def _get_changes(self) -> Iterator[Change]: - return iter(self._changes) - - -class MinMaxValue(GenericValue): - """Generic implementation for <=, >=""" - - @staticmethod - def cmp(a, b): - raise NotImplemented - - def _generic_cmp(self, other): - global _missing_values - if self._old_value is undefined: - _missing_values += 1 - - if self._new_value is undefined: - self._new_value = clone(other) - if self._old_value is undefined or ignore_old_value(): - return True - return _return(self.cmp(self._old_value, other)) - else: - if not self.cmp(self._new_value, other): - self._new_value = clone(other) - - return _return(self.cmp(self._visible_value(), other)) - - def _new_code(self): - return self._file._value_to_code(self._new_value) - - def _get_changes(self) -> Iterator[Change]: - new_token = value_to_token(self._new_value) - if not self.cmp(self._old_value, self._new_value): - flag = "fix" - elif not self.cmp(self._new_value, self._old_value): - flag = "trim" - elif ( - self._ast_node is not None - and self._file._token_of_node(self._ast_node) != new_token - ): - flag = "update" - else: - return - - new_code = self._file._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - file=self._file, - new_code=new_code, - flag=flag, - old_value=self._old_value, - new_value=self._new_value, - ) - - -class MinValue(MinMaxValue): - """ - handles: - - >>> snapshot(5) <= 6 - True - - >>> 6 >= snapshot(5) - True - - """ - - _current_op = "x >= snapshot" - - @staticmethod - def cmp(a, b): - return a <= b - - __le__ = MinMaxValue._generic_cmp - - -class MaxValue(MinMaxValue): - """ - handles: - - >>> snapshot(5) >= 4 - True - - >>> 4 <= snapshot(5) - True - - """ - - _current_op = "x <= snapshot" - - @staticmethod - def cmp(a, b): - return a >= b - - __ge__ = MinMaxValue._generic_cmp - - -class CollectionValue(GenericValue): - _current_op = "x in snapshot" - - def __contains__(self, item): - global _missing_values - if self._old_value is undefined: - _missing_values += 1 - - if self._new_value is undefined: - self._new_value = [clone(item)] - else: - if item not in self._new_value: - self._new_value.append(clone(item)) - - if ignore_old_value() or self._old_value is undefined: - return True - else: - return _return(item in self._old_value) - - def _new_code(self): - return self._file._value_to_code(self._new_value) - - def _get_changes(self) -> Iterator[Change]: - - if self._ast_node is None: - elements = [None] * len(self._old_value) - else: - assert isinstance(self._ast_node, ast.List) - elements = self._ast_node.elts - - for old_value, old_node in zip(self._old_value, elements): - if old_value not in self._new_value: - yield Delete( - flag="trim", - file=self._file, - node=old_node, - old_value=old_value, - ) - continue - - # check for update - new_token = value_to_token(old_value) - - if ( - old_node is not None - and self._file._token_of_node(old_node) != new_token - ): - new_code = self._file._token_to_code(new_token) - - yield Replace( - node=old_node, - file=self._file, - new_code=new_code, - flag="update", - old_value=old_value, - new_value=old_value, - ) - - new_values = [v for v in self._new_value if v not in self._old_value] - if new_values: - yield ListInsert( - flag="fix", - file=self._file, - node=self._ast_node, - position=len(self._old_value), - new_code=[self._file._value_to_code(v) for v in new_values], - new_values=new_values, - ) - - -class DictValue(GenericValue): - _current_op = "snapshot[key]" - - def __getitem__(self, index): - global _missing_values - - if self._new_value is undefined: - self._new_value = {} - - if index not in self._new_value: - old_value = self._old_value - if old_value is undefined: - _missing_values += 1 - old_value = {} - - child_node = None - if self._ast_node is not None: - assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) - child_node = self._ast_node.values[pos] - - self._new_value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._context - ) - - return self._new_value[index] - - def _re_eval(self, value, context: AdapterContext): - super()._re_eval(value, context) - - if self._new_value is not undefined and self._old_value is not undefined: - for key, s in self._new_value.items(): - if key in self._old_value: - s._re_eval(self._old_value[key], context) - - def _new_code(self): - return ( - "{" - + ", ".join( - [ - f"{self._file._value_to_code(k)}: {v._new_code()}" - for k, v in self._new_value.items() - if not isinstance(v, UndecidedValue) - ] - ) - + "}" - ) - - def _get_changes(self) -> Iterator[Change]: - - assert self._old_value is not undefined - - if self._ast_node is None: - values = [None] * len(self._old_value) - else: - assert isinstance(self._ast_node, ast.Dict) - values = self._ast_node.values - - for key, node in zip(self._old_value.keys(), values): - if key in self._new_value: - # check values with same keys - yield from self._new_value[key]._get_changes() - else: - # delete entries - yield Delete("trim", self._file, node, self._old_value[key]) - - to_insert = [] - for key, new_value_element in self._new_value.items(): - if key not in self._old_value and not isinstance( - new_value_element, UndecidedValue - ): - # add new values - to_insert.append((key, new_value_element._new_code())) - - if to_insert: - new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] - yield DictInsert( - "create", - self._file, - self._ast_node, - len(self._old_value), - new_code, - to_insert, - ) - - -T = TypeVar("T") +from ._snapshot.undecided_value import UndecidedValue class ReprWrapper: @@ -585,7 +53,7 @@ def snapshot(obj: Any = undefined) -> Any: `snapshot(value)` has general the semantic of an noop which returns `value`. """ - if not _active: + if not state().active: if obj is undefined: raise AssertionError( "your snapshot is missing a value run pytest with --inline-snapshot=create" @@ -610,22 +78,22 @@ def snapshot(obj: Any = undefined) -> Any: module = inspect.getmodule(frame) if module is not None and module.__file__ is not None: - _files_with_snapshots.add(module.__file__) + state().files_with_snapshots.add(module.__file__) key = id(frame.f_code), frame.f_lasti - if key not in snapshots: + if key not in state().snapshots: node = expr.node if node is None: # we can run without knowing of the calling expression but we will not be able to fix code - snapshots[key] = SnapshotReference(obj, None, context) + state().snapshots[key] = SnapshotReference(obj, None, context) else: assert isinstance(node, ast.Call) - snapshots[key] = SnapshotReference(obj, expr, context) + state().snapshots[key] = SnapshotReference(obj, expr, context) else: - snapshots[key]._re_eval(obj, context) + state().snapshots[key]._re_eval(obj, context) - return snapshots[key]._value + return state().snapshots[key]._value def used_externals(tree): @@ -644,7 +112,6 @@ class SnapshotReference: def __init__(self, value, expr, context: AdapterContext): self._expr = expr node = expr.node.args[0] if expr is not None and expr.node.args else None - source = expr.source if expr is not None else None self._value = UndecidedValue(value, node, context) def _changes(self): diff --git a/src/inline_snapshot/_rewrite_code.py b/src/inline_snapshot/_rewrite_code.py index 109e4b90..54cbe475 100644 --- a/src/inline_snapshot/_rewrite_code.py +++ b/src/inline_snapshot/_rewrite_code.py @@ -16,7 +16,6 @@ from ._format import enforce_formatting from ._format import format_code - if sys.version_info >= (3, 10): from itertools import pairwise else: diff --git a/src/inline_snapshot/_snapshot/collection_value.py b/src/inline_snapshot/_snapshot/collection_value.py new file mode 100644 index 00000000..951ca021 --- /dev/null +++ b/src/inline_snapshot/_snapshot/collection_value.py @@ -0,0 +1,82 @@ +import ast +from typing import Iterator + +from .._change import Change +from .._change import Delete +from .._change import ListInsert +from .._change import Replace +from .._global_state import state +from .._sentinels import undefined +from .._utils import value_to_token +from .generic_value import GenericValue +from .generic_value import clone +from .generic_value import ignore_old_value + + +class CollectionValue(GenericValue): + _current_op = "x in snapshot" + + def __contains__(self, item): + if self._old_value is undefined: + state().missing_values += 1 + + if self._new_value is undefined: + self._new_value = [clone(item)] + else: + if item not in self._new_value: + self._new_value.append(clone(item)) + + if ignore_old_value() or self._old_value is undefined: + return True + else: + return self._return(item in self._old_value) + + def _new_code(self): + return self._file._value_to_code(self._new_value) + + def _get_changes(self) -> Iterator[Change]: + + if self._ast_node is None: + elements = [None] * len(self._old_value) + else: + assert isinstance(self._ast_node, ast.List) + elements = self._ast_node.elts + + for old_value, old_node in zip(self._old_value, elements): + if old_value not in self._new_value: + yield Delete( + flag="trim", + file=self._file, + node=old_node, + old_value=old_value, + ) + continue + + # check for update + new_token = value_to_token(old_value) + + if ( + old_node is not None + and self._file._token_of_node(old_node) != new_token + ): + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=old_node, + file=self._file, + new_code=new_code, + flag="update", + old_value=old_value, + new_value=old_value, + ) + + new_values = [v for v in self._new_value if v not in self._old_value] + if new_values: + yield ListInsert( + flag="fix", + file=self._file, + node=self._ast_node, + position=len(self._old_value), + new_code=[self._file._value_to_code(v) for v in new_values], + new_values=new_values, + ) diff --git a/src/inline_snapshot/_snapshot/dict_value.py b/src/inline_snapshot/_snapshot/dict_value.py new file mode 100644 index 00000000..afed0073 --- /dev/null +++ b/src/inline_snapshot/_snapshot/dict_value.py @@ -0,0 +1,97 @@ +import ast +from typing import Iterator + +from .._adapter.adapter import AdapterContext +from .._change import Change +from .._change import Delete +from .._change import DictInsert +from .._global_state import state +from .._inline_snapshot import UndecidedValue +from .._sentinels import undefined +from .generic_value import GenericValue + + +class DictValue(GenericValue): + _current_op = "snapshot[key]" + + def __getitem__(self, index): + + if self._new_value is undefined: + self._new_value = {} + + if index not in self._new_value: + old_value = self._old_value + if old_value is undefined: + state().missing_values += 1 + old_value = {} + + child_node = None + if self._ast_node is not None: + assert isinstance(self._ast_node, ast.Dict) + if index in old_value: + pos = list(old_value.keys()).index(index) + child_node = self._ast_node.values[pos] + + self._new_value[index] = UndecidedValue( + old_value.get(index, undefined), child_node, self._context + ) + + return self._new_value[index] + + def _re_eval(self, value, context: AdapterContext): + super()._re_eval(value, context) + + if self._new_value is not undefined and self._old_value is not undefined: + for key, s in self._new_value.items(): + if key in self._old_value: + s._re_eval(self._old_value[key], context) + + def _new_code(self): + return ( + "{" + + ", ".join( + [ + f"{self._file._value_to_code(k)}: {v._new_code()}" + for k, v in self._new_value.items() + if not isinstance(v, UndecidedValue) + ] + ) + + "}" + ) + + def _get_changes(self) -> Iterator[Change]: + + assert self._old_value is not undefined + + if self._ast_node is None: + values = [None] * len(self._old_value) + else: + assert isinstance(self._ast_node, ast.Dict) + values = self._ast_node.values + + for key, node in zip(self._old_value.keys(), values): + if key in self._new_value: + # check values with same keys + yield from self._new_value[key]._get_changes() + else: + # delete entries + yield Delete("trim", self._file, node, self._old_value[key]) + + to_insert = [] + for key, new_value_element in self._new_value.items(): + if key not in self._old_value and not isinstance( + new_value_element, UndecidedValue + ): + # add new values + to_insert.append((key, new_value_element._new_code())) + + if to_insert: + new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] + yield DictInsert( + "create", + self._file, + self._ast_node, + len(self._old_value), + new_code, + to_insert, + ) diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py new file mode 100644 index 00000000..8648ffab --- /dev/null +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -0,0 +1,39 @@ +from typing import Iterator +from typing import List + +from inline_snapshot._adapter.adapter import Adapter + +from .._change import Change +from .._compare_context import compare_only +from .._global_state import state +from .._sentinels import undefined +from .generic_value import GenericValue +from .generic_value import clone + + +class EqValue(GenericValue): + _current_op = "x == snapshot" + _changes: List[Change] + + def __eq__(self, other): + if self._old_value is undefined: + state().missing_values += 1 + + if not compare_only() and self._new_value is undefined: + adapter = Adapter(self._context).get_adapter(self._old_value, other) + it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + self._changes = [] + while True: + try: + self._changes.append(next(it)) + except StopIteration as ex: + self._new_value = ex.value + break + + return self._return(self._visible_value() == other) + + def _new_code(self): + return self._file._value_to_code(self._new_value) + + def _get_changes(self) -> Iterator[Change]: + return iter(self._changes) diff --git a/src/inline_snapshot/_snapshot/generic_value.py b/src/inline_snapshot/_snapshot/generic_value.py new file mode 100644 index 00000000..7e47c1d2 --- /dev/null +++ b/src/inline_snapshot/_snapshot/generic_value.py @@ -0,0 +1,136 @@ +import ast +import copy +from typing import Any +from typing import Iterator + +from .._adapter.adapter import AdapterContext +from .._adapter.adapter import get_adapter_type +from .._change import Change +from .._code_repr import code_repr +from .._exceptions import UsageError +from .._global_state import state +from .._sentinels import undefined +from .._types import Snapshot +from .._unmanaged import Unmanaged +from .._unmanaged import update_allowed + + +def clone(obj): + new = copy.deepcopy(obj) + if not obj == new: + raise UsageError( + f"""\ +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +original: {code_repr(obj)} +copied: {code_repr(new)} + +Please fix the way your object is copied or your __eq__ implementation. +""" + ) + return new + + +def ignore_old_value(): + return state().update_flags.fix or state().update_flags.update + + +class GenericValue(Snapshot): + _new_value: Any + _old_value: Any + _current_op = "undefined" + _ast_node: ast.Expr + _context: AdapterContext + + @staticmethod + def _return(result): + if not result: + state().incorrect_values += 1 + return result + + @property + def _file(self): + return self._context.file + + def get_adapter(self, value): + return get_adapter_type(value)(self._context) + + def _re_eval(self, value, context: AdapterContext): + self._context = context + + def re_eval(old_value, node, value): + if isinstance(old_value, Unmanaged): + old_value.value = value + return + + assert type(old_value) is type(value) + + adapter = self.get_adapter(old_value) + if adapter is not None and hasattr(adapter, "items"): + old_items = adapter.items(old_value, node) + new_items = adapter.items(value, node) + assert len(old_items) == len(new_items) + + for old_item, new_item in zip(old_items, new_items): + re_eval(old_item.value, old_item.node, new_item.value) + + else: + if update_allowed(old_value): + if not old_value == value: + raise UsageError( + "snapshot value should not change. Use Is(...) for dynamic snapshot parts." + ) + else: + assert False, "old_value should be converted to Unmanaged" + + re_eval(self._old_value, self._ast_node, value) + + def _ignore_old(self): + return ( + state().update_flags.fix + or state().update_flags.update + or state().update_flags.create + or self._old_value is undefined + ) + + def _visible_value(self): + if self._ignore_old(): + return self._new_value + else: + return self._old_value + + def _get_changes(self) -> Iterator[Change]: + raise NotImplementedError() + + def _new_code(self): + raise NotImplementedError() + + def __repr__(self): + return repr(self._visible_value()) + + def _type_error(self, op): + __tracebackhide__ = True + raise TypeError( + f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`" + ) + + def __eq__(self, _other): + __tracebackhide__ = True + self._type_error("==") + + def __le__(self, _other): + __tracebackhide__ = True + self._type_error("<=") + + def __ge__(self, _other): + __tracebackhide__ = True + self._type_error(">=") + + def __contains__(self, _other): + __tracebackhide__ = True + self._type_error("in") + + def __getitem__(self, _item): + __tracebackhide__ = True + self._type_error("snapshot[key]") diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py new file mode 100644 index 00000000..9ef0a65c --- /dev/null +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -0,0 +1,103 @@ +from typing import Iterator + +from .._change import Change +from .._change import Replace +from .._global_state import state +from .._sentinels import undefined +from .._utils import value_to_token +from .generic_value import GenericValue +from .generic_value import clone +from .generic_value import ignore_old_value + + +class MinMaxValue(GenericValue): + """Generic implementation for <=, >=""" + + @staticmethod + def cmp(a, b): + raise NotImplementedError + + def _generic_cmp(self, other): + if self._old_value is undefined: + state().missing_values += 1 + + if self._new_value is undefined: + self._new_value = clone(other) + if self._old_value is undefined or ignore_old_value(): + return True + return self._return(self.cmp(self._old_value, other)) + else: + if not self.cmp(self._new_value, other): + self._new_value = clone(other) + + return self._return(self.cmp(self._visible_value(), other)) + + def _new_code(self): + return self._file._value_to_code(self._new_value) + + def _get_changes(self) -> Iterator[Change]: + new_token = value_to_token(self._new_value) + if not self.cmp(self._old_value, self._new_value): + flag = "fix" + elif not self.cmp(self._new_value, self._old_value): + flag = "trim" + elif ( + self._ast_node is not None + and self._file._token_of_node(self._ast_node) != new_token + ): + flag = "update" + else: + return + + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag=flag, + old_value=self._old_value, + new_value=self._new_value, + ) + + +class MinValue(MinMaxValue): + """ + handles: + + >>> snapshot(5) <= 6 + True + + >>> 6 >= snapshot(5) + True + + """ + + _current_op = "x >= snapshot" + + @staticmethod + def cmp(a, b): + return a <= b + + __le__ = MinMaxValue._generic_cmp + + +class MaxValue(MinMaxValue): + """ + handles: + + >>> snapshot(5) >= 4 + True + + >>> 4 <= snapshot(5) + True + + """ + + _current_op = "x <= snapshot" + + @staticmethod + def cmp(a, b): + return a >= b + + __ge__ = MinMaxValue._generic_cmp diff --git a/src/inline_snapshot/_snapshot/undecided_value.py b/src/inline_snapshot/_snapshot/undecided_value.py new file mode 100644 index 00000000..eb92a573 --- /dev/null +++ b/src/inline_snapshot/_snapshot/undecided_value.py @@ -0,0 +1,88 @@ +from typing import Iterator + +from inline_snapshot._adapter.adapter import adapter_map + +from .._adapter.adapter import AdapterContext +from .._adapter.adapter import get_adapter_type +from .._change import Change +from .._change import Replace +from .._sentinels import undefined +from .._unmanaged import Unmanaged +from .._unmanaged import map_unmanaged +from .._utils import value_to_token +from .generic_value import GenericValue + + +class UndecidedValue(GenericValue): + def __init__(self, old_value, ast_node, context: AdapterContext): + + old_value = adapter_map(old_value, map_unmanaged) + self._old_value = old_value + self._new_value = undefined + self._ast_node = ast_node + self._context = context + + def _change(self, cls): + self.__class__ = cls + + def _new_code(self): + assert False + + def _get_changes(self) -> Iterator[Change]: + + def handle(node, obj): + + adapter = get_adapter_type(obj) + if adapter is not None and hasattr(adapter, "items"): + for item in adapter.items(obj, node): + yield from handle(item.node, item.value) + return + + if not isinstance(obj, Unmanaged) and node is not None: + new_token = value_to_token(obj) + if self._file._token_of_node(node) != new_token: + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag="update", + old_value=self._old_value, + new_value=self._old_value, + ) + + if self._file._source is not None: + yield from handle(self._ast_node, self._old_value) + + # functions which determine the type + + def __eq__(self, other): + from .._snapshot.eq_value import EqValue + + self._change(EqValue) + return self == other + + def __le__(self, other): + from .._snapshot.min_max_value import MinValue + + self._change(MinValue) + return self <= other + + def __ge__(self, other): + from .._snapshot.min_max_value import MaxValue + + self._change(MaxValue) + return self >= other + + def __contains__(self, other): + from .._snapshot.collection_value import CollectionValue + + self._change(CollectionValue) + return other in self + + def __getitem__(self, item): + from .._snapshot.dict_value import DictValue + + self._change(DictValue) + return self[item] diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py index 09a60849..1a788d7d 100644 --- a/src/inline_snapshot/_source_file.py +++ b/src/inline_snapshot/_source_file.py @@ -2,6 +2,7 @@ from pathlib import Path from executing import Source + from inline_snapshot._format import enforce_formatting from inline_snapshot._format import format_code from inline_snapshot._utils import normalize diff --git a/src/inline_snapshot/_types.py b/src/inline_snapshot/_types.py index 50807a76..68209b02 100644 --- a/src/inline_snapshot/_types.py +++ b/src/inline_snapshot/_types.py @@ -1,10 +1,8 @@ """The following types are for type checking only.""" -... # prevent lint error with black and reorder-python-imports - +from typing import TYPE_CHECKING from typing import Generic from typing import Literal -from typing import TYPE_CHECKING from typing import TypeVar if TYPE_CHECKING: diff --git a/src/inline_snapshot/extra.py b/src/inline_snapshot/extra.py index 0861ed40..aed18abb 100644 --- a/src/inline_snapshot/extra.py +++ b/src/inline_snapshot/extra.py @@ -5,13 +5,16 @@ not depend on other libraries. """ -... # prevent lint error with black and reorder-python-imports import contextlib -from typing import List, Tuple, Union -from inline_snapshot._types import Snapshot -from contextlib import redirect_stdout, redirect_stderr import io import warnings +from contextlib import redirect_stderr +from contextlib import redirect_stdout +from typing import List +from typing import Tuple +from typing import Union + +from inline_snapshot._types import Snapshot @contextlib.contextmanager diff --git a/src/inline_snapshot/pytest_plugin.py b/src/inline_snapshot/pytest_plugin.py index afda98ab..7195aa7a 100644 --- a/src/inline_snapshot/pytest_plugin.py +++ b/src/inline_snapshot/pytest_plugin.py @@ -4,23 +4,26 @@ from pathlib import Path import pytest -from inline_snapshot._problems import report_problems -from inline_snapshot.pydantic_fix import pydantic_fix from rich import box from rich.console import Console from rich.panel import Panel from rich.prompt import Confirm from rich.syntax import Syntax +from inline_snapshot._problems import report_problems +from inline_snapshot.pydantic_fix import pydantic_fix + from . import _config from . import _external from . import _find_external -from . import _inline_snapshot from ._change import apply_all from ._code_repr import used_hasrepr from ._find_external import ensure_import +from ._flags import Flags +from ._global_state import state from ._inline_snapshot import used_externals from ._rewrite_code import ChangeRecorder +from ._snapshot.generic_value import GenericValue pytest.register_assert_rewrite("inline_snapshot.extra") pytest.register_assert_rewrite("inline_snapshot.testing._example") @@ -105,19 +108,17 @@ def pytest_configure(config): ) if xdist_running(config) or not is_implementation_supported(): - _inline_snapshot._active = False + state().active = False elif flags & {"review"}: - _inline_snapshot._active = True + state().active = True - _inline_snapshot._update_flags = _inline_snapshot.Flags( - {"fix", "create", "update", "trim"} - ) + state().update_flags = Flags({"fix", "create", "update", "trim"}) else: - _inline_snapshot._active = "disable" not in flags + state().active = "disable" not in flags - _inline_snapshot._update_flags = _inline_snapshot.Flags(flags & categories) + state().update_flags = Flags(flags & categories) external_storage = ( _config.config.storage_dir or config.rootpath / ".inline-snapshot" @@ -140,24 +141,24 @@ def pytest_configure(config): @pytest.fixture(autouse=True) def snapshot_check(): - _inline_snapshot._missing_values = 0 - _inline_snapshot._incorrect_values = 0 + state().missing_values = 0 + state().incorrect_values = 0 yield - missing_values = _inline_snapshot._missing_values - incorrect_values = _inline_snapshot._incorrect_values + missing_values = state().missing_values + incorrect_values = state().incorrect_values - if missing_values != 0 and not _inline_snapshot._update_flags.create: + if missing_values != 0 and not state().update_flags.create: pytest.fail( ( - f"your snapshot is missing one value." + "your snapshot is missing one value." if missing_values == 1 else f"your snapshot is missing {missing_values} values." ), pytrace=False, ) - if incorrect_values != 0 and not _inline_snapshot._update_flags.fix: + if incorrect_values != 0 and not state().update_flags.fix: pytest.fail( "some snapshots in this test have incorrect values.", pytrace=False, @@ -166,12 +167,12 @@ def snapshot_check(): def pytest_assertrepr_compare(config, op, left, right): results = [] - if isinstance(left, _inline_snapshot.GenericValue): + if isinstance(left, GenericValue): results = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left._visible_value(), right=right ) - if isinstance(right, _inline_snapshot.GenericValue): + if isinstance(right, GenericValue): results = config.hook.pytest_assertrepr_compare( config=config, op=op, left=left, right=right._visible_value() ) @@ -218,7 +219,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): ) return - if not _inline_snapshot._active: + if not state().active: return terminalreporter.section("inline snapshot") @@ -243,7 +244,7 @@ def apply_changes(flag): console.print() return result else: - console.print(f"These changes are not applied.") + console.print("These changes are not applied.") console.print( f"Use [b]--inline-snapshot={flag}[/] to apply them, or use the interactive mode with [b]--inline-snapshot=review[/]", highlight=False, @@ -266,7 +267,7 @@ def apply_changes(flag): "create": 0, } - for snapshot in _inline_snapshot.snapshots.values(): + for snapshot in state().snapshots.values(): all_categories = set() for change in snapshot._changes(): changes[change.flag].append(change) @@ -285,7 +286,7 @@ def apply_changes(flag): def report(flag, message, message_n): num = snapshot_changes[flag] - if num and not getattr(_inline_snapshot._update_flags, flag): + if num and not getattr(state().update_flags, flag): console.print( message if num == 1 else message.format(num=num), highlight=False, @@ -391,7 +392,7 @@ def report(flag, message, message_n): unused_externals = _find_external.unused_externals() - if unused_externals and _inline_snapshot._update_flags.trim: + if unused_externals and state().update_flags.trim: for name in unused_externals: assert _external.storage _external.storage.remove(name) diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index 3badf584..3aa8b3db 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import os import platform import re @@ -13,55 +12,18 @@ from tempfile import TemporaryDirectory from typing import Any +from rich.console import Console + import inline_snapshot._external -import inline_snapshot._external as external from inline_snapshot._problems import report_problems -from rich.console import Console -from .. import _inline_snapshot from .._change import apply_all -from .._inline_snapshot import Flags +from .._flags import Flags +from .._global_state import snapshot_env from .._rewrite_code import ChangeRecorder from .._types import Category from .._types import Snapshot - -@contextlib.contextmanager -def snapshot_env(): - import inline_snapshot._inline_snapshot as inline_snapshot - - current = ( - inline_snapshot.snapshots, - inline_snapshot._update_flags, - inline_snapshot._active, - external.storage, - inline_snapshot._files_with_snapshots, - inline_snapshot._missing_values, - inline_snapshot._incorrect_values, - ) - - inline_snapshot.snapshots = {} - inline_snapshot._update_flags = inline_snapshot.Flags() - inline_snapshot._active = True - external.storage = None - inline_snapshot._files_with_snapshots = set() - inline_snapshot._missing_values = 0 - inline_snapshot._incorrect_values = 0 - - try: - yield - finally: - ( - inline_snapshot.snapshots, - inline_snapshot._update_flags, - inline_snapshot._active, - external.storage, - inline_snapshot._files_with_snapshots, - inline_snapshot._missing_values, - inline_snapshot._incorrect_values, - ) = current - - ansi_escape = re.compile( r""" \x1B # ESC @@ -171,9 +133,9 @@ def run_inline( self._write_files(tmp_path) raised_exception = None - with snapshot_env(): + with snapshot_env() as state: with ChangeRecorder().activate() as recorder: - _inline_snapshot._update_flags = Flags({*flags}) + state.update_flags = Flags({*flags}) inline_snapshot._external.storage = ( inline_snapshot._external.DiscStorage(tmp_path / ".storage") ) @@ -196,10 +158,10 @@ def run_inline( raised_exception = e finally: - _inline_snapshot._active = False + state.active = False changes = [] - for snapshot in _inline_snapshot.snapshots.values(): + for snapshot in state.snapshots.values(): changes += snapshot._changes() snapshot_flags = {change.flag for change in changes} @@ -208,7 +170,7 @@ def run_inline( [ change for change in changes - if change.flag in _inline_snapshot._update_flags.to_set() + if change.flag in state.update_flags.to_set() ] ) recorder.fix_all() diff --git a/tests/adapter/test_change_types.py b/tests/adapter/test_change_types.py new file mode 100644 index 00000000..5b264c6f --- /dev/null +++ b/tests/adapter/test_change_types.py @@ -0,0 +1,54 @@ +import pytest + +from inline_snapshot.testing._example import Example + +values = ["1", '"2\'"', "[5]", "{1: 2}", "F(i=5)", "F.make1('2')", "f(7)"] + + +@pytest.mark.parametrize("a", values) +@pytest.mark.parametrize("b", values + ["F.make2(Is(5))"]) +def test_change_types(a, b): + context = """\ +from inline_snapshot import snapshot, Is +from dataclasses import dataclass + +@dataclass +class F: + i: int + + @staticmethod + def make1(s): + return F(i=int(s)) + + @staticmethod + def make2(s): + return F(i=s) + +def f(v): + return v + +""" + + def code_repr(v): + g = {} + exec(context + f"r=repr({a})", g) + return g["r"] + + def code(a, b): + return f"""\ +{context} + +def test_change(): + for _ in [1,2]: + assert {a} == snapshot({b}) +""" + + print(a, b) + print(code_repr(a), code_repr(b)) + + Example(code(a, b)).run_inline( + ["--inline-snapshot=fix,update"], + changed_files=( + {"test_something.py": code(a, code_repr(a))} if code_repr(a) != b else {} + ), + ) diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py index da2eedce..9c88e1dd 100644 --- a/tests/adapter/test_dataclass.py +++ b/tests/adapter/test_dataclass.py @@ -95,7 +95,8 @@ class A: c:list=field(default_factory=list) def test_something(): - assert A(a=1) == snapshot(A(a=1,b=2,c=[])) + for _ in [1,2]: + assert A(a=1) == snapshot(A(a=1,b=2,c=[])) """ ).run_inline( ["--inline-snapshot=update"], @@ -112,7 +113,47 @@ class A: c:list=field(default_factory=list) def test_something(): - assert A(a=1) == snapshot(A(a=1)) + for _ in [1,2]: + assert A(a=1) == snapshot(A(a=1)) +""" + } + ), + ) + + +def test_dataclass_positional_arguments(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + for _ in [1,2]: + assert A(a=1) == snapshot(A(1,2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=update"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:list=field(default_factory=list) + +def test_something(): + for _ in [1,2]: + assert A(a=1) == snapshot(A(1,2)) """ } ), @@ -400,12 +441,18 @@ def argument(cls, value, pos_or_name): return value.l[pos_or_name] def test_L1(): - assert L(1,2) == snapshot(L(1)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1)), "not equal" def test_L2(): - assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" + +def test_L3(): + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" """ - ).run_pytest( + ).run_pytest().run_pytest( ["--inline-snapshot=fix"], changed_files=snapshot( { @@ -439,10 +486,16 @@ def argument(cls, value, pos_or_name): return value.l[pos_or_name] def test_L1(): - assert L(1,2) == snapshot(L(1, 2)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" def test_L2(): - assert L(1,2) == snapshot(L(1, 2)), "not equal" + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" + +def test_L3(): + for _ in [1,2]: + assert L(1,2) == snapshot(L(1, 2)), "not equal" """ } ), diff --git a/tests/adapter/test_sequence.py b/tests/adapter/test_sequence.py index c7e967c9..e9c487cb 100644 --- a/tests/adapter/test_sequence.py +++ b/tests/adapter/test_sequence.py @@ -1,4 +1,5 @@ import pytest + from inline_snapshot._inline_snapshot import snapshot from inline_snapshot.testing._example import Example diff --git a/tests/conftest.py b/tests/conftest.py index 08993d16..796443bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,12 @@ import black import executing -import inline_snapshot._external import pytest -from inline_snapshot import _inline_snapshot + +import inline_snapshot._external from inline_snapshot._change import apply_all +from inline_snapshot._flags import Flags from inline_snapshot._format import format_code -from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder from inline_snapshot._types import Category from inline_snapshot.testing._example import snapshot_env @@ -101,9 +101,9 @@ def run(self, *flags_arg: Category): print("input:") print(textwrap.indent(source, " |", lambda line: True).rstrip()) - with snapshot_env(): + with snapshot_env() as state: with ChangeRecorder().activate() as recorder: - _inline_snapshot._update_flags = flags + state.update_flags = flags inline_snapshot._external.storage = ( inline_snapshot._external.DiscStorage(tmp_path / ".storage") ) @@ -116,12 +116,12 @@ def run(self, *flags_arg: Category): traceback.print_exc() error = True finally: - _inline_snapshot._active = False + state.active = False - number_snapshots = len(_inline_snapshot.snapshots) + number_snapshots = len(state.snapshots) changes = [] - for snapshot in _inline_snapshot.snapshots.values(): + for snapshot in state.snapshots.values(): changes += snapshot._changes() snapshot_flags = {change.flag for change in changes} @@ -130,7 +130,7 @@ def run(self, *flags_arg: Category): [ change for change in changes - if change.flag in _inline_snapshot._update_flags.to_set() + if change.flag in state.update_flags.to_set() ] ) @@ -139,7 +139,7 @@ def run(self, *flags_arg: Category): source = filename.read_text("utf-8")[len(prefix) :] print("reported_flags:", snapshot_flags) print( - f"run: pytest" + f' --inline-snapshot={",".join(flags.to_set())}' + f'run: pytest --inline-snapshot={",".join(flags.to_set())}' if flags else "" ) diff --git a/tests/test_change.py b/tests/test_change.py index cbe82589..610c1ccc 100644 --- a/tests/test_change.py +++ b/tests/test_change.py @@ -2,10 +2,11 @@ import pytest from executing import Source -from inline_snapshot._change import apply_all + from inline_snapshot._change import CallArg from inline_snapshot._change import Delete from inline_snapshot._change import Replace +from inline_snapshot._change import apply_all from inline_snapshot._inline_snapshot import snapshot from inline_snapshot._rewrite_code import ChangeRecorder from inline_snapshot._source_file import SourceFile diff --git a/tests/test_code_repr.py b/tests/test_code_repr.py index b33cbba2..7c0b50ab 100644 --- a/tests/test_code_repr.py +++ b/tests/test_code_repr.py @@ -1,7 +1,15 @@ import dataclasses +from collections import Counter +from collections import OrderedDict +from collections import UserDict +from collections import UserList +from collections import defaultdict +from collections import namedtuple from dataclasses import dataclass +from typing import NamedTuple import pytest + from inline_snapshot import HasRepr from inline_snapshot import snapshot from inline_snapshot._code_repr import code_repr @@ -240,16 +248,6 @@ class Color(Enum): ).run_inline() -from collections import ( - Counter, - OrderedDict, - UserDict, - UserList, - defaultdict, - namedtuple, -) -from typing import NamedTuple - A = namedtuple("A", "a,b", defaults=[0]) B = namedtuple("B", "a,b", defaults=[0, 0]) diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py index a5e7b4fc..bdbe2ceb 100644 --- a/tests/test_dirty_equals.py +++ b/tests/test_dirty_equals.py @@ -1,4 +1,5 @@ import pytest + from inline_snapshot._inline_snapshot import snapshot from inline_snapshot.testing._example import Example diff --git a/tests/test_docs.py b/tests/test_docs.py index ff70364c..f698464d 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -8,9 +8,10 @@ from pathlib import Path from typing import Optional -import inline_snapshot._inline_snapshot import pytest + from inline_snapshot import snapshot +from inline_snapshot._global_state import state from inline_snapshot.extra import raises @@ -32,7 +33,6 @@ def map_code_blocks(file, func, fix=False): current_code = file.read_text("utf-8") new_lines = [] block_lines = [] - options = set() is_block = False code = None indent = "" @@ -156,7 +156,7 @@ def test_block(block): if recorded_markdown_code != markdown_code: assert new_markdown_code == recorded_markdown_code else: - assert new_markdown_code == None + assert new_markdown_code is None test_doc( """ @@ -365,6 +365,4 @@ def test_block(block: Block): last_code = code return block - map_code_blocks( - file, test_block, inline_snapshot._inline_snapshot._update_flags.fix - ) + map_code_blocks(file, test_block, state().update_flags.fix) diff --git a/tests/test_external.py b/tests/test_external.py index 05958108..5ec6de0f 100644 --- a/tests/test_external.py +++ b/tests/test_external.py @@ -1,12 +1,15 @@ import ast +from inline_snapshot import _inline_snapshot from inline_snapshot import external from inline_snapshot import outsource from inline_snapshot import snapshot +from inline_snapshot._find_external import ensure_import from inline_snapshot.extra import raises - from tests.utils import config +from .utils import apply_changes + def test_basic(check_update): assert check_update( @@ -170,7 +173,7 @@ def test_a(): """ ) - result = project.run("--inline-snapshot=create") + project.run("--inline-snapshot=create") assert project.source == snapshot( """\ @@ -314,9 +317,6 @@ def test_errors(): assert external("123*.txt") != external("123*.bin") -from inline_snapshot import _inline_snapshot - - def test_uses_external(): assert _inline_snapshot.used_externals(ast.parse("[external('111*.txt')]")) assert not _inline_snapshot.used_externals(ast.parse("[external()]")) @@ -354,11 +354,6 @@ def test_something(): ) -from inline_snapshot._find_external import ensure_import - -from .utils import apply_changes - - def test_ensure_imports(tmp_path): file = tmp_path / "file.py" file.write_text( diff --git a/tests/test_formating.py b/tests/test_formating.py index 9b347a7c..f640e656 100644 --- a/tests/test_formating.py +++ b/tests/test_formating.py @@ -3,9 +3,9 @@ from types import SimpleNamespace from click.testing import CliRunner + from inline_snapshot import snapshot from inline_snapshot.testing import Example - from tests._is_normalized import normalization diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index 7b5c57ef..9f40c326 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -7,9 +7,9 @@ from typing import Union import pytest -from inline_snapshot import _inline_snapshot + from inline_snapshot import snapshot -from inline_snapshot._inline_snapshot import Flags +from inline_snapshot._flags import Flags from inline_snapshot.testing import Example from inline_snapshot.testing._example import snapshot_env @@ -23,8 +23,8 @@ def test_snapshot_eq(): @pytest.mark.no_rewriting def test_disabled(): - with snapshot_env(): - _inline_snapshot._active = False + with snapshot_env() as state: + state.active = False with pytest.raises(AssertionError) as excinfo: assert 2 == snapshot() diff --git a/tests/test_preserve_values.py b/tests/test_preserve_values.py index 654eec8b..b97bc228 100644 --- a/tests/test_preserve_values.py +++ b/tests/test_preserve_values.py @@ -2,6 +2,7 @@ import sys import pytest + from inline_snapshot import snapshot diff --git a/tests/test_pypy.py b/tests/test_pypy.py index e76ba960..4094fc67 100644 --- a/tests/test_pypy.py +++ b/tests/test_pypy.py @@ -1,6 +1,7 @@ import sys import pytest + from inline_snapshot import snapshot from inline_snapshot.testing import Example diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 49bc5041..d7305a85 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -1,4 +1,5 @@ import pytest + from inline_snapshot import snapshot from inline_snapshot.testing import Example diff --git a/tests/test_rewrite_code.py b/tests/test_rewrite_code.py index dd9c1259..e34f1073 100644 --- a/tests/test_rewrite_code.py +++ b/tests/test_rewrite_code.py @@ -1,9 +1,10 @@ import pytest + from inline_snapshot._rewrite_code import ChangeRecorder -from inline_snapshot._rewrite_code import end_of -from inline_snapshot._rewrite_code import range_of from inline_snapshot._rewrite_code import SourcePosition from inline_snapshot._rewrite_code import SourceRange +from inline_snapshot._rewrite_code import end_of +from inline_snapshot._rewrite_code import range_of from inline_snapshot._rewrite_code import start_of @@ -18,7 +19,7 @@ def test_range(): assert range_of(r) == r with pytest.raises(ValueError): - rr = SourceRange(b, a) + SourceRange(b, a) def test_rewrite(tmp_path): diff --git a/tests/test_string.py b/tests/test_string.py index 301b6230..aa39782a 100644 --- a/tests/test_string.py +++ b/tests/test_string.py @@ -2,6 +2,7 @@ from hypothesis import given from hypothesis.strategies import text + from inline_snapshot import snapshot from inline_snapshot._utils import triple_quote diff --git a/tests/utils.py b/tests/utils.py index 90dd970a..9e3bdca8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,13 +2,13 @@ import sys from contextlib import contextmanager +import pytest + import inline_snapshot._config as _config import inline_snapshot._external as external -import pytest from inline_snapshot._rewrite_code import ChangeRecorder from inline_snapshot.testing._example import snapshot_env - __all__ = ("snapshot_env",) pytest_compatible = sys.version_info >= (3, 11) and pytest.version_tuple >= (8, 3, 4)