Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions ninja/params/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -39,10 +38,6 @@
TModels = List[TModel]


def NestedDict() -> DictStrAny:
return defaultdict(NestedDict)


class ParamModel(BaseModel, ABC):
__ninja_param_source__ = None

Expand All @@ -65,11 +60,6 @@ def resolve(
return cls()

data = cls._map_data_paths(data)
# Convert defaultdict to dict for pydantic 2.12+ compatibility
# In pydantic 2.12+, accessing missing keys in defaultdict creates nested
# defaultdicts which then fail validation
if isinstance(data, defaultdict):
data = dict(data)
return cls.model_validate(data, context={"request": request})

@classmethod
Expand All @@ -78,22 +68,20 @@ def _map_data_paths(cls, data: DictStrAny) -> DictStrAny:
if not flatten_map:
return data

mapped_data: DictStrAny = NestedDict()
for k in flatten_map:
if k in data:
cls._map_data_path(mapped_data, data[k], flatten_map[k])
else:
cls._map_data_path(mapped_data, None, flatten_map[k])

mapped_data: DictStrAny = {}
for key, path in flatten_map.items():
cls._map_data_path(mapped_data, data.get(key), path)
return mapped_data

@classmethod
def _map_data_path(cls, data: DictStrAny, value: Any, path: Tuple) -> None:
if len(path) == 1:
if value is not None:
data[path[0]] = value
else:
cls._map_data_path(data[path[0]], value, path[1:])
def _map_data_path(
cls, data: DictStrAny, value: Any, path: Tuple[str, ...]
) -> None:
current = data
for key in path[:-1]:
current = current.setdefault(key, {})
if value is not None:
current[path[-1]] = value


class QueryModel(ParamModel):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_params_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from ninja.params.models import DictStrAny, ParamModel


class _NestedParamModel(ParamModel):
outer: DictStrAny
leaf: Optional[int]

__ninja_flatten_map__ = {
"foo": ("outer", "foo"),
"bar": ("outer", "bar"),
"leaf": ("leaf",),
}


def test_map_data_paths_creates_parent_for_missing_nested_values():
assert _NestedParamModel._map_data_paths({}) == {"outer": {}}


def test_map_data_paths_sets_values_when_present():
data = _NestedParamModel._map_data_paths({"foo": 1, "leaf": 2})
assert data == {"outer": {"foo": 1}, "leaf": 2}
132 changes: 78 additions & 54 deletions tests/test_query_schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import datetime
from enum import IntEnum

from pydantic import Field
from pydantic import BaseModel, Field

from ninja import NinjaAPI, Query, Schema
from ninja.testing.client import TestClient


class Range(IntEnum):
Expand All @@ -12,7 +13,7 @@ class Range(IntEnum):
TWO_HUNDRED = 200


class Filter(Schema):
class Filter(BaseModel):
to_datetime: datetime = Field(alias="to")
from_datetime: datetime = Field(alias="from")
range: Range = Range.TWENTY
Expand All @@ -28,7 +29,7 @@ class Data(Schema):

@api.get("/test")
def query_params_schema(request, filters: Filter = Query(...)):
return filters.dict()
return filters.model_dump()


@api.get("/test-mixed")
Expand All @@ -39,57 +40,80 @@ def query_params_mixed_schema(
filters: Filter = Query(...),
data: Data = Query(...),
):
return dict(query1=query1, query2=query2, filters=filters.dict(), data=data.dict())


# def test_request():
# client = TestClient(api)
# response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
# print("!", response.json())
# assert response.json() == {
# "to_datetime": "1970-01-01T00:00:02Z",
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# }

# response = client.get("/test?from=1&to=2&range=21")
# assert response.status_code == 422


# def test_request_mixed():
# client = TestClient(api)
# response = client.get(
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
# )
# print(response.json())
# assert response.json() == {
# "data": {"a_float": 1.6, "an_int": 3},
# "filters": {
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# "to_datetime": "1970-01-01T00:00:02Z",
# },
# "query1": 2,
# "query2": 5,
# }

# response = client.get(
# "/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
# )
# print(response.json())
# assert response.json() == {
# "data": {"a_float": 1.5, "an_int": 0},
# "filters": {
# "from_datetime": "1970-01-01T00:00:01Z",
# "range": 20,
# "to_datetime": "1970-01-01T00:00:02Z",
# },
# "query1": 2,
# "query2": 10,
# }

# response = client.get("/test-mixed?from=1&to=2")
# assert response.status_code == 422
return dict(
query1=query1,
query2=query2,
filters=filters.model_dump(),
data=data.model_dump(),
)


def test_request():
client = TestClient(api)
response = client.get("/test?from=1&to=2&range=20&foo=1&range2=50")
print("!", response.json())
assert response.json() == {
"to_datetime": "1970-01-01T00:00:02Z",
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
}

response = client.get("/test?from=1&to=2&range=21")
assert response.status_code == 422


def test_request_mixed():
client = TestClient(api)
response = client.get(
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&int=3&float=1.6"
)
print(response.json())
assert response.json() == {
"data": {"a_float": 1.6, "an_int": 3},
"filters": {
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
"to_datetime": "1970-01-01T00:00:02Z",
},
"query1": 2,
"query2": 5,
}

response = client.get(
"/test-mixed?from=1&to=2&range=20&foo=1&range2=50&query1=2&query2=10"
)
print(response.json())
assert response.json() == {
"data": {"a_float": 1.5, "an_int": 0},
"filters": {
"from_datetime": "1970-01-01T00:00:01Z",
"range": 20,
"to_datetime": "1970-01-01T00:00:02Z",
},
"query1": 2,
"query2": 10,
}

response = client.get("/test-mixed?from=1&to=2")
assert response.status_code == 422


def test_request_query_params_using_basemodel():
class Foo(BaseModel):
start: int
optional: int = 42

temp_api = NinjaAPI()

@temp_api.get("/foo")
def view(request, foo: Foo = Query(...)):
return foo.model_dump()

client = TestClient(temp_api)
resp = client.get("/foo?start=1")

assert resp.status_code == 200
assert resp.json() == {"start": 1, "optional": 42}


def test_schema():
Expand Down