Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion datargs/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,17 @@ def has_default(self) -> bool:

@property
def default(self):
if self._field.default is not dataclasses.MISSING:
return self._field.default
if self._field.default_factory is not dataclasses.MISSING:
return self._field.default_factory()
# in this case, normally None will be returned
return self._field.default

@property
def default_factory(self):
return self._field.default_factory

@property
def name(self):
return self._field.name
Expand All @@ -98,7 +107,8 @@ class DataField(RecordField[dataclasses.Field]):
"""

def is_required(self) -> bool:
return self.default is dataclasses.MISSING
return self.default is dataclasses.MISSING and \
self.default_factory is dataclasses.MISSING


class NotARecordClass(Exception):
Expand Down
19 changes: 14 additions & 5 deletions datargs/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
Specifying enums by name is not currently supported.
"""
import dataclasses
import argparse

# noinspection PyUnresolvedReferences,PyProtectedMember
from argparse import (
Expand Down Expand Up @@ -288,19 +289,26 @@ def literal_arg(name: str, field: RecordField, override: dict) -> Action:
field, {**override, "choices": choices, "type": inner_type}
)

class NegateAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, option_string[2:4] != 'no')


@TypeDispatch.register(bool)
def bool_arg(name: str, field: RecordField, override: dict) -> Action:
kwargs = {
**subdict(common_kwargs(field), ["type"]),
**override,
"action": "store_false"
if field.default and field.has_default()
else "store_true",
"action": NegateAction,
"nargs": 0, # required for NegateAction
"default": field.default if field.has_default() else None,
}
kwargs.pop("type", None)

negate_name = f"--no-{name[2:]}"
args = [*get_option_strings(name, field), negate_name]
return Action(
args=get_option_strings(name, field),
args=args,
kwargs=kwargs,
)

Expand Down Expand Up @@ -478,7 +486,8 @@ def parse(cls: Type[T], args: Optional[Sequence[str]] = None, *, parser=None) ->
>>> parse(Args, ["--num", "1"])
Args(is_flag=False, num=1)
"""
result = vars(make_parser(cls, parser=parser).parse_args(args))
parser = make_parser(cls, parser=parser)
result = vars(parser.parse_args(args))
try:
command_dest = cls.__datargs_params__.sub_commands.get("dest", None)
except AttributeError:
Expand Down
23 changes: 18 additions & 5 deletions tests/test_arg_type.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Type, Sequence, Text, NoReturn, TypeVar, Optional
Expand Down Expand Up @@ -79,12 +79,12 @@ class TestStoreTrueNoDefault:

@factory
class TestStoreFalse:
store_false: bool = True
default_true: bool = True

args = parse_test(TestStoreFalse, [])
assert args.store_false
args = parse_test(TestStoreFalse, ["--store-false"])
assert not args.store_false
assert args.default_true
args = parse_test(TestStoreFalse, ["--no-default-true"])
assert not args.default_true


def test_str(factory):
Expand Down Expand Up @@ -360,6 +360,19 @@ class Nargs:
args = parse_test(Nargs, [])
assert args.nums == []

def test_default_factory_dataclass():
class UserClass:
def __init__(self, value=1):
self.value = value
@dataclass()
class DefaultFactory:
user_class: UserClass = field(default_factory=UserClass)

args = parse_test(DefaultFactory, ["--user-class", "2"])
assert args.user_class.value == "2"
args = parse_test(DefaultFactory, [])
assert args.user_class.value == 1


is_pre_38 = sys.version_info < (3, 8)

Expand Down