Skip to content

Commit

Permalink
support one_of (#147)
Browse files Browse the repository at this point in the history
* support one_of

* fix validator

* update template

* remove a validator

* fix pytest version

* fix pytest version
  • Loading branch information
koxudaxi authored Jun 13, 2020
1 parent 26cc2b4 commit 4edf51f
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 24 deletions.
8 changes: 8 additions & 0 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def _get_type_hint(self) -> Optional[str]:
self.imports.append(IMPORT_UNION)
return f'{UNION}[{type_hint}]'

@property
def method(self) -> Optional[str]:
return None

@root_validator
def validate_root(cls, values: Any) -> Dict[str, Any]:
name = values.get('name')
Expand Down Expand Up @@ -149,6 +153,7 @@ def __init__(
imports: Optional[List[Import]] = None,
auto_import: bool = True,
reference_classes: Optional[List[str]] = None,
methods: Optional[List[str]] = None,
) -> None:
if not self.TEMPLATE_FILE_PATH:
raise Exception('TEMPLATE_FILE_PATH is undefined')
Expand Down Expand Up @@ -208,6 +213,8 @@ def __init__(
for field in self.fields:
self.imports.extend(field.imports)

self.methods: List[str] = methods or []

super().__init__(template_file_path=template_file_path)

def render(self) -> str:
Expand All @@ -216,6 +223,7 @@ def render(self) -> str:
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
methods=self.methods,
**self.extra_template_data,
)
return response
Expand Down
14 changes: 14 additions & 0 deletions datamodel_code_generator/model/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path
from typing import List, Optional

from jinja2 import Environment, FileSystemLoader, Template
from pydantic import BaseModel as _BaseModel

from ..base import TEMPLATE_DIR
from .base_model import BaseModel, DataModelField
from .custom_root_type import CustomRootType
from .dataclass import DataClass
Expand All @@ -18,10 +21,21 @@ class Config(_BaseModel):
title: Optional[str] = None


# def get_validator_template() -> Template:
# template_file_path: Path = Path('pydantic') / 'one_of_validator.jinja2'
# loader = FileSystemLoader(str(TEMPLATE_DIR / template_file_path.parent))
# environment: Environment = Environment(loader=loader, autoescape=True)
# return environment.get_template(template_file_path.name)
#
#
# VALIDATOR_TEMPLATE: Template = get_validator_template()


__all__ = [
'BaseModel',
'DataModelField',
'CustomRootType',
'DataClass',
'dump_resolve_reference_action',
'VALIDATOR_TEMPLATE',
]
19 changes: 19 additions & 0 deletions datamodel_code_generator/model/pydantic/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import Any, DefaultDict, Dict, List, Optional, Set, Union

from jinja2 import Environment, FileSystemLoader, Template

from datamodel_code_generator.imports import Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.pydantic.types import get_data_type
Expand All @@ -25,6 +27,20 @@ def get_valid_argument(self, value: Any) -> Union[str, List[Any], Dict[Any, Any]
def __init__(self, **values: Any) -> None:
super().__init__(**values)

@property
def method(self) -> Optional[str]:
return self.validator

@property
def validator(self) -> Optional[str]:
return None
# TODO refactor this method for other validation logic
# from datamodel_code_generator.model.pydantic import VALIDATOR_TEMPLATE
#
# return VALIDATOR_TEMPLATE.render(
# field_name=self.name, types=','.join([t.type_hint for t in self.data_types])
# )

@property
def field(self) -> Optional[str]:
"""for backwards compatibility"""
Expand Down Expand Up @@ -66,6 +82,8 @@ def __init__(
imports: Optional[List[Import]] = None,
):

methods: List[str] = [field.method for field in fields if field.method]

super().__init__(
name=name,
fields=fields, # type: ignore
Expand All @@ -77,6 +95,7 @@ def __init__(
auto_import=auto_import,
reference_classes=reference_classes,
imports=imports,
methods=methods,
)

config_parameters: Dict[str, Any] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ class {{ class_name }}({{ base_class }}):
{%- else %}
{{ field.name }}: {{ field.type_hint }} = {{ field.default }}
{%- endif %}
{%- for method in methods -%}
{{ method }}
{%- endfor -%}
{%- endfor -%}
47 changes: 27 additions & 20 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class JsonSchemaObject(BaseModel):
exclusiveMaximum: Optional[bool]
exclusiveMinimum: Optional[bool]
additionalProperties: Union['JsonSchemaObject', bool, None]
oneOf: List['JsonSchemaObject'] = []
anyOf: List['JsonSchemaObject'] = []
allOf: List['JsonSchemaObject'] = []
enum: List[str] = []
Expand Down Expand Up @@ -197,32 +198,30 @@ def set_title(self, name: str, obj: JsonSchemaObject) -> None:

self.extra_template_data[name]['title'] = obj.title

def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
any_of_data_types: List[DataType] = []
for any_of_item in obj.anyOf:
if any_of_item.ref: # $ref
any_of_data_types.append(
def parse_list_item(
self, name: str, target_items: List[JsonSchemaObject]
) -> List[DataType]:
data_types: List[DataType] = []
for item in target_items:
if item.ref: # $ref
data_types.append(
self.data_type(
type=self.get_class_name(
any_of_item.ref_object_name, unique=False
),
type=self.get_class_name(item.ref_object_name, unique=False),
ref=True,
version_compatible=True,
)
)
elif not any(v for k, v in vars(any_of_item).items() if k != 'type'):
elif not any(v for k, v in vars(item).items() if k != 'type'):
# trivial types
any_of_data_types.extend(self.get_data_type(any_of_item))
data_types.extend(self.get_data_type(item))
elif (
any_of_item.is_array
and isinstance(any_of_item.items, JsonSchemaObject)
and not any(
v for k, v in vars(any_of_item.items).items() if k != 'type'
)
item.is_array
and isinstance(item.items, JsonSchemaObject)
and not any(v for k, v in vars(item.items).items() if k != 'type')
):
# trivial item types
types = [t.type_hint for t in self.get_data_type(any_of_item.items)]
any_of_data_types.append(
types = [t.type_hint for t in self.get_data_type(item.items)]
data_types.append(
self.data_type(
type=f"List[Union[{', '.join(types)}]]"
if len(types) > 1
Expand All @@ -233,13 +232,19 @@ def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
else:
name = self.get_class_name(name, unique=False)
singular_name = get_singular_name(name)
self.parse_object(singular_name, any_of_item)
any_of_data_types.append(
self.parse_object(singular_name, item)
data_types.append(
self.data_type(
type=singular_name, ref=True, version_compatible=True
)
)
return any_of_data_types
return data_types

def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
return self.parse_list_item(name, obj.anyOf)

def parse_one_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
return self.parse_list_item(name, obj.oneOf)

def parse_all_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]:
fields: List[DataModelFieldBase] = []
Expand Down Expand Up @@ -306,6 +311,8 @@ def parse_object_fields(self, obj: JsonSchemaObject) -> List[DataModelFieldBase]
is_union = True
elif field.anyOf:
field_types = self.parse_any_of(field_name, field)
elif field.oneOf:
field_types = self.parse_one_of(field_name, field)
elif field.allOf:
class_name = self.get_class_name(field_name)
field_types = self.parse_all_of(class_name, field)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ install_requires =
genson==1.2.1

tests_require =
pytest
pytest>=4.6
pytest-benchmark
pytest-cov
pytest-mock
Expand Down
47 changes: 44 additions & 3 deletions tests/parser/test_jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest

from datamodel_code_generator import DataModelField
from datamodel_code_generator.model import DataModelFieldBase
from datamodel_code_generator.model.pydantic import BaseModel, CustomRootType
from datamodel_code_generator.parser.base import dump_templates
from datamodel_code_generator.parser.jsonschema import (
Expand All @@ -29,7 +31,9 @@ def test_get_model_by_path(schema: Dict, path: str, model: Dict):


def test_json_schema_parser_parse_ref():
parser = JsonSchemaParser(BaseModel, CustomRootType)
parser = JsonSchemaParser(
BaseModel, CustomRootType, data_model_field_type=DataModelField
)
parser.parse_raw_obj = Mock()
external_parent_path = Path(DATA_PATH / 'external_parent.json')
parser.base_path = external_parent_path.parent
Expand All @@ -53,7 +57,9 @@ def test_json_schema_parser_parse_ref():


def test_json_schema_object_ref_url():
parser = JsonSchemaParser(BaseModel, CustomRootType)
parser = JsonSchemaParser(
BaseModel, CustomRootType, data_model_field_type=DataModelField
)
obj = JsonSchemaObject.parse_obj({'$ref': 'https://example.org'})
with pytest.raises(NotImplementedError):
parser.parse_ref(obj)
Expand Down Expand Up @@ -142,6 +148,41 @@ def test_parse_object(source_obj, generated_classes):
],
)
def test_parse_any_root_object(source_obj, generated_classes):
parser = JsonSchemaParser(BaseModel, CustomRootType)
parser = JsonSchemaParser(
BaseModel, CustomRootType, data_model_field_type=DataModelField
)
parser.parse_root_type('AnyObject', JsonSchemaObject.parse_obj(source_obj))
assert dump_templates(list(parser.results)) == generated_classes


@pytest.mark.parametrize(
'source_obj,generated_classes',
[
(
{
"properties": {
"item": {
"properties": {
"timeout": {
"oneOf": [{"type": "string"}, {"type": "integer"}]
}
},
"type": "object",
}
}
},
"""class Item(BaseModel):
timeout: Optional[Union[str, int]] = None
class OnOfObject(BaseModel):
item: Optional[Item] = None""",
)
],
)
def test_parse_one_of_object(source_obj, generated_classes):
parser = JsonSchemaParser(
BaseModel, CustomRootType, data_model_field_type=DataModelField
)
parser.parse_raw_obj('onOfObject', source_obj)
assert dump_templates(list(parser.results)) == generated_classes
1 change: 1 addition & 0 deletions tests/parser/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ def test_openapi_parser_parse_nested_anyof():
parser = OpenAPIParser(
BaseModel,
CustomRootType,
data_model_field_type=DataModelField,
text=Path(DATA_PATH / 'nested_anyof.yaml').read_text(),
)
assert (
Expand Down

0 comments on commit 4edf51f

Please sign in to comment.