Skip to content

Commit

Permalink
updated mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Ckk3 committed Jan 26, 2025
1 parent 3f7f13d commit ff3e419
Showing 1 changed file with 37 additions and 47 deletions.
84 changes: 37 additions & 47 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *,
strict: bool = False) -> Optional[Self]: ...
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
Expand All @@ -167,8 +166,7 @@ def from_type(
) -> Optional[Self]:
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
if strict and definition is None:
raise TypeError(
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
return definition


Expand Down Expand Up @@ -231,12 +229,11 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):

def __init__(
self,
model_to_type_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[Mapping[Type[TypeEngine], Type[Any]]
] = None,
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
Mapping[Type[TypeEngine], Type[Any]]
] = None,
) -> None:
if model_to_type_name is None:
model_to_type_name = self._default_model_to_type_name
Expand Down Expand Up @@ -299,8 +296,7 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
"""
edge_name = f"{type_name}Edge"
if edge_name not in self.edge_types:
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
self.edge_types[edge_name] = edge_type = strawberry.type(
dataclasses.make_dataclass(
edge_name,
Expand All @@ -319,15 +315,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
connection_name = f"{type_name}Connection"
if connection_name not in self.connection_types:
edge_type = self._edge_type_for(type_name)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
self.connection_types[connection_name] = connection_type = strawberry.type(
dataclasses.make_dataclass(
connection_name,
[
("edges", List[edge_type]), # type: ignore[valid-type]
],
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
# type: ignore[valid-type]
bases=(relay.ListConnection[lazy_type],),
)
)
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
Expand Down Expand Up @@ -457,8 +453,7 @@ def _get_association_proxy_annotation(
strawberry_type.__forward_arg__
)
else:
strawberry_type = self._connection_type_for(
strawberry_type.__name__)
strawberry_type = self._connection_type_for(strawberry_type.__name__)
return strawberry_type

def make_connection_wrapper_resolver(
Expand Down Expand Up @@ -509,24 +504,25 @@ async def resolve(self, info: Info):
else:
if relationship.secondary is None:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
)
else:
# If has a secondary table, gets only the first ID as additional IDs require a separate query
if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}")
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}"
)

local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
0][0]
local_remote_pairs_secondary_table_local = (
relationship.local_remote_pairs[0][0]
)
relationship_key = tuple(
[
getattr(
self, str(local_remote_pairs_secondary_table_local.key)),
self, str(local_remote_pairs_secondary_table_local.key)
),
]
)

Expand Down Expand Up @@ -560,7 +556,8 @@ def connection_resolver_for(
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(
relationship.entity.entity), # type: ignore[arg-type]
relationship.entity.entity # type: ignore[arg-type]
),
)
else:
return relationship_resolver
Expand All @@ -578,15 +575,13 @@ def association_proxy_resolver_for(
Return an async field resolver for the given association proxy.
"""
in_between_relationship = mapper.relationships[descriptor.target_collection]
in_between_resolver = self.relationship_resolver_for(
in_between_relationship)
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
descriptor.target_collection
].entity
assert descriptor.value_attr in in_between_mapper.relationships
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
end_relationship_resolver = self.relationship_resolver_for(
end_relationship)
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
end_type_name = self.model_to_type_or_interface_name(
end_relationship.entity.entity # type: ignore[arg-type]
)
Expand All @@ -613,8 +608,7 @@ async def resolve(self, info: Info):
if outputs and isinstance(outputs[0], list):
outputs = list(chain.from_iterable(outputs))
else:
outputs = [
output for output in outputs if output is not None]
outputs = [output for output in outputs if output is not None]
else:
outputs = await end_relationship_resolver(in_between_objects, info)
if not isinstance(outputs, collections.abc.Iterable):
Expand Down Expand Up @@ -710,8 +704,7 @@ def convert(type_: Any) -> Any:
setattr(type_, key, field(resolver=val))
generated_field_keys.append(key)

self._handle_columns(
mapper, type_, excluded_keys, generated_field_keys)
self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
relationship: RelationshipProperty
for key, relationship in mapper.relationships.items():
if (
Expand Down Expand Up @@ -813,7 +806,8 @@ def convert(type_: Any) -> Any:
# ignore inherited `is_type_of`
if "is_type_of" not in type_.__dict__:
type_.is_type_of = (
lambda obj, info: type(obj) == model or type(obj) == type_
lambda obj, info: type(obj) == model # noqa: E721
or type(obj) == type_ # noqa: E721
)

# Default querying methods for relay
Expand All @@ -833,7 +827,8 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(func, type_), # type: ignore[arg-type]
# type: ignore[arg-type]
types.MethodType(func, type_),
)

# Adjust types that inherit from other types/interfaces that implement Node
Expand All @@ -846,8 +841,7 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(
cast(classmethod, meth).__func__, type_),
types.MethodType(cast(classmethod, meth).__func__, type_),
)

# need to make fields that are already in the type
Expand Down Expand Up @@ -875,8 +869,7 @@ def convert(type_: Any) -> Any:
model=model,
),
)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
generated_field_keys)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
return mapped_type

Expand Down Expand Up @@ -916,16 +909,14 @@ def _fix_annotation_namespaces(self) -> None:
self.edge_types.values(),
self.connection_types.values(),
):
strawberry_definition = get_object_definition(
mapped_type, strict=True)
strawberry_definition = get_object_definition(mapped_type, strict=True)
for f in strawberry_definition.fields:
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
namespace = {}
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
namespace.update(
sys.modules[
getattr(mapped_type,
_ORIGINAL_TYPE_KEY).__module__
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
].__dict__
)
namespace.update(self.mapped_types)
Expand Down Expand Up @@ -956,8 +947,7 @@ def _map_unmapped_relationships(self) -> None:
if type_name not in self.mapped_interfaces:
unmapped_interface_models.add(model)
for model in unmapped_models:
self.type(model)(
type(self.model_to_type_name(model), (object,), {}))
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
for model in unmapped_interface_models:
self.interface(model)(
type(self.model_to_interface_name(model), (object,), {})
Expand Down

0 comments on commit ff3e419

Please sign in to comment.