From fc1c25756a0b41ff68ce541c644456363cd1418b Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Sun, 17 Apr 2022 01:59:51 +0900 Subject: [PATCH] Fix nested Enum (#747) --- datamodel_code_generator/parser/base.py | 22 +++++++ .../expected/main/main_nested_enum/output.py | 32 ++++++++++ tests/data/openapi/nested_enum.json | 63 +++++++++++++++++++ tests/test_main.py | 23 +++++++ 4 files changed, 140 insertions(+) create mode 100644 tests/data/expected/main/main_nested_enum/output.py create mode 100644 tests/data/openapi/nested_enum.json diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index d0e741223..9d0b3f324 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -571,6 +571,28 @@ def parse( from_ += "." imports.append(Import(from_=from_, import_=import_, alias=alias)) + # extract inherited enum + for model in models: + if model.fields: + continue + enums: List[Enum] = [] + for base_model in model.base_classes: + if not base_model.reference: + continue + source_model = base_model.reference.source + if isinstance(source_model, Enum): + enums.append(source_model) + if enums: + models.insert( + models.index(model), + enums[0].__class__( + fields=[f for e in enums for f in e.fields], + description=model.description, + reference=model.reference, + ), + ) + models.remove(model) + if self.reuse_model: model_cache: Dict[Tuple[str, ...], Reference] = {} duplicates = [] diff --git a/tests/data/expected/main/main_nested_enum/output.py b/tests/data/expected/main/main_nested_enum/output.py new file mode 100644 index 000000000..3073813fa --- /dev/null +++ b/tests/data/expected/main/main_nested_enum/output.py @@ -0,0 +1,32 @@ +# generated by datamodel-codegen: +# filename: nested_enum.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from enum import Enum + +from pydantic import BaseModel + + +class State(Enum): + field_1 = '1' + field_2 = '2' + + +class NestedState1(Enum): + field_1 = '1' + field_2 = '2' + + +class NestedState2(Enum): + field_1 = '1' + field_2 = '2' + + +class Result1(BaseModel): + state: NestedState1 + + +class Result2(BaseModel): + state: NestedState2 diff --git a/tests/data/openapi/nested_enum.json b/tests/data/openapi/nested_enum.json new file mode 100644 index 000000000..2bf582bb0 --- /dev/null +++ b/tests/data/openapi/nested_enum.json @@ -0,0 +1,63 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0" + }, + "paths": {}, + "components": { + "schemas": { + "Result1": { + "type": "object", + "description": "description for Result1", + "properties": { + "state": { + "$ref": "#/components/schemas/NestedState1" + } + }, + "required": [ + "state" + ] + }, + "Result2": { + "type": "object", + "description": "description for Result2", + "properties": { + "state": { + "$ref": "#/components/schemas/NestedState2" + } + }, + "required": [ + "state" + ] + }, + "NestedState1": { + "allOf": [ + { + "description": "description for NestedState1" + }, + { + "$ref": "#/components/schemas/State" + } + ] + }, + "NestedState2": { + "allOf": [ + { + "description": "description for NestedState2" + }, + { + "$ref": "#/components/schemas/State" + } + ] + }, + "State": { + "type": "string", + "enum": [ + "1", + "2" + ] + } + } + } +} \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index 608a18e6d..af616f9ca 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3613,3 +3613,26 @@ def test_main_use_annotated_with_field_constraints(): with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_nested_enum(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'nested_enum.json'), + '--output', + str(output_file), + '--input-file-type', + 'openapi', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == (EXPECTED_MAIN_PATH / 'main_nested_enum' / 'output.py').read_text() + ) + with pytest.raises(SystemExit): + main()