Skip to content

Commit

Permalink
fix nested enum (#303)
Browse files Browse the repository at this point in the history
* fix nested enum

* fix coverage
  • Loading branch information
koxudaxi authored Jan 15, 2021
1 parent 0bc7e8a commit 4c9d215
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 34 deletions.
66 changes: 33 additions & 33 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,7 @@ def parse_list_item(
data_types.append(
self.data_type.from_model_name(
self.parse_object(
name,
item,
[*path, str(index)],
singular_name=True,
unique=True,
name, item, [*path, str(index)], singular_name=True,
).name,
)
)
Expand Down Expand Up @@ -437,9 +433,7 @@ def parse_object_fields(
elif field.is_object:
if field.properties:
field_type = self.data_type.from_model_name(
self.parse_object(
field_name, field, [*path, field_name], unique=True
).name
self.parse_object(field_name, field, [*path, field_name]).name
)

elif isinstance(field.additionalProperties, JsonSchemaObject):
Expand Down Expand Up @@ -481,7 +475,6 @@ def parse_object_fields(
if field.additionalProperties.ref
or field.additionalProperties.is_object
else field,
unique=True,
).name

field_type = self.data_type(
Expand All @@ -493,9 +486,7 @@ def parse_object_fields(
else:
field_type = self.data_type_manager.get_data_type(Types.object)
elif field.enum:
enum = self.parse_enum(
field_name, field, [*path, field_name], unique=True
)
enum = self.parse_enum(field_name, field, [*path, field_name])
field_type = self.data_type.from_model_name(enum.name)
else:
field_type = self.get_data_type(field)
Expand Down Expand Up @@ -529,11 +520,17 @@ def parse_object(
obj: JsonSchemaObject,
path: List[str],
singular_name: bool = False,
unique: bool = False,
unique: bool = True,
additional_properties: Optional[JsonSchemaObject] = None,
) -> DataModel:
if not unique: # pragma: no cover
warn(
f'{self.__class__.__name__}.parse_object() ignore `unique` argument.'
f'An object name must be unique.'
f'This argument will be removed in a future version'
)
class_name = self.model_resolver.add(
path, name, class_name=True, singular_name=singular_name, unique=unique
path, name, class_name=True, singular_name=singular_name, unique=True
).name
self.set_title(class_name, obj)
self.set_additional_properties(class_name, additional_properties or obj)
Expand All @@ -558,17 +555,14 @@ def parse_array_fields(
items = obj.items or []
item_obj_data_types: List[DataType] = []
for index, item in enumerate(items):
field_path = [*path, str(index)]
if item.has_constraint:
model = self.model_resolver.add(
[*path, f'items{index}'],
name,
class_name=True,
singular_name=True,
unique=True,
field_path, name, class_name=True, singular_name=True, unique=True,
)
item_obj_data_types.append(
self.data_type.from_model_name(
self.parse_root_type(model.name, item, path,).name
self.parse_root_type(model.name, item, field_path,).name
)
)
elif item.ref:
Expand All @@ -577,35 +571,35 @@ def parse_array_fields(
item_obj_data_types.append(
self.data_type.from_model_name(
self.parse_object(
name,
item,
[*path, str(index)],
singular_name=True,
unique=True,
name, item, field_path, singular_name=True,
).name,
)
)
elif item.anyOf:
item_obj_data_types.append(self.parse_any_of(name, item, path))
item_obj_data_types.append(self.parse_any_of(name, item, field_path))
elif item.allOf:
item_obj_data_types.append(
self.parse_all_of(
self.model_resolver.add(path, name, singular_name=True).name,
self.model_resolver.add(
field_path, name, singular_name=True
).name,
item,
path,
field_path,
)
)
elif item.enum:
item_obj_data_types.append(
self.data_type.from_model_name(
self.parse_enum(name, item, path, singular_name=True).name,
self.parse_enum(
name, item, field_path, singular_name=True
).name,
)
)
elif item.is_array:
array_field = self.parse_array_fields(
self.model_resolver.add(path, name, class_name=True).name,
self.model_resolver.add(field_path, name, class_name=True).name,
item,
path,
field_path,
)
item_obj_data_types.append(array_field.data_type)
else:
Expand Down Expand Up @@ -690,8 +684,14 @@ def parse_enum(
obj: JsonSchemaObject,
path: List[str],
singular_name: bool = False,
unique: bool = False,
unique: bool = True,
) -> DataModel:
if not unique: # pragma: no cover
warn(
f'{self.__class__.__name__}.parse_enum() ignore `unique` argument.'
f'An object name must be unique.'
f'This argument will be removed in a future version'
)
enum_fields: List[DataModelFieldBase] = []

for i, enum_part in enumerate(obj.enum):
Expand Down Expand Up @@ -720,7 +720,7 @@ def parse_enum(
class_name=True,
singular_name=singular_name,
singular_name_suffix='Enum',
unique=unique,
unique=True,
).name
enum = Enum(
enum_name,
Expand Down
15 changes: 14 additions & 1 deletion tests/data/expected/main/main_similar_nested_array/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from enum import Enum
from typing import Any, List, Optional, Union

from pydantic import BaseModel
Expand Down Expand Up @@ -49,5 +50,17 @@ class KeyCItem3(BaseModel):
nestedB: Optional[str] = None


class KeyCEnum(Enum):
dog = 'dog'
cat = 'cat'
snake = 'snake'


class KeyCEnum1(Enum):
orange = 'orange'
apple = 'apple'
milk = 'milk'


class ObjectD(BaseModel):
keyC: Optional[List[Union[KeyCItem2, KeyCItem3]]] = None
keyC: Optional[List[Union[KeyCItem2, KeyCItem3, KeyCEnum, KeyCEnum1]]] = None
8 changes: 8 additions & 0 deletions tests/data/jsonschema/similar_nested_array.json
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@
"type": "string"
}
}
},
{
"type": "string",
"enum": ["dog", "cat", "snake"]
},
{
"type": "string",
"enum": ["orange", "apple", "milk"]
}
]
}
Expand Down

0 comments on commit 4c9d215

Please sign in to comment.