From 099a2d6ba62ba676b09192ce6a219819a00dd6db Mon Sep 17 00:00:00 2001 From: Frank Hoffmann <15r10nk-git@polarbit.de> Date: Fri, 20 Oct 2023 10:43:44 +0200 Subject: [PATCH] fix: show better error messages --- docs/index.md | 13 ++++++-- inline_snapshot/_inline_snapshot.py | 47 ++++++++++++++++++++++++++--- tests/test_inline_snapshot.py | 20 ++++++++++++ tests/test_pytest_plugin.py | 5 +++ 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/docs/index.md b/docs/index.md index d8f0562c..b003bb0a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -93,10 +93,17 @@ You can use `snapshot(x)` like you can use `x` in your assertion with a limited !!! warning One snapshot can only be used with one operation. The following code will not work: + ``` python - s = snapshot(5) - assert 5 <= s - assert 5 == s # Error: s is already used with <= + def test_something(): + s = snapshot(5) + assert 5 <= s + assert 5 == s + + + # Error: + # > assert 5 == s + # E TypeError: This snapshot cannot be use with `==`, because it was previously used with `x <= snapshot` ``` ## Supported usage diff --git a/inline_snapshot/_inline_snapshot.py b/inline_snapshot/_inline_snapshot.py index 8a2df9ab..f2355e91 100644 --- a/inline_snapshot/_inline_snapshot.py +++ b/inline_snapshot/_inline_snapshot.py @@ -86,6 +86,7 @@ def snapshot_env(): class GenericValue: _new_value: Any _old_value: Any + _current_op = "undefined" def _needs_trim(self): return False @@ -116,8 +117,34 @@ def get_result(self, flags): def __repr__(self): return repr(self._visible_value()) + def _type_error(self, op): + __tracebackhide__ = True + raise TypeError( + f"This snapshot cannot be use with `{op}`, because it was previously used with `{self._current_op}`" + ) + + def __eq__(self, _other): + __tracebackhide__ = True + self._type_error("==") + + def __le__(self, _other): + __tracebackhide__ = True + self._type_error("<=") + + def __ge__(self, _other): + __tracebackhide__ = True + self._type_error(">=") + + def __contains__(self, _other): + __tracebackhide__ = True + self._type_error("in") + + def __getitem__(self, _item): + __tracebackhide__ = True + self._type_error("snapshot[key]") -class Value(GenericValue): + +class UndecidedValue(GenericValue): def __init__(self, _old_value): self._old_value = _old_value self._new_value = undefined @@ -131,7 +158,7 @@ def _needs_fix(self): # functions which determine the type def __eq__(self, other): - self._change(FixValue) + self._change(EqValue) return self == other def __le__(self, other): @@ -151,7 +178,9 @@ def __getitem__(self, item): return self[item] -class FixValue(GenericValue): +class EqValue(GenericValue): + _current_op = "x == snapshot" + def __eq__(self, other): other = copy.deepcopy(other) @@ -228,6 +257,8 @@ class MinValue(MinMaxValue): """ + _current_op = "x >= snapshot" + @staticmethod def cmp(a, b): return a <= b @@ -251,6 +282,8 @@ class MaxValue(MinMaxValue): """ + _current_op = "x <= snapshot" + @staticmethod def cmp(a, b): return a >= b @@ -259,6 +292,8 @@ def cmp(a, b): class CollectionValue(GenericValue): + _current_op = "x in snapshot" + def __contains__(self, item): item = copy.deepcopy(item) @@ -300,6 +335,8 @@ def get_result(self, flags): class DictValue(GenericValue): + _current_op = "snapshot[key]" + def __getitem__(self, index): if self._new_value is undefined: self._new_value = {} @@ -309,7 +346,7 @@ def __getitem__(self, index): old_value = {} if index not in self._new_value: - self._new_value[index] = Value(old_value.get(index, undefined)) + self._new_value[index] = UndecidedValue(old_value.get(index, undefined)) return self._new_value[index] @@ -487,7 +524,7 @@ def triple_quote(string): class Snapshot: def __init__(self, value, expr): self._expr = expr - self._value = Value(value) + self._value = UndecidedValue(value) @property def _filename(self): diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index c8e3344f..97c45729 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -2,6 +2,7 @@ import itertools import textwrap from collections import namedtuple +from contextlib import nullcontext from dataclasses import dataclass from dataclasses import field from typing import Set @@ -840,3 +841,22 @@ def test_format_value(check_update): def test_unused_snapshot(check_update): assert check_update("snapshot()\n", flags="create") == "snapshot()\n" + + +def test_type_error(check_update): + tests = ["5 == s", "5 <= s", "5 >= s", "5 in s", "5 == s[0]"] + + for test1, test2 in itertools.product(tests, tests): + with pytest.raises(TypeError) if test1 != test2 else nullcontext() as error: + check_update( + f""" +s = snapshot() +assert {test1} +assert {test2} + """, + reported_flags="create", + ) + if error is not None: + assert "This snapshot cannot be use with" in str(error.value) + else: + assert test1 == test2 diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index d92fb1e6..956bc23a 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -507,6 +507,10 @@ def test_docs(project, file, subtests): if flags: new_code = project.source + if "show_error" in options: + new_code = new_code.split("# Error:")[0] + new_code += "# Error:" + textwrap.indent(result.errorLines(), "# ") + if ( inline_snapshot._inline_snapshot._update_flags.fix ): # pragma: no cover @@ -523,6 +527,7 @@ def test_docs(project, file, subtests): assert { f"outcome-{k}={v}" for k, v in result.parseoutcomes().items() + if k in ("failed", "errors", "passed") } == {flag for flag in options if flag.startswith("outcome-")} assert code == new_code else: # pragma: no cover