diff --git a/README.md b/README.md index 92c0b0c0f..b4a7ea103 100644 --- a/README.md +++ b/README.md @@ -187,13 +187,25 @@ components: `model.py`: ```python +# generated by datamodel-codegen: +# filename: api.yaml +# timestamp: 2019-07-23T14:23:18+00:00 + +from typing import List, Optional + +from pydantic import BaseModel + + + class Pet(BaseModel): id: int name: str - tag: str = None + tag: Optional[str] = None + +class Pets(BaseModel): + __root__: List[Pet] -Pets = List[Pet] class Error(BaseModel): code: int @@ -201,11 +213,12 @@ class Error(BaseModel): class api(BaseModel): - apiKey: str = None - apiVersionNumber: str = None + apiKey: Optional[str] = None + apiVersionNumber: Optional[str] = None -apis = List[api] +class apis(BaseModel): + __root__: List[api] ``` ## Development diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index 83b432bf4..b037e1c7a 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Type, Union +from typing import Dict, List, Optional, Set, Type, Union, Iterator from dataclasses import Field, dataclass from prance import BaseParser, ResolvingParser @@ -44,7 +44,7 @@ def get_data_type(_type, format =None) -> DataType: def dump_templates(templates: Union[TemplateBase, List[TemplateBase]]) -> str: if isinstance(templates, TemplateBase): templates = [templates] - return '\n\n'.join(str(m) for m in templates) + return '\n\n\n'.join(str(m) for m in templates) class Parser: @@ -55,9 +55,8 @@ def __init__(self, data_model_type: Type[DataModel], data_model_field_type: Type self.data_model_type: Type[DataModel] = data_model_type self.data_model_field_type: Type[DataModelField] = data_model_field_type - self.models = [] - def parse_object(self, name: str, obj: Dict) -> str: + def parse_object(self, name: str, obj: Dict) -> Iterator[TemplateBase]: requires: Set[str] = set(obj.get('required', [])) d_list: List[DataModelField] = [] for field_name, filed in obj['properties'].items(): @@ -66,25 +65,23 @@ def parse_object(self, name: str, obj: Dict) -> str: name=field_name, type_hint=get_data_type(filed["type"], filed.get("format")).type_hint, required=field_name in requires)) - return dump_templates(self.data_model_type(name, fields=d_list)) + yield self.data_model_type(name, fields=d_list) - def parse_array(self, name: str, obj: Dict) -> str: - templates: List[TemplateBase] = [] + def parse_array(self, name: str, obj: Dict) -> Iterator[TemplateBase]: # continue if '$ref' in obj['items']: _type: str = f"List[{obj['items']['$ref'].split('/')[-1]}]" - templates.append(CustomRootType(name, _type)) + yield CustomRootType(name, _type) elif 'properties' in obj['items']: - self.parse_object(name[:-1], obj['items']) - templates.append(CustomRootType(name, f'List[{name[:-1]}]')) - return dump_templates(templates) + yield from self.parse_object(name[:-1], obj['items']) + yield CustomRootType(name, f'List[{name[:-1]}]') def parse(self) -> str: - parsed_objects: List[str] = [] + templates: List[TemplateBase] = [] for obj_name, obj in self.base_parser.specification['components']['schemas'].items(): if 'properties' in obj: - parsed_objects.append(self.parse_object(obj_name, obj)) + templates.extend(self.parse_object(obj_name, obj)) elif 'items' in obj: - parsed_objects.append(self.parse_array(obj_name, obj)) + templates.extend(self.parse_array(obj_name, obj)) - return '\n\n\n'.join(parsed_objects) + return dump_templates(templates)