diff --git a/README.md b/README.md index a75575f67..a54b69856 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--strip-default-none] [--allow-population-by-field-name] [--use-default] [--force-optional] [--disable-timestamp] + [--class-name CLASS_NAME] [--custom-template-dir CUSTOM_TEMPLATE_DIR] [--extra-template-data EXTRA_TEMPLATE_DATA] [--aliases ALIASES] @@ -94,6 +95,8 @@ optional arguments: --use-default Use default value even if a field is required --force-optional Force optional for required fields --disable-timestamp Disable timestamp on file headers + --class-name CLASS_NAME + Set class name of root model --custom-template-dir CUSTOM_TEMPLATE_DIR Custom template directory --extra-template-data EXTRA_TEMPLATE_DATA diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index 4781a7e1b..ceb88d11d 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -112,6 +112,13 @@ def __str__(self) -> str: return self.message +class InvalidClassNameError(Error): + def __init__(self, class_name: str) -> None: + self.class_name = class_name + message = f'title={repr(class_name)} is invalid class name.' + super().__init__(message=message) + + def get_first_file(path: Path) -> Path: # pragma: no cover if path.is_file(): return path @@ -141,6 +148,7 @@ def generate( allow_population_by_field_name: bool = False, apply_default_values_for_required_fields: bool = False, force_optional_for_required_fields: bool = False, + class_name: Optional[str] = None, ) -> None: input_text: Optional[str] = None if input_file_type == InputFileType.Auto: @@ -196,6 +204,7 @@ def generate( allow_population_by_field_name=allow_population_by_field_name, apply_default_values_for_required_fields=apply_default_values_for_required_fields, force_optional_for_required_fields=force_optional_for_required_fields, + class_name=class_name, ) with chdir(output): diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index fcd3bc5f1..6accc6c78 100755 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -23,6 +23,7 @@ DEFAULT_BASE_CLASS, Error, InputFileType, + InvalidClassNameError, enable_debug_message, generate, ) @@ -106,6 +107,10 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover default=None, ) +arg_parser.add_argument( + '--class-name', help='Set class name of root model', default=None, +) + arg_parser.add_argument( '--custom-template-dir', help='Custom template directory', type=str ) @@ -164,6 +169,7 @@ def validate_path(cls, value: Any) -> Optional[Path]: allow_population_by_field_name: bool = False use_default: bool = False force_optional: bool = False + class_name: Optional[str] = None def merge_args(self, args: Namespace) -> None: for field_name in self.__fields__: @@ -258,8 +264,12 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: allow_population_by_field_name=config.allow_population_by_field_name, apply_default_values_for_required_fields=config.use_default, force_optional_for_required_fields=config.force_optional, + class_name=config.class_name, ) return Exit.OK + except InvalidClassNameError as e: + print(f'{e} You have to set --class-name option', file=sys.stderr) + return Exit.ERROR except Error as e: print(str(e), file=sys.stderr) return Exit.ERROR diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 2dcdac0f9..53cf47192 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -165,6 +165,7 @@ def __init__( allow_population_by_field_name: bool = False, apply_default_values_for_required_fields: bool = False, force_optional_for_required_fields: bool = False, + class_name: Optional[str] = None, ): self.data_type_manager: DataTypeManager = data_type_manager_type( target_python_version @@ -210,6 +211,7 @@ def __init__( self.field_preprocessors.append(snakify_field) if self.strip_default_none: self.field_preprocessors.append(set_strip_default_none) + self.class_name: Optional[str] = class_name @property def iter_source(self) -> Iterator[Source]: diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 08078c2f7..228c14678 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -15,7 +15,7 @@ import yaml from pydantic import BaseModel, Field, root_validator, validator -from datamodel_code_generator import snooper_to_methods +from datamodel_code_generator import Error, InvalidClassNameError, snooper_to_methods from datamodel_code_generator.format import PythonVersion from datamodel_code_generator.model import DataModel, DataModelFieldBase from datamodel_code_generator.model.enum import Enum @@ -193,6 +193,7 @@ def __init__( allow_population_by_field_name: bool = False, apply_default_values_for_required_fields: bool = False, force_optional_for_required_fields: bool = False, + class_name: Optional[str] = None, ): super().__init__( source=source, @@ -213,6 +214,7 @@ def __init__( allow_population_by_field_name=allow_population_by_field_name, apply_default_values_for_required_fields=apply_default_values_for_required_fields, force_optional_for_required_fields=force_optional_for_required_fields, + class_name=class_name, ) self.remote_object_cache: Dict[str, Dict[str, Any]] = {} @@ -744,7 +746,14 @@ def parse_raw(self) -> None: path_parts = list(source.path.parts) self.model_resolver.set_current_root(path_parts) self.raw_obj = yaml.safe_load(source.text) - obj_name = self.raw_obj.get('title', 'Model') + if self.class_name: + obj_name = self.class_name + else: + # backward compatible + obj_name = self.raw_obj.get('title', 'Model') + if not self.model_resolver.validate_name(obj_name): + raise InvalidClassNameError(obj_name) + obj_name = self.model_resolver.add(path_parts, obj_name, unique=False).name self.parse_raw_obj(obj_name, self.raw_obj, path_parts) definitions = self.raw_obj.get('definitions', {}) diff --git a/datamodel_code_generator/reference.py b/datamodel_code_generator/reference.py index c4378ad5a..fa179b553 100644 --- a/datamodel_code_generator/reference.py +++ b/datamodel_code_generator/reference.py @@ -1,4 +1,5 @@ import re +from keyword import iskeyword from pathlib import Path from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -139,6 +140,10 @@ def _get_uniq_name(self, name: str, camel: bool = False) -> str: count += 1 return uniq_name + @classmethod + def validate_name(cls, name: str) -> bool: + return name.isidentifier() and not iskeyword(name) + def get_valid_name(self, name: str, camel: bool = False) -> str: # TODO: when first character is a number replaced_name = re.sub(r'\W', '_', name) diff --git a/docs/index.md b/docs/index.md index 1af8994f6..8b52dbef6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -39,6 +39,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--strip-default-none] [--allow-population-by-field-name] [--use-default] [--force-optional] [--disable-timestamp] + [--class-name CLASS_NAME] [--custom-template-dir CUSTOM_TEMPLATE_DIR] [--extra-template-data EXTRA_TEMPLATE_DATA] [--aliases ALIASES] @@ -61,6 +62,8 @@ optional arguments: --use-default Use default value even if a field is required --force-optional Force optional for required fields --disable-timestamp Disable timestamp on file headers + --class-name CLASS_NAME + Set class name of root model --custom-template-dir CUSTOM_TEMPLATE_DIR Custom template directory --extra-template-data EXTRA_TEMPLATE_DATA diff --git a/tests/data/expected/main/main_invalid_model_name/output.py b/tests/data/expected/main/main_invalid_model_name/output.py new file mode 100644 index 000000000..8551b09ed --- /dev/null +++ b/tests/data/expected/main/main_invalid_model_name/output.py @@ -0,0 +1,19 @@ +# generated by datamodel-codegen: +# filename: invalid_model_name.json +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import Any, List, Optional + +from pydantic import BaseModel, Field, conint + + +class ValidModelName(BaseModel): + firstName: Optional[str] = Field(None, description="The person's first name.") + lastName: Optional[str] = Field(None, description="The person's last name.") + age: Optional[conint(ge=0)] = Field( + None, description='Age in years which must be equal to or greater than zero.' + ) + friends: Optional[List] = None + comment: Optional[Any] = None diff --git a/tests/data/jsonschema/invalid_model_name.json b/tests/data/jsonschema/invalid_model_name.json new file mode 100644 index 000000000..de0d2e2ee --- /dev/null +++ b/tests/data/jsonschema/invalid_model_name.json @@ -0,0 +1,27 @@ +{ + "$id": "https://example.com/person.schema.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "1 xyz", + "type": "object", + "properties": { + "firstName": { + "type": "string", + "description": "The person's first name." + }, + "lastName": { + "type": "string", + "description": "The person's last name." + }, + "age": { + "description": "Age in years which must be equal to or greater than zero.", + "type": "integer", + "minimum": 0 + }, + "friends": { + "type": "array" + }, + "comment": { + "type": "null" + } + } +} \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index 35040bc75..eca0378ba 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -883,3 +883,52 @@ def test_main_subclass_enum(): with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_invalid_model_name_failed(capsys): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'invalid_model_name.json'), + '--output', + str(output_file), + '--input-file-type', + 'jsonschema', + ] + ) + captured = capsys.readouterr() + assert return_code == Exit.ERROR + assert ( + captured.err + == 'title=\'1 xyz\' is invalid class name. You have to set --class-name option\n' + ) + + +@freeze_time('2019-07-26') +def test_main_invalid_model_name(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(JSON_SCHEMA_DATA_PATH / 'invalid_model_name.json'), + '--output', + str(output_file), + '--input-file-type', + 'jsonschema', + '--class-name', + 'ValidModelName', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH / 'main_invalid_model_name' / 'output.py' + ).read_text() + ) + with pytest.raises(SystemExit): + main()