Skip to content

Commit

Permalink
adding json-schema generator and python code generator, support typin…
Browse files Browse the repository at this point in the history
…g.Self
  • Loading branch information
voidZXL committed Oct 19, 2024
1 parent a07ae93 commit 42428eb
Show file tree
Hide file tree
Showing 22 changed files with 1,173 additions and 149 deletions.
29 changes: 19 additions & 10 deletions tests/test_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import utype
from utype import (DataClass, Field, Options, Rule, Schema, exc,
register_transformer, types)
from utype.utils.compat import Final
from utype.utils.compat import Final, Self


@pytest.fixture(params=(False, True))
Expand Down Expand Up @@ -319,19 +319,28 @@ class T(Schema):
T(forward_in_dict={1: [2], 2: [1]})

# test not-module-level self ref
class Self(Schema):
class SelfRef(Schema):
name: str
to_self: "Self" = Field(required=False)
self_lst: List["Self"] = Field(default_factory=list)
to_self: "SelfRef" = Field(required=False)
self_lst: List["SelfRef"] = Field(default_factory=list)

sf = Self(name=1, to_self=b'{"name":"test"}')
sf = SelfRef(name=1, to_self=b'{"name":"test"}')
assert sf.to_self.name == "test"
assert sf.self_lst == []

sf2 = Self(name="t2", self_lst=[dict(sf)])
sf2 = SelfRef(name="t2", self_lst=[dict(sf)])
assert sf2.self_lst[0].name == "1"
assert "to_self" not in sf2

class SelfRef2(Schema):
name: str
to_self: Self = Field(required=False)
self_lst: List[Self] = Field(default_factory=list)

sfi = SelfRef2(name=1, to_self=b'{"name":"test"}')
assert sfi.to_self.name == "test"
assert sfi.self_lst == []

# class ForwardSchema(Schema):
# int1: 'types.PositiveInt' = Field(lt=10)
# int2: 'types.PositiveInt' = Field(lt=20)
Expand All @@ -340,11 +349,11 @@ class Self(Schema):

def test_local_forward_ref(self):
def f(u=0):
class Self(Schema):
class LocSelf(Schema):
num: int = u
to_self: Optional["Self"] = None
list_self: List["Self"] = utype.Field(default_factory=list)
data = Self(to_self={'to_self': {}}, list_self=[{'list_self': []}])
to_self: Optional["LocSelf"] = None
list_self: List["LocSelf"] = utype.Field(default_factory=list)
data = LocSelf(to_self={'to_self': {}}, list_self=[{'list_self': []}])
return data.to_self.to_self.num, data.list_self[0].num

assert f(1) == (1, 1)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import utype
from utype import Field, Options, Param, exc, parse, types
from utype.utils.compat import Final
from utype.utils.compat import Final, Self


@pytest.fixture(params=(False, True))
Expand All @@ -21,6 +21,22 @@ def on_error(request):
return request.param


class schemas:
class MySchema(utype.Schema):
a: int
b: int
result: int

@classmethod
@utype.parse
def add(cls, a: int, b: int) -> Self:
return dict(
a=a,
b=b,
result=a+b
)


class TestFunc:
def test_basic(self):
import utype
Expand Down Expand Up @@ -406,6 +422,13 @@ def fib(n: int = utype.Param(ge=0), _current: int = 0, _next: int = 1):
assert fib('10', _current=10, _next=6) == 55
assert fib('10', 10, 5) == 615 # can pass through positional args

def test_self_ref(self):
result = schemas.MySchema.add('1', '2')
assert isinstance(result, schemas.MySchema)
assert result.a == 1
assert result.b == 2
assert result.result == 3

def test_args_parse(self):
@utype.parse
def get(a):
Expand Down
1 change: 1 addition & 0 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def trans_my(trans, d, t):
],
date: [
("2020-02-20", date(2020, 2, 20), True, True),
("20200220", date(2020, 2, 20), True, True),
("2020/02/20", date(2020, 2, 20), True, True),
("2020/2/20", date(2020, 2, 20), True, True),
("20/02/2020", date(2020, 2, 20), True, True),
Expand Down
2 changes: 1 addition & 1 deletion utype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
register_transformer = TypeTransformer.registry.register


VERSION = (0, 5, 6, None)
VERSION = (0, 6, 0, 'alpha')


def _get_version():
Expand Down
7 changes: 6 additions & 1 deletion utype/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def __init__(self, obj, options: Options = None):
def make_context(self, context=None, force_error: bool = False):
return self.options.make_context(context=context, force_error=force_error)

@property
def bound(self):
return self.obj

@property
def kwargs(self):
return {}
Expand All @@ -109,7 +113,8 @@ def parse_annotation(self, annotation):
annotation=annotation,
forward_refs=self.forward_refs,
global_vars=self.globals,
force_clear_refs=self.is_local
force_clear_refs=self.is_local,
bound=self.bound
)

