Skip to content

Commit

Permalink
feat: pydantic v1 support
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Dec 20, 2024
1 parent 4a202ff commit 06a8a0b
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 23 deletions.
10 changes: 10 additions & 0 deletions codi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from inline_snapshot import snapshot
from pydantic import BaseModel


class M(BaseModel):
a: int


def test():
assert M(a=5) == snapshot()
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,21 @@ matrix.pytest.dependencies = [


[tool.hatch.envs.hatch-test]
# Info if you package this library:
# The following dependencies are installed with uv
# and used for specific tests in specific versions:
# - pydantic v1 & v2
# - attrs
extra-dependencies = [
"dirty-equals>=0.7.0",
"hypothesis>=6.75.5",
"mypy>=1.2.0",
"pyright>=1.1.359",
"pytest-subtests>=0.11.0",
"pytest-freezer>=0.4.8",
"pydantic",
"attrs",
"pytest-mock>=3.14.0"
]

env-vars.TOP = "{root}"

[tool.hatch.envs.hatch-test.scripts]
Expand Down
23 changes: 19 additions & 4 deletions src/inline_snapshot/_adapter/generic_call_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,26 @@ def argument(self, value, pos_or_name):


try:
from pydantic import BaseModel
import pydantic
except ImportError: # pragma: no cover
pass
else:
from pydantic_core import PydanticUndefined
# import pydantic
if pydantic.version.VERSION.startswith("1."):
# pydantic v1
from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined,no-redef]

def get_fields(value):
return value.__fields__

else:
# pydantic v2
from pydantic_core import PydanticUndefined

def get_fields(value):
return value.model_fields

from pydantic import BaseModel

class PydanticContainer(GenericCallAdapter):

Expand All @@ -313,8 +328,8 @@ def arguments(cls, value):

kwargs = {}

for name, field in value.model_fields.items(): # type: ignore
if field.repr:
for name, field in get_fields(value).items(): # type: ignore
if getattr(field, "repr", True):
field_value = getattr(value, name)
is_default = False

Expand Down
30 changes: 30 additions & 0 deletions src/inline_snapshot/pydantic_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from . import Snapshot

is_fixed = False


def pydantic_fix():
global is_fixed
if is_fixed:
return
is_fixed = True

try:
from pydantic import BaseModel
except ImportError:
return

import pydantic

if not pydantic.version.VERSION.startswith("1."):
return

origin_eq = BaseModel.__eq__

def new_eq(self, other):
if isinstance(other, Snapshot): # type: ignore
return other == self
else:
return origin_eq(self, other)

BaseModel.__eq__ = new_eq
3 changes: 3 additions & 0 deletions src/inline_snapshot/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from inline_snapshot._problems import report_problems
from inline_snapshot.pydantic_fix import pydantic_fix
from rich import box
from rich.console import Console
from rich.panel import Panel
Expand Down Expand Up @@ -125,6 +126,8 @@ def pytest_configure(config):
e for e in sys.meta_path if type(e).__name__ != "AssertionRewritingHook"
]

pydantic_fix()

_external.storage.prune_new_files()


