Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

https://github.com/piccolo-orm/piccolo_api/discussions/265#discussion… #929

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion piccolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__VERSION__ = "1.16.0"
__VERSION__ = "1.16.1.dev2"
3 changes: 2 additions & 1 deletion piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,7 +2052,8 @@ class Treasurer(Table):
if not self._meta.unique or any(
not i._meta.unique for i in self._meta.call_chain
):
raise ValueError("Only reverse unique foreign keys.")
pass
# raise ValueError("Only reverse unique foreign keys.")

foreign_keys = [*self._meta.call_chain, self]

Expand Down
120 changes: 91 additions & 29 deletions piccolo/columns/m2m.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
m2m: M2M,
as_list: bool = False,
load_json: bool = False,
reverse: t.Optional[bool] = None,
):
"""
:param columns:
Expand All @@ -39,12 +40,15 @@ def __init__(
flattened list will be returned, rather than a list of objects.
:param load_json:
If ``True``, any JSON strings are loaded as Python objects.
:param reverse:
If ``True``, make reverse query to self reference tables.

"""
self.as_list = as_list
self.columns = columns
self.m2m = m2m
self.load_json = load_json
self.reverse = reverse

safe_types = (int, str)

Expand Down Expand Up @@ -74,20 +78,43 @@ def get_select_string(
fk_2 = self.m2m._meta.secondary_foreign_key
fk_2_name = fk_2._meta.db_column_name
table_2 = fk_2._foreign_key_meta.resolved_references
table_2_name = table_2._meta.tablename
table_2_name_with_schema = table_2._meta.get_formatted_tablename()
table_2_pk_name = table_2._meta.primary_key._meta.db_column_name

inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501
# if primary and secondary table are the same
if table_1 == table_2:
table_2_name = table_1._meta.tablename
table_2_name_with_schema = table_1._meta.get_formatted_tablename()
table_2_pk_name = table_1._meta.primary_key._meta.db_column_name
# check reverse argument. If True change direction in query
if self.reverse:
inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_2_name}" = "{table_2_name}"."{table_2_pk_name}"
""" # noqa: E501
else:
inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501
else:
table_2_name = table_2._meta.tablename
table_2_name_with_schema = table_2._meta.get_formatted_tablename()
table_2_pk_name = table_2._meta.primary_key._meta.db_column_name

inner_select = f"""
{m2m_table_name_with_schema}
JOIN {table_1_name_with_schema} "inner_{table_1_name}" ON (
{m2m_table_name_with_schema}."{fk_1_name}" = "inner_{table_1_name}"."{table_1_pk_name}"
)
JOIN {table_2_name_with_schema} "inner_{table_2_name}" ON (
{m2m_table_name_with_schema}."{fk_2_name}" = "inner_{table_2_name}"."{table_2_pk_name}"
)
WHERE {m2m_table_name_with_schema}."{fk_1_name}" = "{table_1_name}"."{table_1_pk_name}"
""" # noqa: E501

if engine_type in ("postgres", "cockroach"):
if self.as_list:
Expand Down Expand Up @@ -244,9 +271,17 @@ def secondary_foreign_key(self) -> ForeignKey:
"""
See ``primary_foreign_key``.
"""
# if primary and secondary table are the same
for fk_column in self.foreign_key_columns:
if fk_column._foreign_key_meta.resolved_references != self.table:
return fk_column
if (
fk_column._foreign_key_meta.resolved_references
== self.primary_table
):
return self.foreign_key_columns[-1]
if self.table == self.primary_table:
return self.foreign_key_columns[-1]

raise ValueError("No matching foreign key column found!")

Expand Down Expand Up @@ -366,29 +401,54 @@ def __await__(self):
class M2MGetRelated:
row: Table
m2m: M2M
reverse: t.Optional[bool] = False

async def run(self):

joining_table = self.m2m._meta.resolved_joining_table

secondary_table = self.m2m._meta.secondary_table
if self.reverse:
try:
ids = (
await joining_table.select(
getattr(
self.m2m._meta.primary_foreign_key,
secondary_table._meta.primary_key._meta.name,
)
)
.where(self.m2m._meta.secondary_foreign_key == self.row)
.output(as_list=True)
)

# TODO - replace this with a subquery in the future.
ids = (
await joining_table.select(
getattr(
self.m2m._meta.secondary_foreign_key,
secondary_table._meta.primary_key._meta.name,
results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(ids)
)
)
.where(self.m2m._meta.primary_foreign_key == self.row)
.output(as_list=True)
)
except ValueError:
results = []

results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(ids)
)
return results
else:
try:
# TODO - replace this with a subquery in the future.
ids = (
await joining_table.select(
getattr(
self.m2m._meta.secondary_foreign_key,
secondary_table._meta.primary_key._meta.name,
)
)
.where(self.m2m._meta.primary_foreign_key == self.row)
.output(as_list=True)
)

results = await secondary_table.objects().where(
secondary_table._meta.primary_key.is_in(ids)
)
except ValueError:
results = []

return results
return results

def run_sync(self):
return run_sync(self.run())
Expand Down Expand Up @@ -427,6 +487,7 @@ def __call__(
*columns: t.Union[Column, t.List[Column]],
as_list: bool = False,
load_json: bool = False,
reverse: t.Optional[bool] = None,
) -> M2MSelect:
"""
:param columns:
Expand All @@ -437,6 +498,7 @@ def __call__(
flattened list will be returned, rather than a list of objects.
:param load_json:
If ``True``, any JSON strings are loaded as Python objects.

"""
columns_ = flatten(columns)

Expand All @@ -449,5 +511,5 @@ def __call__(
)

return M2MSelect(
*columns_, m2m=self, as_list=as_list, load_json=load_json
*columns_, m2m=self, as_list=as_list, load_json=load_json, reverse=reverse
)
20 changes: 10 additions & 10 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,14 @@ def refresh(

@t.overload
def get_related(
self, foreign_key: ForeignKey[ReferencedTable]
self, foreign_key: ForeignKey[ReferencedTable], many: bool = False
) -> First[ReferencedTable]: ...

@t.overload
def get_related(self, foreign_key: str) -> First[Table]: ...
def get_related(self, foreign_key: str, many: bool = False) -> First[Table]: ...

def get_related(
self, foreign_key: t.Union[str, ForeignKey[ReferencedTable]]
self, foreign_key: t.Union[str, ForeignKey[ReferencedTable]], many: bool = False
) -> t.Union[First[Table], First[ReferencedTable]]:
"""
Used to fetch a ``Table`` instance, for the target of a foreign key.
Expand Down Expand Up @@ -612,16 +612,16 @@ def get_related(

references = foreign_key._foreign_key_meta.resolved_references

return (
references.objects()
.where(
insts = references.objects().where(
foreign_key._foreign_key_meta.resolved_target_column
== getattr(self, column_name)
)
.first()
)
if many:
return insts
return insts.first()


def get_m2m(self, m2m: M2M) -> M2MGetRelated:
def get_m2m(self, m2m: M2M, reverse: t.Optional[bool] = None) -> M2MGetRelated:
"""
Get all matching rows via the join table.

Expand All @@ -632,7 +632,7 @@ def get_m2m(self, m2m: M2M) -> M2MGetRelated:
[<Genre: 1>, <Genre: 2>]

"""
return M2MGetRelated(row=self, m2m=m2m)
return M2MGetRelated(row=self, m2m=m2m, reverse=reverse)

def add_m2m(
self,
Expand Down