Skip to content

Commit

Permalink
Support subclass enum (#771)
Browse files Browse the repository at this point in the history
* Support subclass enum

* Add unittest

* Remove unsued import

* Update documents

* Fix condition

* Update unittest

* Update unittest
  • Loading branch information
koxudaxi authored May 27, 2022
1 parent 0f2192e commit 7c7e35d
Show file tree
Hide file tree
Showing 12 changed files with 230 additions and 12 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL]
[--enum-field-as-literal {all,one}]
[--set-default-enum-member]
[--empty-enum-field-name EMPTY_ENUM_FIELD_NAME]
[--use-subclass-enum]
[--class-name CLASS_NAME] [--use-title-as-name]
[--custom-template-dir CUSTOM_TEMPLATE_DIR]
[--extra-template-data EXTRA_TEMPLATE_DATA]
Expand Down Expand Up @@ -161,6 +162,8 @@ optional arguments:
Set enum members as default values for enum field
--empty-enum-field-name EMPTY_ENUM_FIELD_NAME
Set field name when enum value is empty (default: `_`)
--use-subclass-enum Define Enum class as subclass with field type when enum has
type (int, float, bytes, str)
--class-name CLASS_NAME
Set class name of root model
--use-title-as-name use titles as class names of models
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def generate(
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
Expand Down Expand Up @@ -340,6 +341,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]:
reuse_model=reuse_model,
enum_field_as_literal=enum_field_as_literal,
set_default_enum_member=set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
Expand Down
8 changes: 8 additions & 0 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
default=None,
)

arg_parser.add_argument(
'--use-subclass-enum',
help='Define Enum class as subclass with field type when enum has type (int, float, bytes, str)',
action='store_true',
default=False,
)

