Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to secondary tables relationships #218

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
de06188
first fix versoin, working only if the items has the same id
Ckk3 Nov 19, 2024
0cd732d
bring back the first version, still missin the different ids logic!
Ckk3 Nov 19, 2024
401cd65
fix: now query can pickup related_model and self_model id
Ckk3 Nov 22, 2024
6f644e3
fix: not working with different ids
Ckk3 Nov 22, 2024
ca9bc1c
add nes tests
Ckk3 Nov 23, 2024
eb852ce
add tests
Ckk3 Nov 23, 2024
5770379
Fix mypy erros, still missing some tests
Ckk3 Nov 24, 2024
be77996
update code to work with sqlalchemy 1.4
Ckk3 Nov 24, 2024
fb6a580
remove old code that only works with sqlalchemy 2
Ckk3 Nov 24, 2024
0fb61bb
add seconday tables tests in test_loader
Ckk3 Nov 24, 2024
03a5438
add new tests to loadar and start mapper tests
Ckk3 Nov 26, 2024
a575650
add mapper tests
Ckk3 Nov 28, 2024
beaa3f9
refactor conftest
Ckk3 Nov 30, 2024
8a65328
refactor test_loader
Ckk3 Nov 30, 2024
9d76061
refactor test_mapper
Ckk3 Nov 30, 2024
91c24c5
run autopep
Ckk3 Nov 30, 2024
1cd8df4
run autopep
Ckk3 Nov 30, 2024
e96f179
separate test
Ckk3 Nov 30, 2024
4b6516b
fix lint
Ckk3 Nov 30, 2024
9b079d4
add release file
Ckk3 Nov 30, 2024
4baa7ae
refactor tests
Ckk3 Nov 30, 2024
33d7758
refactor loader
Ckk3 Nov 30, 2024
2a53474
fix release
Ckk3 Nov 30, 2024
d04af46
update pre-commit to work with python 3.8
Ckk3 Jan 26, 2025
3f7f13d
update loader.py
Ckk3 Jan 26, 2025
ff3e419
updated mapper
Ckk3 Jan 26, 2025
6752231
fix lint
Ckk3 Jan 26, 2025
0cd68d2
remote autopep8 from dev container because it give problems when work…
Ckk3 Jan 26, 2025
0745c64
fix lint
Ckk3 Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/strawberry_sqlalchemy_mapper/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def __init__(self, model):
f"Model `{model}` is not polymorphic or is not the base model of its "
+ "inheritance chain, and thus cannot be used as an interface."
)


class InvalidLocalRemotePairs(Exception):
def __init__(self, relationship_name):
super().__init__(
f"The `local_remote_pairs` for the relationship `{relationship_name}` is invalid or missing. "
+ "This is likely an issue with the library. Please report this error to the maintainers."
)
77 changes: 68 additions & 9 deletions src/strawberry_sqlalchemy_mapper/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Tuple,
Union,
)
from strawberry_sqlalchemy_mapper.exc import InvalidLocalRemotePairs

from sqlalchemy import select, tuple_
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -45,12 +46,16 @@ def __init__(
"One of bind or async_bind_factory must be set for loader to function properly."
)

async def _scalars_all(self, *args, **kwargs):
async def _scalars_all(self, *args, disabled_optimization_to_secondary_tables=False, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: maybe call this enable_ and have True as the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont want to do this because it removes optimizations that we only need to remove when we need to pick up secondary tables values, so if the default is True we will lose peformance in queries that dont need it.
But I agree that this var name aren't good enought, so I will change the name to query_secondary_tables and refactor the function.

if self._async_bind_factory:
async with self._async_bind_factory() as bind:
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated! Thank you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated!

return (await bind.execute(*args, **kwargs)).all()
return (await bind.scalars(*args, **kwargs)).all()
else:
assert self._bind is not None
if disabled_optimization_to_secondary_tables is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick:

Suggested change
if disabled_optimization_to_secondary_tables is True:
if disabled_optimization_to_secondary_tables:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

return self._bind.execute(*args, **kwargs).all()
return self._bind.scalars(*args, **kwargs).all()

def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
Expand All @@ -63,14 +68,63 @@ def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
related_model = relationship.entity.entity

async def load_fn(keys: List[Tuple]) -> List[Any]:
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)
if relationship.secondary is None:
sourcery-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved
query = select(related_model).filter(
tuple_(
*[remote for _, remote in relationship.local_remote_pairs or []]
).in_(keys)
)
else:
# Use another query when relationship uses a secondary table
self_model = relationship.parent.entity

