diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index c6464fc..715c1f8 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -3,14 +3,10 @@ import inspect from typing import Any from typing import cast -from typing import Dict # noqa from typing import Iterator -from typing import List -from typing import Tuple # noqa from typing import TypeVar from executing import Source -from inline_snapshot._adapter.adapter import Adapter from inline_snapshot._adapter.adapter import adapter_map from inline_snapshot._source_file import SourceFile @@ -21,7 +17,6 @@ from ._change import Change from ._change import Replace from ._code_repr import code_repr -from ._compare_context import compare_only from ._exceptions import UsageError from ._sentinels import undefined from ._types import Snapshot @@ -181,6 +176,8 @@ def handle(node, obj): # functions which determine the type def __eq__(self, other): + from ._snapshot.eq_value import EqValue + self._change(EqValue) return self == other @@ -226,42 +223,6 @@ def clone(obj): return new -class EqValue(GenericValue): - _current_op = "x == snapshot" - _changes: List[Change] - - def __eq__(self, other): - if self._old_value is undefined: - state()._missing_values += 1 - - if not compare_only() and self._new_value is undefined: - adapter = Adapter(self._context).get_adapter(self._old_value, other) - it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) - self._changes = [] - while True: - try: - self._changes.append(next(it)) - except StopIteration as ex: - self._new_value = ex.value - break - - return _return(self._visible_value() == other) - - # if self._new_value is undefined: - # self._new_value = use_valid_old_values(self._old_value, clone(other)) - # if self._old_value is undefined or ignore_old_value(): - # return True - # return _return(self._old_value == other) - # else: - # return _return(self._new_value == other) - - def _new_code(self): - return self._file._value_to_code(self._new_value) - - def _get_changes(self) -> Iterator[Change]: - return iter(self._changes) - - T = TypeVar("T") diff --git a/src/inline_snapshot/_snapshot/eq_value.py b/src/inline_snapshot/_snapshot/eq_value.py new file mode 100644 index 0000000..5980863 --- /dev/null +++ b/src/inline_snapshot/_snapshot/eq_value.py @@ -0,0 +1,40 @@ +from typing import Iterator +from typing import List + +from inline_snapshot._adapter.adapter import Adapter + +from .._change import Change +from .._compare_context import compare_only +from .._inline_snapshot import _return +from .._inline_snapshot import clone +from .._inline_snapshot import GenericValue +from .._sentinels import undefined +from ..global_state import state + + +class EqValue(GenericValue): + _current_op = "x == snapshot" + _changes: List[Change] + + def __eq__(self, other): + if self._old_value is undefined: + state()._missing_values += 1 + + if not compare_only() and self._new_value is undefined: + adapter = Adapter(self._context).get_adapter(self._old_value, other) + it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + self._changes = [] + while True: + try: + self._changes.append(next(it)) + except StopIteration as ex: + self._new_value = ex.value + break + + return _return(self._visible_value() == other) + + def _new_code(self): + return self._file._value_to_code(self._new_value) + + def _get_changes(self) -> Iterator[Change]: + return iter(self._changes) diff --git a/src/inline_snapshot/_snapshot/min_max_value.py b/src/inline_snapshot/_snapshot/min_max_value.py new file mode 100644 index 0000000..fd48948 --- /dev/null +++ b/src/inline_snapshot/_snapshot/min_max_value.py @@ -0,0 +1,104 @@ +from typing import Iterator + +from .._change import Change +from .._change import Replace +from .._inline_snapshot import _return +from .._inline_snapshot import clone +from .._inline_snapshot import GenericValue +from .._inline_snapshot import ignore_old_value +from .._sentinels import undefined +from .._utils import value_to_token +from ..global_state import state + + +class MinMaxValue(GenericValue): + """Generic implementation for <=, >=""" + + @staticmethod + def cmp(a, b): + raise NotImplemented + + def _generic_cmp(self, other): + if self._old_value is undefined: + state()._missing_values += 1 + + if self._new_value is undefined: + self._new_value = clone(other) + if self._old_value is undefined or ignore_old_value(): + return True + return _return(self.cmp(self._old_value, other)) + else: + if not self.cmp(self._new_value, other): + self._new_value = clone(other) + + return _return(self.cmp(self._visible_value(), other)) + + def _new_code(self): + return self._file._value_to_code(self._new_value) + + def _get_changes(self) -> Iterator[Change]: + new_token = value_to_token(self._new_value) + if not self.cmp(self._old_value, self._new_value): + flag = "fix" + elif not self.cmp(self._new_value, self._old_value): + flag = "trim" + elif ( + self._ast_node is not None + and self._file._token_of_node(self._ast_node) != new_token + ): + flag = "update" + else: + return + + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag=flag, + old_value=self._old_value, + new_value=self._new_value, + ) + + +class MinValue(MinMaxValue): + """ + handles: + + >>> snapshot(5) <= 6 + True + + >>> 6 >= snapshot(5) + True + + """ + + _current_op = "x >= snapshot" + + @staticmethod + def cmp(a, b): + return a <= b + + __le__ = MinMaxValue._generic_cmp + + +class MaxValue(MinMaxValue): + """ + handles: + + >>> snapshot(5) >= 4 + True + + >>> 4 <= snapshot(5) + True + + """ + + _current_op = "x <= snapshot" + + @staticmethod + def cmp(a, b): + return a >= b + + __ge__ = MinMaxValue._generic_cmp