@cached_property
Expand Down
2 changes: 2 additions & 0 deletions utype/parser/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def generate_fields(self):
forward_refs=self.forward_refs,
options=self.options,
force_clear_refs=self.is_local,
bound=self.bound,
**self.kwargs
)
except Exception as e:
Expand Down Expand Up @@ -185,6 +186,7 @@ def generate_fields(self):
forward_refs=self.forward_refs,
options=self.options,
force_clear_refs=self.is_local,
bound=self.bound,
**self.kwargs
)
except Exception as e:
Expand Down
8 changes: 7 additions & 1 deletion utype/parser/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from uuid import UUID

from ..utils import exceptions as exc
from ..utils.base import ParamsCollector
from ..utils.compat import Literal, get_args, is_final, is_annotated, ForwardRef
from ..utils.datastructures import unprovided
from ..utils.functional import copy_value, get_name, multi
Expand All @@ -17,7 +18,7 @@
represent = repr


class Field:
class Field(ParamsCollector):
parser_field_cls = None

def __init__(
Expand Down Expand Up @@ -91,6 +92,8 @@ def __init__(
min_contains: int = None,
unique_items: Union[bool, ConstraintMode] = None,
):
super().__init__(locals())

if mode:
if readonly or writeonly:
raise exc.ConfigError(
Expand Down Expand Up @@ -1094,6 +1097,7 @@ def generate(
positional_only: bool = False,
global_vars=None,
forward_refs=None,
bound=None,
force_clear_refs=False,
**kwargs
):
Expand Down Expand Up @@ -1216,6 +1220,7 @@ def generate(
global_vars=global_vars,
forward_refs=forward_refs,
forward_key=attname,
bound=bound,
constraints=output_field.constraints if output_field else None,
force_clear_refs=force_clear_refs
)
Expand Down Expand Up @@ -1278,6 +1283,7 @@ def generate(
global_vars=global_vars,
forward_refs=forward_refs,
forward_key=attname,
bound=bound,
force_clear_refs=force_clear_refs
)

Expand Down
23 changes: 21 additions & 2 deletions utype/parser/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ def _f_pass():


class FunctionParser(BaseParser):
@property
def bound(self):
# class A:
# class B:
# def f():
# f.__qualname__ = 'A.B.f'
# f.bound -> 'A.B'
name = self.obj.__qualname__
if '.' in name:
return '.'.join(name.split('.')[:-1])
return None

@classmethod
def function_pass(cls, f):
if not inspect.isfunction(f):
Expand Down Expand Up @@ -299,10 +311,12 @@ def generate_return_types(self):
if not self.return_annotation:
return

self.return_type = self.parse_annotation(annotation=self.return_annotation)
self.return_type = self.parse_annotation(
annotation=self.return_annotation
)

# https://docs.python.org/3/library/typing.html#typing.Generator
if self.return_type and issubclass(self.return_type, Rule):
if self.return_type and isinstance(self.return_type, type) and issubclass(self.return_type, Rule):
if self.is_generator:
if self.return_type.__origin__ in (Iterable, Iterator):
self.generator_yield_type = self.return_type.__args__[0]
Expand Down Expand Up @@ -406,6 +420,7 @@ def generate_fields(self):
forward_refs=self.forward_refs,
options=self.options,
positional_only=param.kind == param.POSITIONAL_ONLY,
bound=self.bound,
**self.kwargs
)
except Exception as e:
Expand Down Expand Up @@ -760,6 +775,7 @@ def get_sync_generator(
@wraps(self.obj)
def eager_generator(*args, **kwargs) -> Generator:
context = (options or self.options).make_context()
self.resolve_forward_refs()
args, kwargs = self.get_params(
args,
kwargs,
Expand Down Expand Up @@ -846,6 +862,7 @@ def get_async_generator(
@wraps(self.obj)
def eager_generator(*args, **kwargs) -> AsyncGenerator:
context = (options or self.options).make_context()
self.resolve_forward_refs()
args, kwargs = self.get_params(
args,
kwargs,
Expand Down Expand Up @@ -886,6 +903,7 @@ def get_async_call(
@wraps(self.obj)
def eager_call(*args, **kwargs):
context = (options or self.options).make_context()
self.resolve_forward_refs()
args, kwargs = self.get_params(
args,
kwargs,
Expand Down Expand Up @@ -915,6 +933,7 @@ def sync_call(
parse_params: bool = None,
parse_result: bool = None,
):
self.resolve_forward_refs()
args, kwargs = self.get_params(
args,
kwargs,
Expand Down
4 changes: 4 additions & 0 deletions utype/parser/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, List, Optional, Set, Type, Union

from ..utils import exceptions as exc
# from ..utils.base import ParamsCollector
from ..utils.compat import Literal
from ..utils.datastructures import unprovided
from ..utils.functional import multi
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(
# if this value is another callable (like dict, list), return value()
# otherwise return this value directly when attr is unprovided
):
# super().__init__({k: v for k, v in locals().items() if not unprovided(v)})

if no_data_loss:
if addition is None:
Expand Down Expand Up @@ -182,6 +184,8 @@ def __init__(
for key, val in locals().items():
if unprovided(val):
continue
if key.startswith('_'):
continue
if hasattr(self, key):
# if getattr(self, key) == val:
# continue
Expand Down
Loading

0 comments on commit 42428eb

Please sign in to comment.