diff --git a/datargs/compat/__init__.py b/datargs/compat/__init__.py index 1dd9718..2ac0d92 100644 --- a/datargs/compat/__init__.py +++ b/datargs/compat/__init__.py @@ -70,8 +70,17 @@ def has_default(self) -> bool: @property def default(self): + if self._field.default is not dataclasses.MISSING: + return self._field.default + if self._field.default_factory is not dataclasses.MISSING: + return self._field.default_factory() + # in this case, normally None will be returned return self._field.default + @property + def default_factory(self): + return self._field.default_factory + @property def name(self): return self._field.name @@ -98,7 +107,8 @@ class DataField(RecordField[dataclasses.Field]): """ def is_required(self) -> bool: - return self.default is dataclasses.MISSING + return self.default is dataclasses.MISSING and \ + self.default_factory is dataclasses.MISSING class NotARecordClass(Exception): diff --git a/datargs/make.py b/datargs/make.py index 4d6858e..1f96759 100644 --- a/datargs/make.py +++ b/datargs/make.py @@ -42,6 +42,7 @@ Specifying enums by name is not currently supported. """ import dataclasses +import argparse # noinspection PyUnresolvedReferences,PyProtectedMember from argparse import ( @@ -288,19 +289,26 @@ def literal_arg(name: str, field: RecordField, override: dict) -> Action: field, {**override, "choices": choices, "type": inner_type} ) +class NegateAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, option_string[2:4] != 'no') + @TypeDispatch.register(bool) def bool_arg(name: str, field: RecordField, override: dict) -> Action: kwargs = { **subdict(common_kwargs(field), ["type"]), **override, - "action": "store_false" - if field.default and field.has_default() - else "store_true", + "action": NegateAction, + "nargs": 0, # required for NegateAction + "default": field.default if field.has_default() else None, } kwargs.pop("type", None) + + negate_name = f"--no-{name[2:]}" + args = [*get_option_strings(name, field), negate_name] return Action( - args=get_option_strings(name, field), + args=args, kwargs=kwargs, ) @@ -478,7 +486,8 @@ def parse(cls: Type[T], args: Optional[Sequence[str]] = None, *, parser=None) -> >>> parse(Args, ["--num", "1"]) Args(is_flag=False, num=1) """ - result = vars(make_parser(cls, parser=parser).parse_args(args)) + parser = make_parser(cls, parser=parser) + result = vars(parser.parse_args(args)) try: command_dest = cls.__datargs_params__.sub_commands.get("dest", None) except AttributeError: diff --git a/tests/test_arg_type.py b/tests/test_arg_type.py index 3b596d1..6311863 100644 --- a/tests/test_arg_type.py +++ b/tests/test_arg_type.py @@ -1,7 +1,7 @@ from abc import ABC import sys from argparse import ArgumentParser -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Type, Sequence, Text, NoReturn, TypeVar, Optional @@ -79,12 +79,12 @@ class TestStoreTrueNoDefault: @factory class TestStoreFalse: - store_false: bool = True + default_true: bool = True args = parse_test(TestStoreFalse, []) - assert args.store_false - args = parse_test(TestStoreFalse, ["--store-false"]) - assert not args.store_false + assert args.default_true + args = parse_test(TestStoreFalse, ["--no-default-true"]) + assert not args.default_true def test_str(factory): @@ -360,6 +360,19 @@ class Nargs: args = parse_test(Nargs, []) assert args.nums == [] +def test_default_factory_dataclass(): + class UserClass: + def __init__(self, value=1): + self.value = value + @dataclass() + class DefaultFactory: + user_class: UserClass = field(default_factory=UserClass) + + args = parse_test(DefaultFactory, ["--user-class", "2"]) + assert args.user_class.value == "2" + args = parse_test(DefaultFactory, []) + assert args.user_class.value == 1 + is_pre_38 = sys.version_info < (3, 8)