Skip to content

Commit

Permalink
fix ignore_alias_conflicts options
Browse files Browse the repository at this point in the history
  • Loading branch information
voidZXL committed Nov 7, 2024
1 parent 9afb981 commit dc876c4
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
15 changes: 15 additions & 0 deletions tests/test_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,21 @@ class InsensitiveSchemaInvalid2(Schema):
value: str
# name is same

# ------------
class AliasSchema(Schema):
alias: str = Field(alias_from=['alias_from', '@af2'])

assert dict(AliasSchema({'alias': 1, 'alias_from': 1, '@af2': 1})) == {'alias': '1'}

with pytest.raises(exc.AliasConflictError):
dict(AliasSchema({'alias': 1, 'alias_from': 2}))

class AliasSchema2(Schema):
__options__ = Options(ignore_alias_conflicts=True)
alias: str = Field(alias_from=['alias_from', '@af2'])
assert dict(AliasSchema2({'alias': 1, 'alias_from': 2})) == {'alias': '1'}
assert dict(AliasSchema2({'alias_from': 1, '@af2': 2})) == {'alias': '1'}

def test_addition(self):
class UserSchemaDisallow(Schema):
__options__ = Options(addition=False)
Expand Down
16 changes: 11 additions & 5 deletions utype/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,13 @@ def data_first_parse(
result[name] = default
continue

if (name in result) or (excluded_keys and name in excluded_keys):
if options.ignore_alias_conflicts:
if not options.ignore_alias_conflicts:
if name in result: # or (excluded_keys and name in excluded_keys):
if result[name] != value:
context.handle_error(exc.AliasConflictError(item=name, value=value))
continue
context.handle_error(exc.AliasConflictError(item=name, value=value))

if excluded_keys and name in excluded_keys:
continue

parsed = field.parse_value(value, context=context)
Expand Down Expand Up @@ -547,8 +550,9 @@ def field_first_parse(
if unprovided(value):
value = data[alias]
else:
context.handle_error(exc.AliasConflictError(item=name))
break
if data[alias] != value:
context.handle_error(exc.AliasConflictError(item=name, value=data[alias]))
break

if unprovided(value):
unprovided_fields.add(name)
Expand Down Expand Up @@ -605,6 +609,8 @@ def field_first_parse(
for k, v in data.items():
if k in used_alias:
continue
# if excluded_keys and k in excluded_keys:
# pass
add_value = self.parse_addition(k, v, context=context)
if not unprovided(add_value):
addition[k] = add_value
Expand Down
33 changes: 23 additions & 10 deletions utype/parser/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..utils.base import ParamsCollector
from ..utils.compat import Literal, get_args, is_final, is_annotated, ForwardRef
from ..utils.datastructures import unprovided
from ..utils.functional import copy_value, get_name, multi
from ..utils.functional import copy_value, get_name, multi, distinct_add
from .options import Options, RuntimeContext
from .rule import ConstraintMode, Lax, LogicalType, Rule, resolve_forward_type
from ..settings import warning_settings
Expand Down Expand Up @@ -293,8 +293,8 @@ def get_alias(self, attname: str, generator=None):
alias = _alias
return alias

def get_alias_from(self, attname: str, generator=None) -> Set[str]:
aliases = {attname}
def get_alias_from(self, attname: str, generator=None) -> List[str]:
aliases = [attname]
alias_from = []
if self.alias_from:
if not multi(self.alias_from):
Expand All @@ -312,9 +312,9 @@ def get_alias_from(self, attname: str, generator=None) -> Set[str]:
if callable(alias):
alias = alias(attname)
if multi(alias):
aliases.update([a for a in alias if isinstance(a, str) and a])
distinct_add(aliases, [a for a in alias if isinstance(a, str) and a])
elif isinstance(alias, str) and alias:
aliases.add(alias)
distinct_add(aliases, [alias])

return aliases

Expand Down Expand Up @@ -437,7 +437,7 @@ def __init__(
# all the transformers and validators are infer from type
field: Field,
attname: str = None,
aliases: Set[str] = None,
aliases: List[str] = None,
field_property: property = None,
output_type: type = None,
output_field: Field = None,
Expand All @@ -458,8 +458,17 @@ def __init__(
self.property = field_property
self.final = final
self.name = name
self.aliases = set(aliases or []).difference({self.name})
self.all_aliases = self.aliases.union({self.name})

all_aliases = [self.name]
_aliases = []
for alias in aliases or []:
if alias not in all_aliases:
all_aliases.append(alias)
if alias != name:
_aliases.append(alias)
# !!!order matters
self.all_aliases = all_aliases
self.aliases = _aliases

# self.input_transformer = self.transformer_cls.resolver_transformer(input_type)
self.dependencies = dependencies
Expand Down Expand Up @@ -491,6 +500,10 @@ def __init__(
# ----------------
# below are static field properties

@property
def alias_set(self):
return set(self.aliases)

@property
def discriminator(self):
return self.field.discriminator
Expand Down Expand Up @@ -538,7 +551,7 @@ def setup(self, options: Options):
# do not lower name
# self.name = self.name.lower()
self.aliases = {a.lower() for a in self.aliases}
self.all_aliases = {a.lower() for a in self.all_aliases}
self.all_aliases = [a.lower() for a in self.all_aliases]

if self.repr_func is None:
if options.secret_names:
Expand Down Expand Up @@ -675,7 +688,7 @@ def apply_fields(
take the field
"""
if self.aliases:
inter = self.aliases.intersection(fields)
inter = self.alias_set.intersection(fields)
if inter:
raise exc.ConfigError(
f"Field(name={repr(self.name)}) aliases: {inter} conflict with fields"
Expand Down
2 changes: 1 addition & 1 deletion utype/parser/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
ignore_constraints: bool = unprovided, # for Rule, ignore constraints, only transform type
alias_from_generator: Union[Callable, List[Callable]] = unprovided,
alias_generator: Callable = unprovided,
ignore_alias_conflict: bool = unprovided,
ignore_alias_conflicts: bool = unprovided,
allow_subclasses: bool = unprovided,
cast_keyword_str: bool = unprovided,
# allowed_runtime_options: Union[str, None, List[str]] = "*",
Expand Down

0 comments on commit dc876c4

Please sign in to comment.