Skip to content

Commit

Permalink
support custom class name (#257)
Browse files Browse the repository at this point in the history
* support custom class name

* improve coverage

* update document
  • Loading branch information
koxudaxi authored Nov 11, 2020
1 parent e8fc364 commit f8d2e31
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ usage: datamodel-codegen [-h] [--input INPUT]
[--strip-default-none]
[--allow-population-by-field-name] [--use-default]
[--force-optional] [--disable-timestamp]
[--class-name CLASS_NAME]
[--custom-template-dir CUSTOM_TEMPLATE_DIR]
[--extra-template-data EXTRA_TEMPLATE_DATA]
[--aliases ALIASES]
Expand All @@ -94,6 +95,8 @@ optional arguments:
--use-default Use default value even if a field is required
--force-optional Force optional for required fields
--disable-timestamp Disable timestamp on file headers
--class-name CLASS_NAME
Set class name of root model
--custom-template-dir CUSTOM_TEMPLATE_DIR
Custom template directory
--extra-template-data EXTRA_TEMPLATE_DATA
Expand Down
9 changes: 9 additions & 0 deletions datamodel_code_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def __str__(self) -> str:
return self.message


class InvalidClassNameError(Error):
def __init__(self, class_name: str) -> None:
self.class_name = class_name
message = f'title={repr(class_name)} is invalid class name.'
super().__init__(message=message)


def get_first_file(path: Path) -> Path: # pragma: no cover
if path.is_file():
return path
Expand Down Expand Up @@ -141,6 +148,7 @@ def generate(
allow_population_by_field_name: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: Optional[str] = None,
) -> None:
input_text: Optional[str] = None
if input_file_type == InputFileType.Auto:
Expand Down Expand Up @@ -196,6 +204,7 @@ def generate(
allow_population_by_field_name=allow_population_by_field_name,
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
force_optional_for_required_fields=force_optional_for_required_fields,
class_name=class_name,
)

with chdir(output):
Expand Down
10 changes: 10 additions & 0 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DEFAULT_BASE_CLASS,
Error,
InputFileType,
InvalidClassNameError,
enable_debug_message,
generate,
)
Expand Down Expand Up @@ -106,6 +107,10 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover
default=None,
)

arg_parser.add_argument(
'--class-name', help='Set class name of root model', default=None,
)

arg_parser.add_argument(
'--custom-template-dir', help='Custom template directory', type=str
)
Expand Down Expand Up @@ -164,6 +169,7 @@ def validate_path(cls, value: Any) -> Optional[Path]:
allow_population_by_field_name: bool = False
use_default: bool = False
force_optional: bool = False
class_name: Optional[str] = None

