diff --git a/README.md b/README.md index 8ea04bb16..9fb389e33 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--empty-enum-field-name EMPTY_ENUM_FIELD_NAME] + [--use-subclass-enum] [--class-name CLASS_NAME] [--use-title-as-name] [--custom-template-dir CUSTOM_TEMPLATE_DIR] [--extra-template-data EXTRA_TEMPLATE_DATA] @@ -161,6 +162,8 @@ optional arguments: Set enum members as default values for enum field --empty-enum-field-name EMPTY_ENUM_FIELD_NAME Set field name when enum value is empty (default: `_`) + --use-subclass-enum Define Enum class as subclass with field type when enum has + type (int, float, bytes, str) --class-name CLASS_NAME Set class name of root model --use-title-as-name use titles as class names of models diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index 07a1066aa..1b1592b1c 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -221,6 +221,7 @@ def generate( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + use_subclass_enum: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, @@ -340,6 +341,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]: reuse_model=reuse_model, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, + use_subclass_enum=use_subclass_enum, strict_nullable=strict_nullable, use_generic_container_types=use_generic_container_types, enable_faux_immutability=enable_faux_immutability, diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index a3558335b..92ed73e2c 100755 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -258,6 +258,12 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover default=None, ) +arg_parser.add_argument( + '--use-subclass-enum', + help='Define Enum class as subclass with field type when enum has type (int, float, bytes, str)', + action='store_true', + default=False, +) arg_parser.add_argument( '--class-name', @@ -404,6 +410,7 @@ def _validate_use_annotated(cls, values: Dict[str, Any]) -> Dict[str, Any]: encoding: str = 'utf-8' enum_field_as_literal: Optional[LiteralType] = None set_default_enum_member: bool = False + use_subclass_enum: bool = False strict_nullable: bool = False use_generic_container_types: bool = False enable_faux_immutability: bool = False @@ -536,6 +543,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: encoding=config.encoding, enum_field_as_literal=config.enum_field_as_literal, set_default_enum_member=config.set_default_enum_member, + use_subclass_enum=config.use_subclass_enum, strict_nullable=config.strict_nullable, use_generic_container_types=config.use_generic_container_types, enable_faux_immutability=config.enable_faux_immutability, diff --git a/datamodel_code_generator/model/enum.py b/datamodel_code_generator/model/enum.py index 45102fdf1..f446d716b 100644 --- a/datamodel_code_generator/model/enum.py +++ b/datamodel_code_generator/model/enum.py @@ -1,15 +1,71 @@ -from typing import Any, ClassVar, Optional, Tuple +from pathlib import Path +from typing import Any, ClassVar, DefaultDict, Dict, List, Optional, Tuple from datamodel_code_generator.imports import IMPORT_ANY, IMPORT_ENUM, Import from datamodel_code_generator.model import DataModel, DataModelFieldBase +from datamodel_code_generator.model.base import BaseClassDataType +from datamodel_code_generator.reference import Reference from datamodel_code_generator.types import DataType, Types +_INT: str = 'int' +_FLOAT: str = 'float' +_BYTES: str = 'bytes' +_STR: str = 'str' + +SUBCLASS_BASE_CLASSES: Dict[Types, str] = { + Types.int32: _INT, + Types.int64: _INT, + Types.integer: _INT, + Types.float: _FLOAT, + Types.double: _FLOAT, + Types.number: _FLOAT, + Types.byte: _BYTES, + Types.string: _STR, +} + class Enum(DataModel): TEMPLATE_FILE_PATH: ClassVar[str] = 'Enum.jinja2' BASE_CLASS: ClassVar[str] = 'enum.Enum' DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = (IMPORT_ENUM,) + def __init__( + self, + *, + reference: Reference, + fields: List[DataModelFieldBase], + decorators: Optional[List[str]] = None, + base_classes: Optional[List[Reference]] = None, + custom_base_class: Optional[str] = None, + custom_template_dir: Optional[Path] = None, + extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None, + methods: Optional[List[str]] = None, + path: Optional[Path] = None, + description: Optional[str] = None, + type_: Optional[Types] = None, + ): + + super().__init__( + reference=reference, + fields=fields, + decorators=decorators, + base_classes=base_classes, + custom_base_class=custom_base_class, + custom_template_dir=custom_template_dir, + extra_template_data=extra_template_data, + methods=methods, + path=path, + description=description, + ) + + if not base_classes and type_: + base_class = SUBCLASS_BASE_CLASSES.get(type_) + if base_class: + self.base_classes: List[BaseClassDataType] = [ + BaseClassDataType(type=base_class), + *self.base_classes, + ] + @classmethod def get_data_type(cls, types: Types, **kwargs: Any) -> DataType: raise NotImplementedError diff --git a/datamodel_code_generator/model/template/Enum.jinja2 b/datamodel_code_generator/model/template/Enum.jinja2 index 503443462..39458683e 100644 --- a/datamodel_code_generator/model/template/Enum.jinja2 +++ b/datamodel_code_generator/model/template/Enum.jinja2 @@ -1,7 +1,7 @@ {% for decorator in decorators -%} {{ decorator }} {% endfor -%} -class {{ class_name }}(Enum): +class {{ class_name }}({{ base_class }}): {%- if description %} """ {{ description }} diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 692c67b77..7889fa04c 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -280,6 +280,7 @@ def __init__( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + use_subclass_enum: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, @@ -330,6 +331,7 @@ def __init__( self.encoding: str = encoding self.enum_field_as_literal: Optional[LiteralType] = enum_field_as_literal self.set_default_enum_member: bool = set_default_enum_member + self.use_subclass_enum: bool = use_subclass_enum self.strict_nullable: bool = strict_nullable self.use_generic_container_types: bool = use_generic_container_types self.enable_faux_immutability: bool = enable_faux_immutability diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 4dfbb301a..9df64ff8c 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -248,6 +248,22 @@ def get_ref_type(ref: str) -> JSONReference: return JSONReference.REMOTE +def _get_type(type_: str, format__: Optional[str] = None) -> Types: + if type_ not in json_schema_data_formats: + return Types.any + data_formats: Optional[Types] = json_schema_data_formats[type_].get( + 'default' if format__ is None else format__ + ) + if data_formats is not None: + return data_formats + + warn( + "format of {!r} not understood for {!r} - using default" + "".format(format__, type_) + ) + return json_schema_data_formats[type_]['default'] + + JsonSchemaObject.update_forward_refs() DEFAULT_FIELD_KEYS: Set[str] = { @@ -295,6 +311,7 @@ def __init__( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + use_subclass_enum: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, @@ -339,6 +356,7 @@ def __init__( encoding=encoding, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, + use_subclass_enum=use_subclass_enum, strict_nullable=strict_nullable, use_generic_container_types=use_generic_container_types, enable_faux_immutability=enable_faux_immutability, @@ -397,17 +415,8 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType: ) def _get_data_type(type_: str, format__: str) -> DataType: - data_formats: Optional[Types] = json_schema_data_formats[type_].get( - format__ - ) - if data_formats is None: - warn( - "format of {!r} not understood for {!r} - using default" - "".format(format__, type_) - ) - data_formats = json_schema_data_formats[type_]['default'] return self.data_type_manager.get_data_type( - data_formats, + _get_type(type_, format__), **obj.dict() if not self.field_constraints else {}, ) @@ -965,6 +974,9 @@ def create_enum(reference_: Reference) -> DataType: path=self.current_source_path, description=obj.description if self.use_schema_description else None, custom_template_dir=self.custom_template_dir, + type_=_get_type(obj.type, obj.format) + if self.use_subclass_enum and isinstance(obj.type, str) + else None, ) self.results.append(enum) return self.data_type(reference=reference_) diff --git a/datamodel_code_generator/parser/openapi.py b/datamodel_code_generator/parser/openapi.py index 95e193606..e209bc8ff 100644 --- a/datamodel_code_generator/parser/openapi.py +++ b/datamodel_code_generator/parser/openapi.py @@ -165,6 +165,7 @@ def __init__( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + use_subclass_enum: bool = False, strict_nullable: bool = False, use_generic_container_types: bool = False, enable_faux_immutability: bool = False, @@ -210,6 +211,7 @@ def __init__( encoding=encoding, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, + use_subclass_enum=use_subclass_enum, strict_nullable=strict_nullable, use_generic_container_types=use_generic_container_types, enable_faux_immutability=enable_faux_immutability, diff --git a/docs/index.md b/docs/index.md index 1ae636516..c0a5f5f23 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--empty-enum-field-name EMPTY_ENUM_FIELD_NAME] + [--use-subclass-enum] [--class-name CLASS_NAME] [--use-title-as-name] [--custom-template-dir CUSTOM_TEMPLATE_DIR] [--extra-template-data EXTRA_TEMPLATE_DATA] @@ -127,6 +128,8 @@ optional arguments: Set enum members as default values for enum field --empty-enum-field-name EMPTY_ENUM_FIELD_NAME Set field name when enum value is empty (default: `_`) + --use-subclass-enum Define Enum class as subclass with field type when enum has + type (int, float, bytes, str) --class-name CLASS_NAME Set class name of root model --use-title-as-name use titles as class names of models diff --git a/tests/data/expected/main/main_jsonschema_subclass_enum/output.py b/tests/data/expected/main/main_jsonschema_subclass_enum/output.py new file mode 100644 index 000000000..b056f234b --- /dev/null +++ b/tests/data/expected/main/main_jsonschema_subclass_enum/output.py @@ -0,0 +1,53 @@ +# generated by datamodel-codegen: +# filename: subclass_enum.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + + +class IntEnum(int, Enum): + integer_1 = 1 + integer_2 = 2 + integer_3 = 3 + + +class FloatEnum(float, Enum): + number_1_1 = 1.1 + number_2_1 = 2.1 + number_3_1 = 3.1 + + +class StrEnum(str, Enum): + field_1 = '1' + field_2 = '2' + field_3 = '3' + + +class NonTypedEnum(Enum): + field_1 = '1' + field_2 = '2' + field_3 = '3' + + +class BooleanEnum(Enum): + boolean_True = True + boolean_False = False + + +class UnknownEnum(Enum): + a = 'a' + b = 'b' + + +class Model(BaseModel): + IntEnum: Optional[IntEnum] = None + FloatEnum: Optional[FloatEnum] = None + StrEnum: Optional[StrEnum] = None + NonTypedEnum: Optional[NonTypedEnum] = None + BooleanEnum: Optional[BooleanEnum] = None + UnknownEnum: Optional[UnknownEnum] = None diff --git a/tests/data/jsonschema/subclass_enum.json b/tests/data/jsonschema/subclass_enum.json new file mode 100644 index 000000000..7a5c2e76e --- /dev/null +++ b/tests/data/jsonschema/subclass_enum.json @@ -0,0 +1,51 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "IntEnum": { + "type": "integer", + "enum": [ + 1, + 2, + 3 + ] + }, + "FloatEnum": { + "type": "number", + "enum": [ + 1.1, + 2.1, + 3.1 + ] + }, + "StrEnum": { + "type": "string", + "enum": [ + "1", + "2", + "3" + ] + }, + "NonTypedEnum": { + "enum": [ + "1", + "2", + "3" + ] + }, + "BooleanEnum": { + "type": "boolean", + "enum": [ + true, + false + ] + }, + "UnknownEnum": { + "type": "unknown", + "enum": [ + "a", + "b" + ] + } + } +} \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index e6e6f6660..4ea0fd011 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2895,6 +2895,32 @@ def test_main_jsonschema_special_enum(): main() +@freeze_time('2019-07-26') +def test_main_jsonschema_subclass_enum(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'subclass_enum.json'), + '--output', + str(output_file), + '--input-file-type', + 'jsonschema', + '--use-subclass-enum', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH / 'main_jsonschema_subclass_enum' / 'output.py' + ).read_text() + ) + with pytest.raises(SystemExit): + main() + + @freeze_time('2019-07-26') def test_main_jsonschema_special_enum_empty_enum_field_name(): with TemporaryDirectory() as output_dir: