Skip to content

Commit

Permalink
fix: show better error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Oct 20, 2023
1 parent 862ec56 commit 5a4af3e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 8 deletions.
13 changes: 10 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,17 @@ You can use `snapshot(x)` like you can use `x` in your assertion with a limited
!!! warning
One snapshot can only be used with one operation.
The following code will not work:
<!-- inline-snapshot: show_error outcome-failed=1 -->
``` python
s = snapshot(5)
assert 5 <= s
assert 5 == s # Error: s is already used with <=
def test_something():
s = snapshot(5)
assert 5 <= s
assert 5 == s


# Error:
# > assert 5 == s
# E TypeError: This snapshot cannot be use with `==`, because it was previously used with `x <= snapshot`
```

## Supported usage
Expand Down
47 changes: 42 additions & 5 deletions inline_snapshot/_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def snapshot_env():
class GenericValue:
_new_value: Any
_old_value: Any
_current_op = "undefined"

def _needs_trim(self):
return False
Expand Down Expand Up @@ -116,8 +117,34 @@ def get_result(self, flags):
def __repr__(self):
return repr(self._visible_value())

def _type_error(self, op):
__tracebackhide__ = True
raise TypeError(
f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`"
)

def __eq__(self, _other):
__tracebackhide__ = True
self._type_error("==")

def __le__(self, _other):
__tracebackhide__ = True
self._type_error("<=")

def __ge__(self, _other):
__tracebackhide__ = True
self._type_error(">=")

def __contains__(self, _other):
__tracebackhide__ = True
self._type_error("in")

def __getitem__(self, _item):
__tracebackhide__ = True
self._type_error("snapshot[key]")

class Value(GenericValue):

class UndecidedValue(GenericValue):
def __init__(self, _old_value):
self._old_value = _old_value
self._new_value = undefined
Expand All @@ -131,7 +158,7 @@ def _needs_fix(self):
# functions which determine the type

def __eq__(self, other):
self._change(FixValue)
self._change(EqValue)
return self == other

def __le__(self, other):
Expand All @@ -151,7 +178,9 @@ def __getitem__(self, item):
return self[item]


class FixValue(GenericValue):
class EqValue(GenericValue):
_current_op = "x == snapshot"

def __eq__(self, other):
other = copy.deepcopy(other)

Expand Down Expand Up @@ -228,6 +257,8 @@ class MinValue(MinMaxValue):
"""

_current_op = "x >= snapshot"

@staticmethod
def cmp(a, b):
return a <= b
Expand All @@ -251,6 +282,8 @@ class MaxValue(MinMaxValue):
"""

_current_op = "x <= snapshot"

@staticmethod
def cmp(a, b):
return a >= b
Expand All @@ -259,6 +292,8 @@ def cmp(a, b):


class CollectionValue(GenericValue):
_current_op = "x in snapshot"

def __contains__(self, item):
item = copy.deepcopy(item)

Expand Down Expand Up @@ -300,6 +335,8 @@ def get_result(self, flags):


class DictValue(GenericValue):
_current_op = "snapshot[key]"

def __getitem__(self, index):
if self._new_value is undefined:
self._new_value = {}
Expand All @@ -309,7 +346,7 @@ def __getitem__(self, index):
old_value = {}

if index not in self._new_value:
self._new_value[index] = Value(old_value.get(index, undefined))
self._new_value[index] = UndecidedValue(old_value.get(index, undefined))

return self._new_value[index]

Expand Down Expand Up @@ -487,7 +524,7 @@ def triple_quote(string):
class Snapshot:
def __init__(self, value, expr):
self._expr = expr
self._value = Value(value)
self._value = UndecidedValue(value)

@property
def _filename(self):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_inline_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import textwrap
from collections import namedtuple
from contextlib import nullcontext
from dataclasses import dataclass
from dataclasses import field
from typing import Set
Expand Down Expand Up @@ -840,3 +841,22 @@ def test_format_value(check_update):

def test_unused_snapshot(check_update):
assert check_update("snapshot()\n", flags="create") == "snapshot()\n"


def test_type_error(check_update):
tests = ["5 == s", "5 <= s", "5 >= s", "5 in s", "5 == s[0]"]

for test1, test2 in itertools.product(tests, tests):
with pytest.raises(TypeError) if test1 != test2 else nullcontext() as error:
check_update(
f"""
s = snapshot()
assert {test1}
assert {test2}
""",
reported_flags="create",
)
if error is not None:
assert "This snapshot cannot be use with" in str(error.value)
else:
assert test1 == test2
5 changes: 5 additions & 0 deletions tests/test_pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,10 @@ def test_docs(project, file, subtests):
if flags:
new_code = project.source

if "show_error" in options:
new_code = new_code.split("# Error:")[0]
new_code += "# Error:" + textwrap.indent(result.errorLines(), "# ")

if (
inline_snapshot._inline_snapshot._update_flags.fix
): # pragma: no cover
Expand All @@ -523,6 +527,7 @@ def test_docs(project, file, subtests):
assert {
f"outcome-{k}={v}"
for k, v in result.parseoutcomes().items()
if k in ("failed", "errors", "passed")
} == {flag for flag in options if flag.startswith("outcome-")}
assert code == new_code
else: # pragma: no cover
Expand Down

0 comments on commit 5a4af3e

Please sign in to comment.