-
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: extracted EqValue into eq_value.py
- Loading branch information
Showing
3 changed files
with
146 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import Iterator | ||
from typing import List | ||
|
||
from inline_snapshot._adapter.adapter import Adapter | ||
|
||
from .._change import Change | ||
from .._compare_context import compare_only | ||
from .._inline_snapshot import _return | ||
from .._inline_snapshot import clone | ||
from .._inline_snapshot import GenericValue | ||
from .._sentinels import undefined | ||
from ..global_state import state | ||
|
||
|
||
class EqValue(GenericValue): | ||
_current_op = "x == snapshot" | ||
_changes: List[Change] | ||
|
||
def __eq__(self, other): | ||
if self._old_value is undefined: | ||
state()._missing_values += 1 | ||
|
||
if not compare_only() and self._new_value is undefined: | ||
adapter = Adapter(self._context).get_adapter(self._old_value, other) | ||
it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) | ||
self._changes = [] | ||
while True: | ||
try: | ||
self._changes.append(next(it)) | ||
except StopIteration as ex: | ||
self._new_value = ex.value | ||
break | ||
|
||
return _return(self._visible_value() == other) | ||
|
||
def _new_code(self): | ||
return self._file._value_to_code(self._new_value) | ||
|
||
def _get_changes(self) -> Iterator[Change]: | ||
return iter(self._changes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import Iterator | ||
|
||
from .._change import Change | ||
from .._change import Replace | ||
from .._inline_snapshot import _return | ||
from .._inline_snapshot import clone | ||
from .._inline_snapshot import GenericValue | ||
from .._inline_snapshot import ignore_old_value | ||
from .._sentinels import undefined | ||
from .._utils import value_to_token | ||
from ..global_state import state | ||
|
||
|
||
class MinMaxValue(GenericValue): | ||
"""Generic implementation for <=, >=""" | ||
|
||
@staticmethod | ||
def cmp(a, b): | ||
raise NotImplemented | ||
|
||
def _generic_cmp(self, other): | ||
if self._old_value is undefined: | ||
state()._missing_values += 1 | ||
|
||
if self._new_value is undefined: | ||
self._new_value = clone(other) | ||
if self._old_value is undefined or ignore_old_value(): | ||
return True | ||
return _return(self.cmp(self._old_value, other)) | ||
else: | ||
if not self.cmp(self._new_value, other): | ||
self._new_value = clone(other) | ||
|
||
return _return(self.cmp(self._visible_value(), other)) | ||
|
||
def _new_code(self): | ||
return self._file._value_to_code(self._new_value) | ||
|
||
def _get_changes(self) -> Iterator[Change]: | ||
new_token = value_to_token(self._new_value) | ||
if not self.cmp(self._old_value, self._new_value): | ||
flag = "fix" | ||
elif not self.cmp(self._new_value, self._old_value): | ||
flag = "trim" | ||
elif ( | ||
self._ast_node is not None | ||
and self._file._token_of_node(self._ast_node) != new_token | ||
): | ||
flag = "update" | ||
else: | ||
return | ||
|
||
new_code = self._file._token_to_code(new_token) | ||
|
||
yield Replace( | ||
node=self._ast_node, | ||
file=self._file, | ||
new_code=new_code, | ||
flag=flag, | ||
old_value=self._old_value, | ||
new_value=self._new_value, | ||
) | ||
|
||
|
||
class MinValue(MinMaxValue): | ||
""" | ||
handles: | ||
>>> snapshot(5) <= 6 | ||
True | ||
>>> 6 >= snapshot(5) | ||
True | ||
""" | ||
|
||
_current_op = "x >= snapshot" | ||
|
||
@staticmethod | ||
def cmp(a, b): | ||
return a <= b | ||
|
||
__le__ = MinMaxValue._generic_cmp | ||
|
||
|
||
class MaxValue(MinMaxValue): | ||
""" | ||
handles: | ||
>>> snapshot(5) >= 4 | ||
True | ||
>>> 4 <= snapshot(5) | ||
True | ||
""" | ||
|
||
_current_op = "x <= snapshot" | ||
|
||
@staticmethod | ||
def cmp(a, b): | ||
return a >= b | ||
|
||
__ge__ = MinMaxValue._generic_cmp |