if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{related_model.__name__} -- {self_model.__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish: for ruff/black this parenthesis should be closed in the next line. I think you forgot to pre-commit install =P (ditto for the lines below)

Also, we are probably missing a lint check in here which runs ruff/black/etc (and maybe migrate to ruff formatter instead of black soon)

Copy link
Contributor Author

@Ckk3 Ckk3 Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, I see now that pre-commit dont run due to some updates that dont work with dev container python version (3.8).
I updated that imports and now i'm fixing all the erros ;)


self_model_key_label = str(
relationship.local_remote_pairs[0][1].key)
related_model_key_label = str(
relationship.local_remote_pairs[1][1].key)

self_model_key = str(
relationship.local_remote_pairs[0][0].key)
related_model_key = str(
relationship.local_remote_pairs[1][0].key)

remote_to_use = relationship.local_remote_pairs[0][1]
query_keys = tuple([item[0] for item in keys])

# This query returns rows in this format -> (self_model.key, related_model)
query = (
select(
getattr(self_model, self_model_key).label(
self_model_key_label),
related_model
)
.join(
sourcery-ai[bot] marked this conversation as resolved.
Show resolved Hide resolved
relationship.secondary,
getattr(relationship.secondary.c,
related_model_key_label) == getattr(related_model, related_model_key)
)
.join(
self_model,
getattr(relationship.secondary.c,
self_model_key_label) == getattr(self_model, self_model_key)
)
.filter(
remote_to_use.in_(query_keys)
)
)

if relationship.order_by:
query = query.order_by(*relationship.order_by)
rows = await self._scalars_all(query)

if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

Suggested change
if relationship.secondary is not None:
# We need to retrieve values from both the self_model and related_model. To achieve this, we must disable the default SQLAlchemy optimization that returns only related_model values. This is necessary because we use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=True)
else:
rows = await self._scalars_all(query)
# We need to retrieve values from both the self_model and related_model.
# To achieve this, we must disable the default SQLAlchemy optimization
# that returns only related_model values. This is necessary because we
# use the keys variable to match both related_model and self_model.
rows = await self._scalars_all(query, disabled_optimization_to_secondary_tables=relationship.secondary is not None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated!


def group_by_remote_key(row: Any) -> Tuple:
return tuple(
Expand All @@ -82,8 +136,13 @@ def group_by_remote_key(row: Any) -> Tuple:
)

grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
if relationship.secondary is None:
for row in rows:
grouped_keys[group_by_remote_key(row)].append(row)
else:
for row in rows:
grouped_keys[(row[0],)].append(row[1])

if relationship.uselist:
return [grouped_keys[key] for key in keys]
else:
Expand Down
93 changes: 63 additions & 30 deletions src/strawberry_sqlalchemy_mapper/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
from strawberry_sqlalchemy_mapper.exc import (
HybridPropertyNotAnnotated,
InterfaceModelNotPolymorphic,
InvalidLocalRemotePairs,
UnsupportedAssociationProxyTarget,
UnsupportedColumnType,
UnsupportedDescriptorType,
Expand Down Expand Up @@ -154,7 +155,8 @@ def from_type(cls, type_: type, *, strict: Literal[True]) -> Self: ...

@overload
@classmethod
def from_type(cls, type_: type, *, strict: bool = False) -> Optional[Self]: ...
def from_type(cls, type_: type, *,
strict: bool = False) -> Optional[Self]: ...

@classmethod
def from_type(
Expand All @@ -165,7 +167,8 @@ def from_type(
) -> Optional[Self]:
definition = getattr(type_, cls.TYPE_KEY_NAME, None)
if strict and definition is None:
raise TypeError(f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
raise TypeError(
f"{type_!r} does not have a StrawberrySQLAlchemyType in it")
return definition


Expand Down Expand Up @@ -228,11 +231,12 @@ class StrawberrySQLAlchemyMapper(Generic[BaseModelType]):

def __init__(
self,
model_to_type_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[
Mapping[Type[TypeEngine], Type[Any]]
] = None,
model_to_type_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
model_to_interface_name: Optional[Callable[[
Type[BaseModelType]], str]] = None,
extra_sqlalchemy_type_to_strawberry_type_map: Optional[Mapping[Type[TypeEngine], Type[Any]]
] = None,
) -> None:
if model_to_type_name is None:
model_to_type_name = self._default_model_to_type_name
Expand Down Expand Up @@ -295,7 +299,8 @@ def _edge_type_for(self, type_name: str) -> Type[Any]:
"""
edge_name = f"{type_name}Edge"
if edge_name not in self.edge_types:
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.edge_types[edge_name] = edge_type = strawberry.type(
dataclasses.make_dataclass(
edge_name,
Expand All @@ -314,14 +319,15 @@ def _connection_type_for(self, type_name: str) -> Type[Any]:
connection_name = f"{type_name}Connection"
if connection_name not in self.connection_types:
edge_type = self._edge_type_for(type_name)
lazy_type = StrawberrySQLAlchemyLazy(type_name=type_name, mapper=self)
lazy_type = StrawberrySQLAlchemyLazy(
type_name=type_name, mapper=self)
self.connection_types[connection_name] = connection_type = strawberry.type(
dataclasses.make_dataclass(
connection_name,
[
("edges", List[edge_type]), # type: ignore[valid-type]
],
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
bases=(relay.ListConnection[lazy_type],), # type: ignore[valid-type]
)
)
setattr(connection_type, _GENERATED_FIELD_KEYS_KEY, ["edges"])
Expand Down Expand Up @@ -387,7 +393,7 @@ def _convert_relationship_to_strawberry_type(
if relationship.uselist:
# Use list if excluding relay pagination
if use_list:
return List[ForwardRef(type_name)] # type: ignore
return List[ForwardRef(type_name)] # type: ignore

return self._connection_type_for(type_name)
else:
Expand Down Expand Up @@ -451,7 +457,8 @@ def _get_association_proxy_annotation(
strawberry_type.__forward_arg__
)
else:
strawberry_type = self._connection_type_for(strawberry_type.__name__)
strawberry_type = self._connection_type_for(
strawberry_type.__name__)
return strawberry_type

def make_connection_wrapper_resolver(
Expand Down Expand Up @@ -500,13 +507,29 @@ async def resolve(self, info: Info):
if relationship.key not in instance_state.unloaded:
related_objects = getattr(self, relationship.key)
else:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
)
if relationship.secondary is None:
relationship_key = tuple(
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: you can pass the iterator to the tuple directly, no need to create a list for that

Suggested change
[
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key
]
getattr(self, local.key)
for local, _ in relationship.local_remote_pairs or []
if local.key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)
else:
# If has a secondary table, gets only the first ID as additional IDs require a separate query
if not relationship.local_remote_pairs:
raise InvalidLocalRemotePairs(
f"{relationship.entity.entity.__name__} -- {relationship.parent.entity.__name__}")

local_remote_pairs_secondary_table_local = relationship.local_remote_pairs[
0][0]
relationship_key = tuple(
[
getattr(
self, str(local_remote_pairs_secondary_table_local.key)),
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

)

if any(item is None for item in relationship_key):
if relationship.uselist:
return []
Expand Down Expand Up @@ -536,7 +559,8 @@ def connection_resolver_for(
if relationship.uselist and not use_list:
return self.make_connection_wrapper_resolver(
relationship_resolver,
self.model_to_type_or_interface_name(relationship.entity.entity), # type: ignore[arg-type]
self.model_to_type_or_interface_name(
relationship.entity.entity), # type: ignore[arg-type]
)
else:
return relationship_resolver
Expand All @@ -554,13 +578,15 @@ def association_proxy_resolver_for(
Return an async field resolver for the given association proxy.
"""
in_between_relationship = mapper.relationships[descriptor.target_collection]
in_between_resolver = self.relationship_resolver_for(in_between_relationship)
in_between_resolver = self.relationship_resolver_for(
in_between_relationship)
in_between_mapper: Mapper = mapper.relationships[ # type: ignore[assignment]
descriptor.target_collection
].entity
assert descriptor.value_attr in in_between_mapper.relationships
end_relationship = in_between_mapper.relationships[descriptor.value_attr]
end_relationship_resolver = self.relationship_resolver_for(end_relationship)
end_relationship_resolver = self.relationship_resolver_for(
end_relationship)
end_type_name = self.model_to_type_or_interface_name(
end_relationship.entity.entity # type: ignore[arg-type]
)
Expand All @@ -587,7 +613,8 @@ async def resolve(self, info: Info):
if outputs and isinstance(outputs[0], list):
outputs = list(chain.from_iterable(outputs))
else:
outputs = [output for output in outputs if output is not None]
outputs = [
output for output in outputs if output is not None]
else:
outputs = await end_relationship_resolver(in_between_objects, info)
if not isinstance(outputs, collections.abc.Iterable):
Expand Down Expand Up @@ -683,7 +710,8 @@ def convert(type_: Any) -> Any:
setattr(type_, key, field(resolver=val))
generated_field_keys.append(key)

self._handle_columns(mapper, type_, excluded_keys, generated_field_keys)
self._handle_columns(
mapper, type_, excluded_keys, generated_field_keys)
relationship: RelationshipProperty
for key, relationship in mapper.relationships.items():
if (
Expand Down Expand Up @@ -805,7 +833,7 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(func, type_), # type: ignore[arg-type]
types.MethodType(func, type_), # type: ignore[arg-type]
)

# Adjust types that inherit from other types/interfaces that implement Node
Expand All @@ -818,7 +846,8 @@ def convert(type_: Any) -> Any:
setattr(
type_,
attr,
types.MethodType(cast(classmethod, meth).__func__, type_),
types.MethodType(
cast(classmethod, meth).__func__, type_),
)

# need to make fields that are already in the type
Expand Down Expand Up @@ -846,7 +875,8 @@ def convert(type_: Any) -> Any:
model=model,
),
)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY, generated_field_keys)
setattr(mapped_type, _GENERATED_FIELD_KEYS_KEY,
generated_field_keys)
setattr(mapped_type, _ORIGINAL_TYPE_KEY, type_)
return mapped_type

Expand Down Expand Up @@ -886,14 +916,16 @@ def _fix_annotation_namespaces(self) -> None:
self.edge_types.values(),
self.connection_types.values(),
):
strawberry_definition = get_object_definition(mapped_type, strict=True)
strawberry_definition = get_object_definition(
mapped_type, strict=True)
for f in strawberry_definition.fields:
if f.name in getattr(mapped_type, _GENERATED_FIELD_KEYS_KEY):
namespace = {}
if hasattr(mapped_type, _ORIGINAL_TYPE_KEY):
namespace.update(
sys.modules[
getattr(mapped_type, _ORIGINAL_TYPE_KEY).__module__
getattr(mapped_type,
_ORIGINAL_TYPE_KEY).__module__
].__dict__
)
namespace.update(self.mapped_types)
Expand Down Expand Up @@ -924,7 +956,8 @@ def _map_unmapped_relationships(self) -> None:
if type_name not in self.mapped_interfaces:
unmapped_interface_models.add(model)
for model in unmapped_models:
self.type(model)(type(self.model_to_type_name(model), (object,), {}))
self.type(model)(
type(self.model_to_type_name(model), (object,), {}))
for model in unmapped_interface_models:
self.interface(model)(
type(self.model_to_interface_name(model), (object,), {})
Expand Down
Loading
Loading