Expand Down
10 changes: 9 additions & 1 deletion src/inline_snapshot/testing/_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def run_pytest(
self,
args: list[str] = [],
*,
extra_dependencies: list[str] = [],
env: dict[str, str] = {},
changed_files: Snapshot[dict[str, str]] | None = None,
report: Snapshot[str] | None = None,
Expand Down Expand Up @@ -256,7 +257,7 @@ def run_pytest(
tmp_path = Path(dir)
self._write_files(tmp_path)

cmd = ["pytest", *args]
cmd = ["python", "-m", "pytest", *args]

term_columns = 80

Expand All @@ -269,6 +270,13 @@ def run_pytest(

command_env.update(env)

if extra_dependencies:
uv_cmd = ["uv", "run"]
for dependency in extra_dependencies:
uv_cmd.append(f"--with={dependency}")

cmd = uv_cmd + cmd

result = sp.run(cmd, cwd=tmp_path, capture_output=True, env=command_env)

print("run>", *cmd)
Expand Down
34 changes: 21 additions & 13 deletions tests/adapter/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_something():
)


def test_pydantic_default_value():
def test_pydantic_default_value(pydantic_version):
Example(
"""\
from inline_snapshot import snapshot,Is
Expand All @@ -97,8 +97,9 @@ class A(BaseModel):
def test_something():
assert A(a=1) == snapshot(A(a=1,b=2,c=[]))
"""
).run_inline(
).run_pytest(
["--inline-snapshot=update"],
extra_dependencies=[pydantic_version],
changed_files=snapshot(
{
"test_something.py": """\
Expand Down Expand Up @@ -173,8 +174,9 @@ def test_something():
assert A(a=1) == snapshot(A(a=1,b=2,c=[],d=11))
assert A(a=2,b=3) == snapshot(A(a=1,b=2,c=[],d=11))
"""
).run_inline(
).run_pytest(
["--inline-snapshot=fix"],
extra_dependencies=["attrs"],
changed_files=snapshot(
{
"test_something.py": """\
Expand All @@ -194,8 +196,9 @@ def test_something():
"""
}
),
).run_inline(
).run_pytest(
["--inline-snapshot=update"],
extra_dependencies=["attrs"],
changed_files=snapshot(
{
"test_something.py": """\
Expand Down Expand Up @@ -223,7 +226,6 @@ def test_attrs_field_repr():
Example(
"""\
from inline_snapshot import snapshot
from pydantic import BaseModel,Field
import attrs
@attrs.define
Expand All @@ -233,13 +235,13 @@ class container:
assert container(a=1,b=5) == snapshot()
"""
).run_inline(
).run_pytest(
["--inline-snapshot=create"],
extra_dependencies=["attrs"],
changed_files=snapshot(
{
"test_something.py": """\
from inline_snapshot import snapshot
from pydantic import BaseModel,Field
import attrs
@attrs.define
Expand All @@ -251,7 +253,9 @@ class container:
"""
}
),
).run_inline()
).run_pytest(
extra_dependencies=["attrs"],
)


def test_attrs_unmanaged():
Expand All @@ -278,10 +282,13 @@ def test():
dt.datetime.now(), id
)
"""
).run_inline(
).run_pytest(
["--inline-snapshot=create,fix"],
extra_dependencies=["attrs"],
changed_files=snapshot({}),
).run_inline()
).run_pytest(
extra_dependencies=["attrs"],
)


def test_disabled(executing_used):
Expand Down Expand Up @@ -582,7 +589,7 @@ class container:
).run_inline()


def test_pydantic_field_repr():
def test_pydantic_field_repr(pydantic_version):

Example(
"""\
Expand All @@ -595,8 +602,9 @@ class container(BaseModel):
assert container(a=1,b=5) == snapshot()
"""
).run_inline(
).run_pytest(
["--inline-snapshot=create"],
extra_dependencies=[pydantic_version],
changed_files=snapshot(
{
"test_something.py": """\
Expand All @@ -611,7 +619,7 @@ class container(BaseModel):
"""
}
),
).run_inline()
).run_pytest()


def test_dataclass_var():
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def check_pypy(request):
yield


@pytest.fixture(params=["pydantic>=2.0.0", "pydantic<2.0.0"])
def pydantic_version(request):
yield request.param


@pytest.fixture()
def check_update(source):
def w(source_code, *, flags="", reported_flags=None, number=1):
Expand Down
7 changes: 4 additions & 3 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from inline_snapshot.testing import Example


def test_pydantic_repr():
def test_pydantic_repr(pydantic_version):

Example(
"""
Expand All @@ -18,8 +18,9 @@ def test_pydantic():
assert M(size=5,name="Tom")==snapshot()
"""
).run_inline(
).run_pytest(
["--inline-snapshot=create"],
extra_dependencies=[pydantic_version],
changed_files=snapshot(
{
"test_something.py": """\
Expand All @@ -39,4 +40,4 @@ def test_pydantic():
"""
}
),
).run_inline()
).run_pytest()
Loading

0 comments on commit 06a8a0b

Please sign in to comment.