diff --git a/README.md b/README.md index 2b90f59d9..bd8224adf 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--input-file-type {auto,openapi,jsonschema,json,yaml,dict,csv}] [--output OUTPUT] [--base-class BASE_CLASS] [--field-constraints] [--snake-case-field] [--strip-default-none] - [--allow-population-by-field-name] [--use-default] [--force-optional] + [--allow-population-by-field-name] [--use-default] [--force-optional] [--strict-nullable] [--disable-timestamp] [--use-standard-collections] [--use-schema-description] [--reuse-model] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--class-name CLASS_NAME] [--custom-template-dir CUSTOM_TEMPLATE_DIR] @@ -94,6 +94,7 @@ optional arguments: Allow population by field name --use-default Use default value even if a field is required --force-optional Force optional for required fields + --strict-nullable Treat default field as a non-nullable field (only OpenAPI) --disable-timestamp Disable timestamp on file headers --use-standard-collections Use standard collections for type hinting (list, dict) diff --git a/datamodel_code_generator/__init__.py b/datamodel_code_generator/__init__.py index 419fa3443..82e674978 100644 --- a/datamodel_code_generator/__init__.py +++ b/datamodel_code_generator/__init__.py @@ -193,6 +193,7 @@ def generate( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + strict_nullable: bool = False, ) -> None: input_text: Optional[str] = None if input_file_type == InputFileType.Auto: @@ -283,6 +284,7 @@ def get_header_and_first_line(csv_file: IO[str]) -> Dict[str, Any]: reuse_model=reuse_model, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, + strict_nullable=strict_nullable, ) with chdir(output): diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index 912b0fc3f..19148f82e 100755 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -102,6 +102,13 @@ def sig_int_handler(_: int, __: Any) -> None: # pragma: no cover default=None, ) +arg_parser.add_argument( + '--strict-nullable', + help='Treat default field as a non-nullable field (Only OpenAPI)', + action='store_true', + default=None, +) + arg_parser.add_argument( '--disable-timestamp', help='Disable timestamp on file headers', @@ -233,6 +240,7 @@ def validate_literal_option(cls, values: Dict[str, Any]) -> Dict[str, Any]: encoding: str = 'utf-8' enum_field_as_literal: Optional[LiteralType] = None set_default_enum_member: bool = False + strict_nullable: bool = False def merge_args(self, args: Namespace) -> None: for field_name in self.__fields__: @@ -347,6 +355,7 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: encoding=config.encoding, enum_field_as_literal=config.enum_field_as_literal, set_default_enum_member=config.set_default_enum_member, + strict_nullable=config.strict_nullable, ) return Exit.OK except InvalidClassNameError as e: diff --git a/datamodel_code_generator/model/base.py b/datamodel_code_generator/model/base.py index b426cc79e..859acbcf3 100644 --- a/datamodel_code_generator/model/base.py +++ b/datamodel_code_generator/model/base.py @@ -37,19 +37,28 @@ class DataModelFieldBase(BaseModel): data_type: DataType constraints: Any = None strip_default_none: bool = False + nullable: Optional[bool] = None @property def type_hint(self) -> str: type_hint = self.data_type.type_hint - if self.required: - return type_hint - if type_hint is None or type_hint == '': + + if not type_hint: return OPTIONAL + elif self.nullable is not None: + if self.nullable: + return f'{OPTIONAL}[{type_hint}]' + return type_hint + elif self.required: + return type_hint return f'{OPTIONAL}[{type_hint}]' @property def imports(self) -> List[Import]: - if not self.required: + if self.nullable is None: + if not self.required: + return self.data_type.imports_ + [IMPORT_OPTIONAL] + elif self.nullable: return self.data_type.imports_ + [IMPORT_OPTIONAL] return self.data_type.imports_ diff --git a/datamodel_code_generator/model/pydantic/base_model.py b/datamodel_code_generator/model/pydantic/base_model.py index cd09134c0..df5b93c35 100644 --- a/datamodel_code_generator/model/pydantic/base_model.py +++ b/datamodel_code_generator/model/pydantic/base_model.py @@ -64,6 +64,8 @@ def __str__(self) -> str: f"{k}={repr(v)}" for k, v in data.items() if v is not None ) if not field_arguments: + if self.nullable and self.required: + return 'Field(...)' # Field() is for mypy return "" value_arg = "..." if self.required else repr(self.default) @@ -126,5 +128,6 @@ def __init__( self.extra_template_data['config'] = Config.parse_obj(config_parameters) for field in fields: - if field.field: + field_value = field.field + if field_value and field_value != '...': self.imports.append(IMPORT_FIELD) diff --git a/datamodel_code_generator/parser/base.py b/datamodel_code_generator/parser/base.py index 2540f044e..8e1fab154 100644 --- a/datamodel_code_generator/parser/base.py +++ b/datamodel_code_generator/parser/base.py @@ -238,6 +238,7 @@ def __init__( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + strict_nullable: bool = False, ): self.data_type_manager: DataTypeManager = data_type_manager_type( target_python_version, use_standard_collections @@ -263,6 +264,7 @@ def __init__( self.encoding: str = encoding self.enum_field_as_literal: Optional[LiteralType] = enum_field_as_literal self.set_default_enum_member: bool = set_default_enum_member + self.strict_nullable: bool = strict_nullable self.current_source_path: Optional[Path] = None diff --git a/datamodel_code_generator/parser/jsonschema.py b/datamodel_code_generator/parser/jsonschema.py index 9fc20aa56..646d06044 100644 --- a/datamodel_code_generator/parser/jsonschema.py +++ b/datamodel_code_generator/parser/jsonschema.py @@ -32,10 +32,9 @@ from datamodel_code_generator.model.enum import Enum from datamodel_code_generator.parser import LiteralType -from ..imports import IMPORT_LITERAL, Import from ..model import pydantic as pydantic_model from ..parser.base import Parser -from ..reference import Reference, is_url +from ..reference import is_url from ..types import DataType, DataTypeManager, Types @@ -249,6 +248,7 @@ def __init__( encoding: str = 'utf-8', enum_field_as_literal: Optional[LiteralType] = None, set_default_enum_member: bool = False, + strict_nullable: bool = False, ): super().__init__( source=source, @@ -277,6 +277,7 @@ def __init__( encoding=encoding, enum_field_as_literal=enum_field_as_literal, set_default_enum_member=set_default_enum_member, + strict_nullable=strict_nullable, ) self.remote_object_cache: Dict[str, Dict[str, Any]] = {} @@ -556,6 +557,7 @@ def parse_object_fields( required=required, alias=alias, constraints=constraints, + nullable=field.nullable if self.strict_nullable else None, ) ) return fields @@ -663,9 +665,14 @@ def parse_array_fields( if self.force_optional_for_required_fields: required: bool = False else: - required = not obj.nullable and not ( - obj.has_default and self.apply_default_values_for_required_fields - ) + if self.strict_nullable: + required = not ( + obj.has_default and self.apply_default_values_for_required_fields + ) + else: + required = not obj.nullable and not ( + obj.has_default and self.apply_default_values_for_required_fields + ) return self.data_model_field_type( data_type=self.data_type(data_types=item_obj_data_types, is_list=True,), example=obj.example, @@ -675,6 +682,7 @@ def parse_array_fields( title=obj.title, required=required, constraints=obj.dict(), + nullable=obj.nullable if self.strict_nullable else None, ) def parse_array( @@ -713,9 +721,14 @@ def parse_root_type( if self.force_optional_for_required_fields: required: bool = False else: - required = not obj.nullable and not ( - obj.has_default and self.apply_default_values_for_required_fields - ) + if self.strict_nullable: + required = not ( + obj.has_default and self.apply_default_values_for_required_fields + ) + else: + required = not obj.nullable and not ( + obj.has_default and self.apply_default_values_for_required_fields + ) reference = self.model_resolver.add(path, name, loaded=True) self.set_title(name, obj) self.set_additional_properties(name, additional_properties or obj) @@ -730,6 +743,7 @@ def parse_root_type( default=obj.default, required=required, constraints=obj.dict() if self.field_constraints else {}, + nullable=obj.nullable if self.strict_nullable else None, ) ], custom_base_class=self.base_class, diff --git a/docs/index.md b/docs/index.md index 6ccd54e80..28b8a0d17 100644 --- a/docs/index.md +++ b/docs/index.md @@ -37,7 +37,7 @@ usage: datamodel-codegen [-h] [--input INPUT] [--input-file-type {auto,openapi,jsonschema,json,yaml,dict,csv}] [--output OUTPUT] [--base-class BASE_CLASS] [--field-constraints] [--snake-case-field] [--strip-default-none] - [--allow-population-by-field-name] [--use-default] [--force-optional] + [--allow-population-by-field-name] [--use-default] [--force-optional] [--strict-nullable] [--disable-timestamp] [--use-standard-collections] [--use-schema-description] [--reuse-model] [--enum-field-as-literal {all,one}] [--set-default-enum-member] [--class-name CLASS_NAME] [--custom-template-dir CUSTOM_TEMPLATE_DIR] @@ -61,6 +61,7 @@ optional arguments: Allow population by field name --use-default Use default value even if a field is required --force-optional Force optional for required fields + --strict-nullable Treat default field as a non-nullable field (only OpenAPI) --disable-timestamp Disable timestamp on file headers --use-standard-collections Use standard collections for type hinting (list, dict) diff --git a/tests/data/expected/main/main_openapi_nullable/output.py b/tests/data/expected/main/main_openapi_nullable/output.py new file mode 100644 index 000000000..479edb9c2 --- /dev/null +++ b/tests/data/expected/main/main_openapi_nullable/output.py @@ -0,0 +1,60 @@ +# generated by datamodel-codegen: +# filename: nullable.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import AnyUrl, BaseModel, Field + + +class Cursors(BaseModel): + prev: str + next: Optional[str] = 'last' + index: float + + +class TopLevel(BaseModel): + cursors: Cursors + + +class Info(BaseModel): + name: str + + +class User(BaseModel): + info: Info + + +class Api(BaseModel): + apiKey: Optional[str] = Field( + None, description='To be used as a dataset parameter value' + ) + apiVersionNumber: Optional[str] = Field( + None, description='To be used as a version parameter value' + ) + apiUrl: Optional[AnyUrl] = Field( + None, description="The URL describing the dataset's fields" + ) + apiDocumentationUrl: Optional[AnyUrl] = Field( + None, description='A URL to the API console for each API' + ) + + +class Apis(BaseModel): + __root__: Optional[List[Api]] = None + + +class EmailItem(BaseModel): + author: str + address: str = Field(..., description='email address') + description: Optional[str] = 'empty' + + +class Email(BaseModel): + __root__: List[EmailItem] + + +class Id(BaseModel): + __root__: int diff --git a/tests/data/expected/main/main_openapi_nullable_strict_nullable/output.py b/tests/data/expected/main/main_openapi_nullable_strict_nullable/output.py new file mode 100644 index 000000000..7882a2534 --- /dev/null +++ b/tests/data/expected/main/main_openapi_nullable_strict_nullable/output.py @@ -0,0 +1,58 @@ +# generated by datamodel-codegen: +# filename: nullable.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import AnyUrl, BaseModel, Field + + +class Cursors(BaseModel): + prev: Optional[str] = Field(...) + next: str = 'last' + index: float + + +class TopLevel(BaseModel): + cursors: Cursors + + +class Info(BaseModel): + name: str + + +class User(BaseModel): + info: Info + + +class Api(BaseModel): + apiKey: str = Field(None, description='To be used as a dataset parameter value') + apiVersionNumber: str = Field( + None, description='To be used as a version parameter value' + ) + apiUrl: Optional[AnyUrl] = Field( + None, description="The URL describing the dataset's fields" + ) + apiDocumentationUrl: Optional[AnyUrl] = Field( + None, description='A URL to the API console for each API' + ) + + +class Apis(BaseModel): + __root__: Optional[List[Api]] = Field(...) + + +class EmailItem(BaseModel): + author: str + address: str = Field(..., description='email address') + description: str = 'empty' + + +class Email(BaseModel): + __root__: List[EmailItem] + + +class Id(BaseModel): + __root__: int diff --git a/tests/data/openapi/nullable.yaml b/tests/data/openapi/nullable.yaml new file mode 100644 index 000000000..55fa604b4 --- /dev/null +++ b/tests/data/openapi/nullable.yaml @@ -0,0 +1,82 @@ +openapi: 3.0.3 +info: + version: 1.0.0 + title: testapi + license: + name: proprietary +servers: [] +paths: {} +components: + schemas: + TopLevel: + type: object + properties: + cursors: + type: object + properties: + prev: + type: string + nullable: true + next: + type: string + default: last + index: + type: number + required: + - prev + - index + required: + - cursors + User: + type: object + properties: + info: + type: object + properties: + name: + type: string + required: + - name + required: + - info + apis: + type: array + nullable: true + items: + type: object + properties: + apiKey: + type: string + description: To be used as a dataset parameter value + apiVersionNumber: + type: string + description: To be used as a version parameter value + apiUrl: + type: string + format: uri + description: "The URL describing the dataset's fields" + nullable: true + apiDocumentationUrl: + type: string + format: uri + description: A URL to the API console for each API + nullable: true + email: + type: array + items: + type: object + properties: + author: + type: string + address: + type: string + description: email address + description: + type: string + default: empty + required: + - author + - address + id: + type: integer + default: 1 \ No newline at end of file diff --git a/tests/model/pydantic/test_base_model.py b/tests/model/pydantic/test_base_model.py index 18e5db9c0..9ff2d574a 100644 --- a/tests/model/pydantic/test_base_model.py +++ b/tests/model/pydantic/test_base_model.py @@ -32,6 +32,43 @@ def test_base_model_optional(): ) +def test_base_model_nullable_required(): + field = DataModelField( + name='a', + data_type=DataType(type='str'), + default='abc', + required=True, + nullable=True, + ) + + base_model = BaseModel(name='test_model', fields=[field]) + + assert base_model.name == 'test_model' + assert base_model.fields == [field] + assert base_model.decorators == [] + assert ( + base_model.render() == 'class test_model(BaseModel):\n' + ' a: Optional[str] = Field(...)' + ) + + +def test_base_model_strict_non_nullable_required(): + field = DataModelField( + name='a', + data_type=DataType(type='str'), + default='abc', + required=True, + nullable=False, + ) + + base_model = BaseModel(name='test_model', fields=[field]) + + assert base_model.name == 'test_model' + assert base_model.fields == [field] + assert base_model.decorators == [] + assert base_model.render() == 'class test_model(BaseModel):\n' ' a: str' + + def test_base_model_decorator(): field = DataModelField( name='a', data_type=DataType(type='str'), default='abc', required=False diff --git a/tests/test_main.py b/tests/test_main.py index 57317da18..ab10e02c8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1785,3 +1785,54 @@ def test_main_all_of_ref(): ) with pytest.raises(SystemExit): main() + + +@freeze_time('2019-07-26') +def test_main_openapi_nullable(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'nullable.yaml'), + '--output', + str(output_file), + '--input-file-type', + 'openapi', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == (EXPECTED_MAIN_PATH / 'main_openapi_nullable' / 'output.py').read_text() + ) + with pytest.raises(SystemExit): + main() + + +@freeze_time('2019-07-26') +def test_main_openapi_nullable_strict_nullable(): + with TemporaryDirectory() as output_dir: + output_file: Path = Path(output_dir) / 'output.py' + return_code: Exit = main( + [ + '--input', + str(OPEN_API_DATA_PATH / 'nullable.yaml'), + '--output', + str(output_file), + '--input-file-type', + 'openapi', + '--strict-nullable', + ] + ) + assert return_code == Exit.OK + assert ( + output_file.read_text() + == ( + EXPECTED_MAIN_PATH + / 'main_openapi_nullable_strict_nullable' + / 'output.py' + ).read_text() + ) + with pytest.raises(SystemExit): + main()