diff --git a/datamodel_code_generator/format.py b/datamodel_code_generator/format.py index 675d9b14d..c5f04bfe0 100644 --- a/datamodel_code_generator/format.py +++ b/datamodel_code_generator/format.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Dict +from typing import Dict, Optional import black import isort @@ -20,7 +20,11 @@ class PythonVersion(Enum): } -def format_code(code: str, python_version: PythonVersion, settings_path: Path) -> str: +def format_code( + code: str, python_version: PythonVersion, settings_path: Optional[Path] = None +) -> str: + if not settings_path: + settings_path = Path().resolve() code = apply_isort(code, settings_path) code = apply_black(code, python_version, settings_path) return code diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index c2158ae56..2dcdac0f9 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -354,9 +354,7 @@ def parse( body = '\n'.join(result) if format_: - body = format_code( - body, self.target_python_version, settings_path or Path().resolve() - ) + body = format_code(body, self.target_python_version, settings_path) results[module] = Result(body=body, source=models[0].path) diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 104cd8ff9..08078c2f7 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -216,6 +216,7 @@ def __init__( ) self.remote_object_cache: Dict[str, Dict[str, Any]] = {} + self.raw_obj: Dict[Any, Any] = {} def get_data_type(self, obj: JsonSchemaObject) -> DataType: if obj.type is None: @@ -299,6 +300,13 @@ def parse_all_of( ) -> DataType: fields: List[DataModelFieldBase] = [] base_classes: List[DataType] = [] + if len(obj.allOf) == 1: + single_obj = obj.allOf[0] + if single_obj.ref and single_obj.ref.startswith('#/'): + if get_model_by_path(self.raw_obj, single_obj.ref[2:].split('/')).get( + 'enum' + ): + return self.get_ref_data_type(single_obj.ref) for all_of_item in obj.allOf: if all_of_item.ref: # $ref base_classes.append(self.get_ref_data_type(all_of_item.ref)) @@ -735,10 +743,10 @@ def parse_raw(self) -> None: self.current_source_path = source.path path_parts = list(source.path.parts) self.model_resolver.set_current_root(path_parts) - raw_obj: Dict[Any, Any] = yaml.safe_load(source.text) - obj_name = raw_obj.get('title', 'Model') + self.raw_obj = yaml.safe_load(source.text) + obj_name = self.raw_obj.get('title', 'Model') obj_name = self.model_resolver.add(path_parts, obj_name, unique=False).name - self.parse_raw_obj(obj_name, raw_obj, path_parts) - definitions = raw_obj.get('definitions', {}) + self.parse_raw_obj(obj_name, self.raw_obj, path_parts) + definitions = self.raw_obj.get('definitions', {}) for key, model in definitions.items(): self.parse_raw_obj(key, model, [*path_parts, '#/definitions', key]) diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index aa4e3281a..b19fe5746 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import yaml @@ -16,13 +16,17 @@ def parse_raw(self) -> None: base_parser = BaseParser( spec_string=source.text, backend='openapi-spec-validator' ) - components: Dict[str, Any] = base_parser.specification['components'] + specification: Dict[str, Any] = base_parser.specification else: - components = yaml.safe_load(source.text)['components'] - self.model_resolver.set_current_root(list(source.path.parts)) - for obj_name, raw_obj in components[ + specification = yaml.safe_load(source.text) + self.raw_obj = specification + schemas: Optional[Dict[Any, Any]] = specification.get('components', {}).get( 'schemas' - ].items(): # type: str, Dict[Any, Any] + ) + if not schemas: # pragma: no cover + continue + self.model_resolver.set_current_root(list(source.path.parts)) + for obj_name, raw_obj in schemas.items(): # type: str, Dict[Any, Any] self.parse_raw_obj( obj_name, raw_obj, ['components', 'schemas', obj_name] ) diff --git a/tests/data/expected/main/main_subclass_enum/output.py b/tests/data/expected/main/main_subclass_enum/output.py new file mode 100644 index 000000000..c80d03e4a --- /dev/null +++ b/tests/data/expected/main/main_subclass_enum/output.py @@ -0,0 +1,22 @@ +# generated by datamodel-codegen: +# filename: subclass_enum.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class ProcessingStatus(Enum): + COMPLETED = 'COMPLETED' + PENDING = 'PENDING' + FAILED = 'FAILED' + + +class ProcessingTask(BaseModel): + processing_status: Optional[ProcessingStatus] = Field( + 'COMPLETED', title='Status of the task' + ) diff --git a/tests/data/openapi/subclass_enum.json b/tests/data/openapi/subclass_enum.json new file mode 100644 index 000000000..739911d71 --- /dev/null +++ b/tests/data/openapi/subclass_enum.json @@ -0,0 +1,37 @@ +{ + "openapi": "3.0.2", + "components": { + "schemas": { + "ProcessingStatus": { + "title": "ProcessingStatus", + "enum": [ + "COMPLETED", + "PENDING", + "FAILED" + ], + "type": "string", + "description": "The processing status" + }, + "ProcessingTask": { + "title": "ProcessingTask", + "type": "object", + "properties": { + "processing_status": { + "title": "Status of the task", + "allOf": [ + { + "$ref": "#/components/schemas/ProcessingStatus" + } + ], + "default": "COMPLETED" + } + } + }, + } + }, + "info": { + "title": "", + "version": "" + }, + "paths": {} +} \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index 0e3683a1c..35040bc75 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -861,3 +861,25 @@ def test_main_with_exclusive(): with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_subclass_enum(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'subclass_enum.json'), + '--output', + str(output_file), + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == (EXPECTED_MAIN_PATH / 'main_subclass_enum' / 'output.py').read_text() + ) + + with pytest.raises(SystemExit): + main()