Skip to content

Commit

Permalink
refactor: extracted EqValue into eq_value.py
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 16, 2025
1 parent e7a7322 commit 28b1a4b
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 41 deletions.
43 changes: 2 additions & 41 deletions src/inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
import inspect
from typing import Any
from typing import cast
from typing import Dict # noqa
from typing import Iterator
from typing import List
from typing import Tuple # noqa
from typing import TypeVar

from executing import Source
from inline_snapshot._adapter.adapter import Adapter
from inline_snapshot._adapter.adapter import adapter_map
from inline_snapshot._source_file import SourceFile

Expand All @@ -21,7 +17,6 @@
from ._change import Change
from ._change import Replace
from ._code_repr import code_repr
from ._compare_context import compare_only
from ._exceptions import UsageError
from ._sentinels import undefined
from ._types import Snapshot
Expand Down Expand Up @@ -181,6 +176,8 @@ def handle(node, obj):
# functions which determine the type

def __eq__(self, other):
from ._snapshot.eq_value import EqValue

self._change(EqValue)
return self == other

Expand Down Expand Up @@ -226,42 +223,6 @@ def clone(obj):
return new


class EqValue(GenericValue):
_current_op = "x == snapshot"
_changes: List[Change]

def __eq__(self, other):
if self._old_value is undefined:
state()._missing_values += 1

if not compare_only() and self._new_value is undefined:
adapter = Adapter(self._context).get_adapter(self._old_value, other)
it = iter(adapter.assign(self._old_value, self._ast_node, clone(other)))
self._changes = []
while True:
try:
self._changes.append(next(it))
except StopIteration as ex:
self._new_value = ex.value
break

return _return(self._visible_value() == other)

# if self._new_value is undefined:
# self._new_value = use_valid_old_values(self._old_value, clone(other))
# if self._old_value is undefined or ignore_old_value():
# return True
# return _return(self._old_value == other)
# else:
# return _return(self._new_value == other)

def _new_code(self):
return self._file._value_to_code(self._new_value)

def _get_changes(self) -> Iterator[Change]:
return iter(self._changes)


T = TypeVar("T")


Expand Down
40 changes: 40 additions & 0 deletions src/inline_snapshot/_snapshot/eq_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Iterator
from typing import List

from inline_snapshot._adapter.adapter import Adapter

from .._change import Change
from .._compare_context import compare_only
from .._inline_snapshot import _return
from .._inline_snapshot import clone
from .._inline_snapshot import GenericValue
from .._sentinels import undefined
from ..global_state import state


class EqValue(GenericValue):
_current_op = "x == snapshot"
_changes: List[Change]

def __eq__(self, other):
if self._old_value is undefined:
state()._missing_values += 1

if not compare_only() and self._new_value is undefined:
adapter = Adapter(self._context).get_adapter(self._old_value, other)
it = iter(adapter.assign(self._old_value, self._ast_node, clone(other)))
self._changes = []
while True:
try:
self._changes.append(next(it))
except StopIteration as ex:
self._new_value = ex.value
break

return _return(self._visible_value() == other)

def _new_code(self):
return self._file._value_to_code(self._new_value)

def _get_changes(self) -> Iterator[Change]:
return iter(self._changes)
104 changes: 104 additions & 0 deletions src/inline_snapshot/_snapshot/min_max_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Iterator

from .._change import Change
from .._change import Replace
from .._inline_snapshot import _return
from .._inline_snapshot import clone
from .._inline_snapshot import GenericValue
from .._inline_snapshot import ignore_old_value
from .._sentinels import undefined
from .._utils import value_to_token
from ..global_state import state


class MinMaxValue(GenericValue):
"""Generic implementation for <=, >="""

@staticmethod
def cmp(a, b):
raise NotImplemented

def _generic_cmp(self, other):
if self._old_value is undefined:
state()._missing_values += 1

if self._new_value is undefined:
self._new_value = clone(other)
if self._old_value is undefined or ignore_old_value():
return True
return _return(self.cmp(self._old_value, other))
else:
if not self.cmp(self._new_value, other):
self._new_value = clone(other)

return _return(self.cmp(self._visible_value(), other))

def _new_code(self):
return self._file._value_to_code(self._new_value)

def _get_changes(self) -> Iterator[Change]:
new_token = value_to_token(self._new_value)
if not self.cmp(self._old_value, self._new_value):
flag = "fix"
elif not self.cmp(self._new_value, self._old_value):
flag = "trim"
elif (
self._ast_node is not None
and self._file._token_of_node(self._ast_node) != new_token
):
flag = "update"
else:
return

new_code = self._file._token_to_code(new_token)

yield Replace(
node=self._ast_node,
file=self._file,
new_code=new_code,
flag=flag,
old_value=self._old_value,
new_value=self._new_value,
)


class MinValue(MinMaxValue):
"""
handles:
>>> snapshot(5) <= 6
True
>>> 6 >= snapshot(5)
True
"""

_current_op = "x >= snapshot"

@staticmethod
def cmp(a, b):
return a <= b

__le__ = MinMaxValue._generic_cmp


class MaxValue(MinMaxValue):
"""
handles:
>>> snapshot(5) >= 4
True
>>> 4 <= snapshot(5)
True
"""

_current_op = "x <= snapshot"

@staticmethod
def cmp(a, b):
return a >= b

__ge__ = MinMaxValue._generic_cmp

0 comments on commit 28b1a4b

Please sign in to comment.