def merge_args(self, args: Namespace) -> None:
for field_name in self.__fields__:
Expand Down Expand Up @@ -258,8 +264,12 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
allow_population_by_field_name=config.allow_population_by_field_name,
apply_default_values_for_required_fields=config.use_default,
force_optional_for_required_fields=config.force_optional,
class_name=config.class_name,
)
return Exit.OK
except InvalidClassNameError as e:
print(f'{e} You have to set --class-name option', file=sys.stderr)
return Exit.ERROR
except Error as e:
print(str(e), file=sys.stderr)
return Exit.ERROR
Expand Down
2 changes: 2 additions & 0 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
allow_population_by_field_name: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: Optional[str] = None,
):
self.data_type_manager: DataTypeManager = data_type_manager_type(
target_python_version
Expand Down Expand Up @@ -210,6 +211,7 @@ def __init__(
self.field_preprocessors.append(snakify_field)
if self.strip_default_none:
self.field_preprocessors.append(set_strip_default_none)
self.class_name: Optional[str] = class_name

@property
def iter_source(self) -> Iterator[Source]:
Expand Down
13 changes: 11 additions & 2 deletions datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import yaml
from pydantic import BaseModel, Field, root_validator, validator

from datamodel_code_generator import snooper_to_methods
from datamodel_code_generator import Error, InvalidClassNameError, snooper_to_methods
from datamodel_code_generator.format import PythonVersion
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.enum import Enum
Expand Down Expand Up @@ -193,6 +193,7 @@ def __init__(
allow_population_by_field_name: bool = False,
apply_default_values_for_required_fields: bool = False,
force_optional_for_required_fields: bool = False,
class_name: Optional[str] = None,
):
super().__init__(
source=source,
Expand All @@ -213,6 +214,7 @@ def __init__(
allow_population_by_field_name=allow_population_by_field_name,
apply_default_values_for_required_fields=apply_default_values_for_required_fields,
force_optional_for_required_fields=force_optional_for_required_fields,
class_name=class_name,
)

self.remote_object_cache: Dict[str, Dict[str, Any]] = {}
Expand Down Expand Up @@ -744,7 +746,14 @@ def parse_raw(self) -> None:
path_parts = list(source.path.parts)
self.model_resolver.set_current_root(path_parts)
self.raw_obj = yaml.safe_load(source.text)
obj_name = self.raw_obj.get('title', 'Model')
if self.class_name:
obj_name = self.class_name
else:
# backward compatible
obj_name = self.raw_obj.get('title', 'Model')
if not self.model_resolver.validate_name(obj_name):
raise InvalidClassNameError(obj_name)

obj_name = self.model_resolver.add(path_parts, obj_name, unique=False).name
self.parse_raw_obj(obj_name, self.raw_obj, path_parts)
definitions = self.raw_obj.get('definitions', {})
Expand Down
5 changes: 5 additions & 0 deletions datamodel_code_generator/reference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from keyword import iskeyword
from pathlib import Path
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -139,6 +140,10 @@ def _get_uniq_name(self, name: str, camel: bool = False) -> str:
count += 1
return uniq_name

@classmethod
def validate_name(cls, name: str) -> bool:
return name.isidentifier() and not iskeyword(name)

def get_valid_name(self, name: str, camel: bool = False) -> str:
# TODO: when first character is a number
replaced_name = re.sub(r'\W', '_', name)
Expand Down
3 changes: 3 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ usage: datamodel-codegen [-h] [--input INPUT]
[--strip-default-none]
[--allow-population-by-field-name] [--use-default]
[--force-optional] [--disable-timestamp]
[--class-name CLASS_NAME]
[--custom-template-dir CUSTOM_TEMPLATE_DIR]
[--extra-template-data EXTRA_TEMPLATE_DATA]
[--aliases ALIASES]
Expand All @@ -61,6 +62,8 @@ optional arguments:
--use-default Use default value even if a field is required
--force-optional Force optional for required fields
--disable-timestamp Disable timestamp on file headers
--class-name CLASS_NAME
Set class name of root model
--custom-template-dir CUSTOM_TEMPLATE_DIR
Custom template directory
--extra-template-data EXTRA_TEMPLATE_DATA
Expand Down
19 changes: 19 additions & 0 deletions tests/data/expected/main/main_invalid_model_name/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# generated by datamodel-codegen:
# filename: invalid_model_name.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Any, List, Optional

from pydantic import BaseModel, Field, conint


class ValidModelName(BaseModel):
firstName: Optional[str] = Field(None, description="The person's first name.")
lastName: Optional[str] = Field(None, description="The person's last name.")
age: Optional[conint(ge=0)] = Field(
None, description='Age in years which must be equal to or greater than zero.'
)
friends: Optional[List] = None
comment: Optional[Any] = None
27 changes: 27 additions & 0 deletions tests/data/jsonschema/invalid_model_name.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"$id": "https://example.com/person.schema.json",
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "1 xyz",
"type": "object",
"properties": {
"firstName": {
"type": "string",
"description": "The person's first name."
},
"lastName": {
"type": "string",
"description": "The person's last name."
},
"age": {
"description": "Age in years which must be equal to or greater than zero.",
"type": "integer",
"minimum": 0
},
"friends": {
"type": "array"
},
"comment": {
"type": "null"
}
}
}
49 changes: 49 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,3 +883,52 @@ def test_main_subclass_enum():

with pytest.raises(SystemExit):
main()


@freeze_time('2019-07-26')
def test_main_invalid_model_name_failed(capsys):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'invalid_model_name.json'),
'--output',
str(output_file),
'--input-file-type',
'jsonschema',
]
)
captured = capsys.readouterr()
assert return_code == Exit.ERROR
assert (
captured.err
== 'title=\'1 xyz\' is invalid class name. You have to set --class-name option\n'
)


@freeze_time('2019-07-26')
def test_main_invalid_model_name():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'invalid_model_name.json'),
'--output',
str(output_file),
'--input-file-type',
'jsonschema',
'--class-name',
'ValidModelName',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_MAIN_PATH / 'main_invalid_model_name' / 'output.py'
).read_text()
)
with pytest.raises(SystemExit):
main()

0 comments on commit f8d2e31

Please sign in to comment.