arg_parser.add_argument(
'--class-name',
Expand Down Expand Up @@ -404,6 +410,7 @@ def _validate_use_annotated(cls, values: Dict[str, Any]) -> Dict[str, Any]:
encoding: str = 'utf-8'
enum_field_as_literal: Optional[LiteralType] = None
set_default_enum_member: bool = False
use_subclass_enum: bool = False
strict_nullable: bool = False
use_generic_container_types: bool = False
enable_faux_immutability: bool = False
Expand Down Expand Up @@ -536,6 +543,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
encoding=config.encoding,
enum_field_as_literal=config.enum_field_as_literal,
set_default_enum_member=config.set_default_enum_member,
use_subclass_enum=config.use_subclass_enum,
strict_nullable=config.strict_nullable,
use_generic_container_types=config.use_generic_container_types,
enable_faux_immutability=config.enable_faux_immutability,
Expand Down
58 changes: 57 additions & 1 deletion datamodel_code_generator/model/enum.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,71 @@
from typing import Any, ClassVar, Optional, Tuple
from pathlib import Path
from typing import Any, ClassVar, DefaultDict, Dict, List, Optional, Tuple

from datamodel_code_generator.imports import IMPORT_ANY, IMPORT_ENUM, Import
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import BaseClassDataType
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import DataType, Types

_INT: str = 'int'
_FLOAT: str = 'float'
_BYTES: str = 'bytes'
_STR: str = 'str'

SUBCLASS_BASE_CLASSES: Dict[Types, str] = {
Types.int32: _INT,
Types.int64: _INT,
Types.integer: _INT,
Types.float: _FLOAT,
Types.double: _FLOAT,
Types.number: _FLOAT,
Types.byte: _BYTES,
Types.string: _STR,
}


class Enum(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = 'Enum.jinja2'
BASE_CLASS: ClassVar[str] = 'enum.Enum'
DEFAULT_IMPORTS: ClassVar[Tuple[Import, ...]] = (IMPORT_ENUM,)

def __init__(
self,
*,
reference: Reference,
fields: List[DataModelFieldBase],
decorators: Optional[List[str]] = None,
base_classes: Optional[List[Reference]] = None,
custom_base_class: Optional[str] = None,
custom_template_dir: Optional[Path] = None,
extra_template_data: Optional[DefaultDict[str, Dict[str, Any]]] = None,
methods: Optional[List[str]] = None,
path: Optional[Path] = None,
description: Optional[str] = None,
type_: Optional[Types] = None,
):

super().__init__(
reference=reference,
fields=fields,
decorators=decorators,
base_classes=base_classes,
custom_base_class=custom_base_class,
custom_template_dir=custom_template_dir,
extra_template_data=extra_template_data,
methods=methods,
path=path,
description=description,
)

if not base_classes and type_:
base_class = SUBCLASS_BASE_CLASSES.get(type_)
if base_class:
self.base_classes: List[BaseClassDataType] = [
BaseClassDataType(type=base_class),
*self.base_classes,
]

@classmethod
def get_data_type(cls, types: Types, **kwargs: Any) -> DataType:
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/template/Enum.jinja2
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{% for decorator in decorators -%}
{{ decorator }}
{% endfor -%}
class {{ class_name }}(Enum):
class {{ class_name }}({{ base_class }}):
{%- if description %}
"""
{{ description }}
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
Expand Down Expand Up @@ -330,6 +331,7 @@ def __init__(
self.encoding: str = encoding
self.enum_field_as_literal: Optional[LiteralType] = enum_field_as_literal
self.set_default_enum_member: bool = set_default_enum_member
self.use_subclass_enum: bool = use_subclass_enum
self.strict_nullable: bool = strict_nullable
self.use_generic_container_types: bool = use_generic_container_types
self.enable_faux_immutability: bool = enable_faux_immutability
Expand Down
32 changes: 22 additions & 10 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,22 @@ def get_ref_type(ref: str) -> JSONReference:
return JSONReference.REMOTE


def _get_type(type_: str, format__: Optional[str] = None) -> Types:
if type_ not in json_schema_data_formats:
return Types.any
data_formats: Optional[Types] = json_schema_data_formats[type_].get(
'default' if format__ is None else format__
)
if data_formats is not None:
return data_formats

warn(
"format of {!r} not understood for {!r} - using default"
"".format(format__, type_)
)
return json_schema_data_formats[type_]['default']


JsonSchemaObject.update_forward_refs()

DEFAULT_FIELD_KEYS: Set[str] = {
Expand Down Expand Up @@ -295,6 +311,7 @@ def __init__(
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
Expand Down Expand Up @@ -339,6 +356,7 @@ def __init__(
encoding=encoding,
enum_field_as_literal=enum_field_as_literal,
set_default_enum_member=set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
Expand Down Expand Up @@ -397,17 +415,8 @@ def get_data_type(self, obj: JsonSchemaObject) -> DataType:
)

def _get_data_type(type_: str, format__: str) -> DataType:
data_formats: Optional[Types] = json_schema_data_formats[type_].get(
format__
)
if data_formats is None:
warn(
"format of {!r} not understood for {!r} - using default"
"".format(format__, type_)
)
data_formats = json_schema_data_formats[type_]['default']
return self.data_type_manager.get_data_type(
data_formats,
_get_type(type_, format__),
**obj.dict() if not self.field_constraints else {},
)

Expand Down Expand Up @@ -965,6 +974,9 @@ def create_enum(reference_: Reference) -> DataType:
path=self.current_source_path,
description=obj.description if self.use_schema_description else None,
custom_template_dir=self.custom_template_dir,
type_=_get_type(obj.type, obj.format)
if self.use_subclass_enum and isinstance(obj.type, str)
else None,
)
self.results.append(enum)
return self.data_type(reference=reference_)
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
encoding: str = 'utf-8',
enum_field_as_literal: Optional[LiteralType] = None,
set_default_enum_member: bool = False,
use_subclass_enum: bool = False,
strict_nullable: bool = False,
use_generic_container_types: bool = False,
enable_faux_immutability: bool = False,
Expand Down Expand Up @@ -210,6 +211,7 @@ def __init__(
encoding=encoding,
enum_field_as_literal=enum_field_as_literal,
set_default_enum_member=set_default_enum_member,
use_subclass_enum=use_subclass_enum,
strict_nullable=strict_nullable,
use_generic_container_types=use_generic_container_types,
enable_faux_immutability=enable_faux_immutability,
Expand Down
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--url URL]
[--enum-field-as-literal {all,one}]
[--set-default-enum-member]
[--empty-enum-field-name EMPTY_ENUM_FIELD_NAME]
[--use-subclass-enum]
[--class-name CLASS_NAME] [--use-title-as-name]
[--custom-template-dir CUSTOM_TEMPLATE_DIR]
[--extra-template-data EXTRA_TEMPLATE_DATA]
Expand Down Expand Up @@ -127,6 +128,8 @@ optional arguments:
Set enum members as default values for enum field
--empty-enum-field-name EMPTY_ENUM_FIELD_NAME
Set field name when enum value is empty (default: `_`)
--use-subclass-enum Define Enum class as subclass with field type when enum has
type (int, float, bytes, str)
--class-name CLASS_NAME
Set class name of root model
--use-title-as-name use titles as class names of models
Expand Down
53 changes: 53 additions & 0 deletions tests/data/expected/main/main_jsonschema_subclass_enum/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# generated by datamodel-codegen:
# filename: subclass_enum.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import Optional

from pydantic import BaseModel


class IntEnum(int, Enum):
integer_1 = 1
integer_2 = 2
integer_3 = 3


class FloatEnum(float, Enum):
number_1_1 = 1.1
number_2_1 = 2.1
number_3_1 = 3.1


class StrEnum(str, Enum):
field_1 = '1'
field_2 = '2'
field_3 = '3'


class NonTypedEnum(Enum):
field_1 = '1'
field_2 = '2'
field_3 = '3'


class BooleanEnum(Enum):
boolean_True = True
boolean_False = False


class UnknownEnum(Enum):
a = 'a'
b = 'b'


class Model(BaseModel):
IntEnum: Optional[IntEnum] = None
FloatEnum: Optional[FloatEnum] = None
StrEnum: Optional[StrEnum] = None
NonTypedEnum: Optional[NonTypedEnum] = None
BooleanEnum: Optional[BooleanEnum] = None
UnknownEnum: Optional[UnknownEnum] = None
51 changes: 51 additions & 0 deletions tests/data/jsonschema/subclass_enum.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"IntEnum": {
"type": "integer",
"enum": [
1,
2,
3
]
},
"FloatEnum": {
"type": "number",
"enum": [
1.1,
2.1,
3.1
]
},
"StrEnum": {
"type": "string",
"enum": [
"1",
"2",
"3"
]
},
"NonTypedEnum": {
"enum": [
"1",
"2",
"3"
]
},
"BooleanEnum": {
"type": "boolean",
"enum": [
true,
false
]
},
"UnknownEnum": {
"type": "unknown",
"enum": [
"a",
"b"
]
}
}
}
Loading

0 comments on commit 7c7e35d

Please sign in to comment.