From 4edf51f63a54a569aee25013b937d731b69f2e9c Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sun, 14 Jun 2020 00:29:34 +0900 Subject: [PATCH] support one_of (#147) * support one_of * fix validator * update template * remove a validator * fix pytest version * fix pytest version --- datamodel_code_generator/model/base.py | 8 ++++ .../model/pydantic/__init__.py | 14 ++++++ .../model/pydantic/base_model.py | 19 ++++++++ .../model/template/pydantic/BaseModel.jinja2 | 3 ++ datamodel_code_generator/parser/jsonschema.py | 47 +++++++++++-------- setup.cfg | 2 +- tests/parser/test_jsonschema.py | 47 +++++++++++++++++-- tests/parser/test_openapi.py | 1 + 8 files changed, 117 insertions(+), 24 deletions(-) diff --git a/datamodel_code_generator/model/base.py b/datamodel_code_generator/model/base.py index ef4442b4f..63c97ab5a 100644 --- a/datamodel_code_generator/model/base.py +++ b/datamodel_code_generator/model/base.py @@ -79,6 +79,10 @@ def _get_type_hint(self) -> Optional[str]: self.imports.append(IMPORT_UNION) return f'{UNION}[{type_hint}]' + @property + def method(self) -> Optional[str]: + return None + @root_validator def validate_root(cls, values: Any) -> Dict[str, Any]: name = values.get('name') @@ -149,6 +153,7 @@ def __init__( imports: Optional[List[Import]] = None, auto_import: bool = True, reference_classes: Optional[List[str]] = None, + methods: Optional[List[str]] = None, ) -> None: if not self.TEMPLATE_FILE_PATH: raise Exception('TEMPLATE_FILE_PATH is undefined') @@ -208,6 +213,8 @@ def __init__( for field in self.fields: self.imports.extend(field.imports) + self.methods: List[str] = methods or [] + super().__init__(template_file_path=template_file_path) def render(self) -> str: @@ -216,6 +223,7 @@ def render(self) -> str: fields=self.fields, decorators=self.decorators, base_class=self.base_class, + methods=self.methods, **self.extra_template_data, ) return response diff --git a/datamodel_code_generator/model/pydantic/__init__.py b/datamodel_code_generator/model/pydantic/__init__.py index 31ab5d5f9..e45ee2bd9 100644 --- a/datamodel_code_generator/model/pydantic/__init__.py +++ b/datamodel_code_generator/model/pydantic/__init__.py @@ -1,7 +1,10 @@ +from pathlib import Path from typing import List, Optional +from jinja2 import Environment, FileSystemLoader, Template from pydantic import BaseModel as _BaseModel +from ..base import TEMPLATE_DIR from .base_model import BaseModel, DataModelField from .custom_root_type import CustomRootType from .dataclass import DataClass @@ -18,10 +21,21 @@ class Config(_BaseModel): title: Optional[str] = None +# def get_validator_template() -> Template: +# template_file_path: Path = Path('pydantic') / 'one_of_validator.jinja2' +# loader = FileSystemLoader(str(TEMPLATE_DIR / template_file_path.parent)) +# environment: Environment = Environment(loader=loader, autoescape=True) +# return environment.get_template(template_file_path.name) +# +# +# VALIDATOR_TEMPLATE: Template = get_validator_template() + + __all__ = [ 'BaseModel', 'DataModelField', 'CustomRootType', 'DataClass', 'dump_resolve_reference_action', + 'VALIDATOR_TEMPLATE', ] diff --git a/datamodel_code_generator/model/pydantic/base_model.py b/datamodel_code_generator/model/pydantic/base_model.py index 7522b7c00..f6da44bae 100644 --- a/datamodel_code_generator/model/pydantic/base_model.py +++ b/datamodel_code_generator/model/pydantic/base_model.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import Any, DefaultDict, Dict, List, Optional, Set, Union +from jinja2 import Environment, FileSystemLoader, Template + from datamodel_code_generator.imports import Import from datamodel_code_generator.model import DataModel, DataModelFieldBase from datamodel_code_generator.model.pydantic.types import get_data_type @@ -25,6 +27,20 @@ def get_valid_argument(self, value: Any) -> Union[str, List[Any], Dict[Any, Any] def __init__(self, **values: Any) -> None: super().__init__(**values) + @property + def method(self) -> Optional[str]: + return self.validator + + @property + def validator(self) -> Optional[str]: + return None + # TODO refactor this method for other validation logic + # from datamodel_code_generator.model.pydantic import VALIDATOR_TEMPLATE + # + # return VALIDATOR_TEMPLATE.render( + # field_name=self.name, types=','.join([t.type_hint for t in self.data_types]) + # ) + @property def field(self) -> Optional[str]: """for backwards compatibility""" @@ -66,6 +82,8 @@ def __init__( imports: Optional[List[Import]] = None, ): + methods: List[str] = [field.method for field in fields if field.method] + super().__init__( name=name, fields=fields, # type: ignore @@ -77,6 +95,7 @@ def __init__( auto_import=auto_import, reference_classes=reference_classes, imports=imports, + methods=methods, ) config_parameters: Dict[str, Any] = {} diff --git a/datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 b/datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 index ed5bffe16..46b07408e 100644 --- a/datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 +++ b/datamodel_code_generator/model/template/pydantic/BaseModel.jinja2 @@ -16,4 +16,7 @@ class {{ class_name }}({{ base_class }}): {%- else %} {{ field.name }}: {{ field.type_hint }} = {{ field.default }} {%- endif %} +{%- for method in methods -%} + {{ method }} +{%- endfor -%} {%- endfor -%} diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index ebfded7dd..a11c69a3a 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -86,6 +86,7 @@ class JsonSchemaObject(BaseModel): exclusiveMaximum: Optional[bool] exclusiveMinimum: Optional[bool] additionalProperties: Union['JsonSchemaObject', bool, None] + oneOf: List['JsonSchemaObject'] = [] anyOf: List['JsonSchemaObject'] = [] allOf: List['JsonSchemaObject'] = [] enum: List[str] = [] @@ -197,32 +198,30 @@ def set_title(self, name: str, obj: JsonSchemaObject) -> None: self.extra_template_data[name]['title'] = obj.title - def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: - any_of_data_types: List[DataType] = [] - for any_of_item in obj.anyOf: - if any_of_item.ref: # $ref - any_of_data_types.append( + def parse_list_item( + self, name: str, target_items: List[JsonSchemaObject] + ) -> List[DataType]: + data_types: List[DataType] = [] + for item in target_items: + if item.ref: # $ref + data_types.append( self.data_type( - type=self.get_class_name( - any_of_item.ref_object_name, unique=False - ), + type=self.get_class_name(item.ref_object_name, unique=False), ref=True, version_compatible=True, ) ) - elif not any(v for k, v in vars(any_of_item).items() if k != 'type'): + elif not any(v for k, v in vars(item).items() if k != 'type'): # trivial types - any_of_data_types.extend(self.get_data_type(any_of_item)) + data_types.extend(self.get_data_type(item)) elif ( - any_of_item.is_array - and isinstance(any_of_item.items, JsonSchemaObject) - and not any( - v for k, v in vars(any_of_item.items).items() if k != 'type' - ) + item.is_array + and isinstance(item.items, JsonSchemaObject) + and not any(v for k, v in vars(item.items).items() if k != 'type') ): # trivial item types - types = [t.type_hint for t in self.get_data_type(any_of_item.items)] - any_of_data_types.append( + types = [t.type_hint for t in self.get_data_type(item.items)] + data_types.append( self.data_type( type=f"List[Union[{', '.join(types)}]]" if len(types) > 1 @@ -233,13 +232,19 @@ def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: else: name = self.get_class_name(name, unique=False) singular_name = get_singular_name(name) - self.parse_object(singular_name, any_of_item) - any_of_data_types.append( + self.parse_object(singular_name, item) + data_types.append( self.data_type( type=singular_name, ref=True, version_compatible=True ) ) - return any_of_data_types + return data_types + + def parse_any_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: + return self.parse_list_item(name, obj.anyOf) + + def parse_one_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: + return self.parse_list_item(name, obj.oneOf) def parse_all_of(self, name: str, obj: JsonSchemaObject) -> List[DataType]: fields: List[DataModelFieldBase] = [] @@ -306,6 +311,8 @@ def parse_object_fields(self, obj: JsonSchemaObject) -> List[DataModelFieldBase] is_union = True elif field.anyOf: field_types = self.parse_any_of(field_name, field) + elif field.oneOf: + field_types = self.parse_one_of(field_name, field) elif field.allOf: class_name = self.get_class_name(field_name) field_types = self.parse_all_of(class_name, field) diff --git a/setup.cfg b/setup.cfg index f2921e461..eebb0a553 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,7 @@ install_requires = genson==1.2.1 tests_require = - pytest + pytest>=4.6 pytest-benchmark pytest-cov pytest-mock diff --git a/tests/parser/test_jsonschema.py b/tests/parser/test_jsonschema.py index cdf733d36..414fbfeea 100644 --- a/tests/parser/test_jsonschema.py +++ b/tests/parser/test_jsonschema.py @@ -4,6 +4,8 @@ import pytest +from datamodel_code_generator import DataModelField +from datamodel_code_generator.model import DataModelFieldBase from datamodel_code_generator.model.pydantic import BaseModel, CustomRootType from datamodel_code_generator.parser.base import dump_templates from datamodel_code_generator.parser.jsonschema import ( @@ -29,7 +31,9 @@ def test_get_model_by_path(schema: Dict, path: str, model: Dict): def test_json_schema_parser_parse_ref(): - parser = JsonSchemaParser(BaseModel, CustomRootType) + parser = JsonSchemaParser( + BaseModel, CustomRootType, data_model_field_type=DataModelField + ) parser.parse_raw_obj = Mock() external_parent_path = Path(DATA_PATH / 'external_parent.json') parser.base_path = external_parent_path.parent @@ -53,7 +57,9 @@ def test_json_schema_parser_parse_ref(): def test_json_schema_object_ref_url(): - parser = JsonSchemaParser(BaseModel, CustomRootType) + parser = JsonSchemaParser( + BaseModel, CustomRootType, data_model_field_type=DataModelField + ) obj = JsonSchemaObject.parse_obj({'$ref': 'https://example.org'}) with pytest.raises(NotImplementedError): parser.parse_ref(obj) @@ -142,6 +148,41 @@ def test_parse_object(source_obj, generated_classes): ], ) def test_parse_any_root_object(source_obj, generated_classes): - parser = JsonSchemaParser(BaseModel, CustomRootType) + parser = JsonSchemaParser( + BaseModel, CustomRootType, data_model_field_type=DataModelField + ) parser.parse_root_type('AnyObject', JsonSchemaObject.parse_obj(source_obj)) assert dump_templates(list(parser.results)) == generated_classes + + +@pytest.mark.parametrize( + 'source_obj,generated_classes', + [ + ( + { + "properties": { + "item": { + "properties": { + "timeout": { + "oneOf": [{"type": "string"}, {"type": "integer"}] + } + }, + "type": "object", + } + } + }, + """class Item(BaseModel): + timeout: Optional[Union[str, int]] = None + + +class OnOfObject(BaseModel): + item: Optional[Item] = None""", + ) + ], +) +def test_parse_one_of_object(source_obj, generated_classes): + parser = JsonSchemaParser( + BaseModel, CustomRootType, data_model_field_type=DataModelField + ) + parser.parse_raw_obj('onOfObject', source_obj) + assert dump_templates(list(parser.results)) == generated_classes diff --git a/tests/parser/test_openapi.py b/tests/parser/test_openapi.py index c4f7110b9..fd8853a6d 100644 --- a/tests/parser/test_openapi.py +++ b/tests/parser/test_openapi.py @@ -849,6 +849,7 @@ def test_openapi_parser_parse_nested_anyof(): parser = OpenAPIParser( BaseModel, CustomRootType, + data_model_field_type=DataModelField, text=Path(DATA_PATH / 'nested_anyof.yaml').read_text(), ) assert (