From c7ee6c84d2ad08e429b8bbe0c769c2fc90d37a70 Mon Sep 17 00:00:00 2001 From: C2D <50617709+i404788@users.noreply.github.com> Date: Wed, 28 Dec 2022 16:53:24 +0100 Subject: [PATCH] Add collapse root model feature (#933) * Add collapse root model feature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused root type model * Update unittest * ignore coverage * copy field arguments * Update documents Co-authored-by: ferris Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Koudai Aono --- README.md | 5 +- datamodel_code_generator/__init__.py | 2 + datamodel_code_generator/__main__.py | 12 +++++ datamodel_code_generator/parser/base.py | 20 ++++++++ datamodel_code_generator/parser/jsonschema.py | 2 + datamodel_code_generator/parser/openapi.py | 2 + datamodel_code_generator/types.py | 11 ++-- docs/index.md | 5 +- .../main/main_collapse_root_models/output.py | 17 +++++++ .../output.py | 17 +++++++ tests/data/openapi/not_real_string.json | 33 ++++++++++++ tests/test_main.py | 51 +++++++++++++++++++ 12 files changed, 172 insertions(+), 5 deletions(-) mode change 100755 => 100644 datamodel_code_generator/__main__.py create mode 100644 tests/data/expected/main/main_collapse_root_models/output.py create mode 100644 tests/data/expected/main/main_collapse_root_models_field_constraints/output.py create mode 100644 tests/data/openapi/not_real_string.json diff --git a/README.md b/README.md index e2dd39c84..ad3b15665 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL] [--disable-timestamp] [--use-standard-collections] [--use-generic-container-types] [--use-schema-description] [--use-field-description] [--reuse-model] - [--enum-field-as-literal {all,one}] + [--collapse-root-models] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--empty-enum-field-name EMPTY_ENUM_FIELD_NAME] [--special-field-name-prefix SPECIAL_FIELD_NAME_PREFIX] @@ -169,6 +169,9 @@ optional arguments: Use schema description to populate field docstring --reuse-model Re-use models on the field when a module has the model with the same content + --collapse-root-models + Models generated with a root-type field will be + merged into the models using that root-type model --enum-field-as-literal {all,one} Parse enum field as literal. all: all enum field type are Literal. one: field type is Literal when an enum diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index f3817629c..9c26fb6c4 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -246,6 +246,7 @@ def generate( original_field_name_delimiter: Optional[str] = None, use_double_quotes: bool = False, use_union_operator: bool = False, + collapse_root_models: bool = False, special_field_name_prefix: Optional[str] = None, ) -> None: remote_text_cache: DefaultPutDict[str, str] = DefaultPutDict() @@ -372,6 +373,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]: original_field_name_delimiter=original_field_name_delimiter, use_double_quotes=use_double_quotes, use_union_operator=use_union_operator, + collapse_root_models=collapse_root_models, special_field_name_prefix=special_field_name_prefix, **kwargs, ) diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py old mode 100755 new mode 100644 index cb29e40f9..f4773ebfc --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -265,6 +265,16 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover default=None, ) + +arg_parser.add_argument( + "--collapse-root-models", + action='store_true', + default=False, + help="Models generated with a root-type field will be merged" + "into the models using that root-type model", +) + + arg_parser.add_argument( '--enum-field-as-literal', help='Parse enum field as literal. ' @@ -501,6 +511,7 @@ def _validate_use_union_operator(cls, values: Dict[str, Any]) -> Dict[str, Any]: use_non_positive_negative_number_constrained_types: bool = False original_field_name_delimiter: Optional[str] = None use_double_quotes: bool = False + collapse_root_models: bool = False special_field_name_prefix: Optional[str] = None def merge_args(self, args: Namespace) -> None: @@ -647,6 +658,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: use_non_positive_negative_number_constrained_types=config.use_non_positive_negative_number_constrained_types, original_field_name_delimiter=config.original_field_name_delimiter, use_double_quotes=config.use_double_quotes, + collapse_root_models=config.collapse_root_models, use_union_operator=config.use_union_operator, special_field_name_prefix=config.special_field_name_prefix, ) diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 08be25cfc..fd3caf985 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -311,6 +311,7 @@ def __init__( use_double_quotes: bool = False, use_union_operator: bool = False, allow_responses_without_content: bool = False, + collapse_root_models: bool = False, special_field_name_prefix: Optional[str] = None, ): self.data_type_manager: DataTypeManager = data_type_manager_type( @@ -412,6 +413,7 @@ def __init__( ) self.use_double_quotes = use_double_quotes self.allow_responses_without_content = allow_responses_without_content + self.collapse_root_models = collapse_root_models @property def iter_source(self) -> Iterator[Source]: @@ -692,6 +694,24 @@ def parse( for duplicate in duplicates: models.remove(duplicate) + if self.collapse_root_models: + for model in models: + for model_field in model.fields: + reference = model_field.data_type.reference + if reference and isinstance( + reference.source, self.data_model_root_type + ): + # Use root-type as model_field type + root_type_field = reference.source.fields[0] + model_field.data_type.remove_reference() + model_field.data_type = root_type_field.data_type + model_field.data_type.parent = model_field + model_field.extras = root_type_field.extras + model_field.constraints = root_type_field.constraints + + if not reference.children: # pragma: no cover + models.remove(reference.source) + if self.set_default_enum_member: for model in models: for model_field in model.fields: diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 3510dfdb1..c21efbb39 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -345,6 +345,7 @@ def __init__( use_double_quotes: bool = False, use_union_operator: bool = False, allow_responses_without_content: bool = False, + collapse_root_models: bool = False, special_field_name_prefix: Optional[str] = None, ): super().__init__( @@ -397,6 +398,7 @@ def __init__( use_double_quotes=use_double_quotes, use_union_operator=use_union_operator, allow_responses_without_content=allow_responses_without_content, + collapse_root_models=collapse_root_models, special_field_name_prefix=special_field_name_prefix, ) diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index 8802366d6..4ae52d639 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -194,6 +194,7 @@ def __init__( use_double_quotes: bool = False, use_union_operator: bool = False, allow_responses_without_content: bool = False, + collapse_root_models: bool = False, special_field_name_prefix: Optional[str] = None, ): super().__init__( @@ -246,6 +247,7 @@ def __init__( use_double_quotes=use_double_quotes, use_union_operator=use_union_operator, allow_responses_without_content=allow_responses_without_content, + collapse_root_models=collapse_root_models, special_field_name_prefix=special_field_name_prefix, ) self.open_api_scopes: List[OpenAPIScope] = openapi_scopes or [ diff --git a/datamodel_code_generator/types.py b/datamodel_code_generator/types.py index 1aee5eb90..10d00dbd1 100644 --- a/datamodel_code_generator/types.py +++ b/datamodel_code_generator/types.py @@ -133,16 +133,21 @@ def unresolved_types(self) -> FrozenSet[str]: | ({self.reference.path} if self.reference else set()) ) - def replace_reference(self, reference: Reference) -> None: + def replace_reference(self, reference: Optional[Reference]) -> None: if not self.reference: # pragma: no cover raise Exception( f'`{self.__class__.__name__}.replace_reference()` can\'t be called' f' when `reference` field is empty.' ) - self.reference.children.remove(self) + while self in self.reference.children: + self.reference.children.remove(self) self.reference = reference - reference.children.append(self) + if reference: + reference.children.append(self) + + def remove_reference(self) -> None: + self.replace_reference(None) @property def module_name(self) -> Optional[str]: diff --git a/docs/index.md b/docs/index.md index 76a29469d..c16d572d0 100644 --- a/docs/index.md +++ b/docs/index.md @@ -56,7 +56,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL] [--disable-timestamp] [--use-standard-collections] [--use-generic-container-types] [--use-schema-description] [--use-field-description] [--reuse-model] - [--enum-field-as-literal {all,one}] + [--collapse-root-models] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--empty-enum-field-name EMPTY_ENUM_FIELD_NAME] [--special-field-name-prefix SPECIAL_FIELD_NAME_PREFIX] @@ -131,6 +131,9 @@ optional arguments: Use schema description to populate field docstring --reuse-model Re-use models on the field when a module has the model with the same content + --collapse-root-models + Models generated with a root-type field will be + merged into the models using that root-type model --enum-field-as-literal {all,one} Parse enum field as literal. all: all enum field type are Literal. one: field type is Literal when an enum diff --git a/tests/data/expected/main/main_collapse_root_models/output.py b/tests/data/expected/main/main_collapse_root_models/output.py new file mode 100644 index 000000000..0b7ec4f48 --- /dev/null +++ b/tests/data/expected/main/main_collapse_root_models/output.py @@ -0,0 +1,17 @@ +# generated by datamodel-codegen: +# filename: not_real_string.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, constr + + +class Tweet(BaseModel): + author_id: Optional[str] = None + + +class FileRequest(BaseModel): + file_hash: constr(regex=r'^[a-fA-F\d]{32}$', min_length=32, max_length=32) diff --git a/tests/data/expected/main/main_collapse_root_models_field_constraints/output.py b/tests/data/expected/main/main_collapse_root_models_field_constraints/output.py new file mode 100644 index 000000000..db3d9edb8 --- /dev/null +++ b/tests/data/expected/main/main_collapse_root_models_field_constraints/output.py @@ -0,0 +1,17 @@ +# generated by datamodel-codegen: +# filename: not_real_string.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field + + +class Tweet(BaseModel): + author_id: Optional[str] = None + + +class FileRequest(BaseModel): + file_hash: str = Field(..., max_length=32, min_length=32, regex='^[a-fA-F\\d]{32}$') diff --git a/tests/data/openapi/not_real_string.json b/tests/data/openapi/not_real_string.json new file mode 100644 index 000000000..0507d210d --- /dev/null +++ b/tests/data/openapi/not_real_string.json @@ -0,0 +1,33 @@ +{ + "openapi" : "3.0.0", + "components" : { + "schemas" : { + "UserId" : { + "type" : "string" + }, + "Tweet" : { + "type" : "object", + "properties" : { + "author_id" : { + "$ref" : "#/components/schemas/UserId" + } + } + }, + "FileHash": { + "type": "string", + "minLength": 32, + "maxLength": 32, + "pattern": "^[a-fA-F\\d]{32}$" + }, + "FileRequest": { + "type": "object", + "required": ["file_hash"], + "properties": { + "file_hash": { + "$ref": "#/components/schemas/FileHash" + } + } + } + } + } +} \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index f9367bd64..ec5cf1e5c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4080,3 +4080,54 @@ def test_external_relative_ref(): with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_collapse_root_models(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'not_real_string.json'), + '--output', + str(output_file), + "--collapse-root-models", + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH / 'main_collapse_root_models' / 'output.py' + ).read_text() + ) + with pytest.raises(SystemExit): + main() + + +@freeze_time('2019-07-26') +def test_main_collapse_root_models_field_constraints(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'not_real_string.json'), + '--output', + str(output_file), + "--collapse-root-models", + '--field-constraints', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH + / 'main_collapse_root_models_field_constraints' + / 'output.py' + ).read_text() + ) + with pytest.raises(SystemExit): + main()