Skip to content

Commit ca4c4a2

Browse files
committed
Add support for composite unique constraint
1 parent 363d683 commit ca4c4a2

File tree

12 files changed

+888
-7
lines changed

12 files changed

+888
-7
lines changed

piccolo/apps/migrations/auto/diffable_table.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55

66
from piccolo.apps.migrations.auto.operations import (
77
AddColumn,
8+
AddConstraint,
89
AlterColumn,
910
DropColumn,
11+
DropConstraint,
1012
)
1113
from piccolo.apps.migrations.auto.serialisation import (
1214
deserialise_params,
1315
serialise_params,
1416
)
1517
from piccolo.columns.base import Column
18+
from piccolo.constraint import Constraint
1619
from piccolo.table import Table, create_table_class
1720

1821

@@ -55,6 +58,8 @@ class TableDelta:
5558
add_columns: t.List[AddColumn] = field(default_factory=list)
5659
drop_columns: t.List[DropColumn] = field(default_factory=list)
5760
alter_columns: t.List[AlterColumn] = field(default_factory=list)
61+
add_constraints: t.List[AddConstraint] = field(default_factory=list)
62+
drop_constraints: t.List[DropConstraint] = field(default_factory=list)
5863

5964
def __eq__(self, value: TableDelta) -> bool: # type: ignore
6065
"""
@@ -85,6 +90,19 @@ def __eq__(self, value) -> bool:
8590
return False
8691

8792

93+
@dataclass
94+
class ConstraintComparison:
95+
constraint: Constraint
96+
97+
def __hash__(self) -> int:
98+
return self.constraint.__hash__()
99+
100+
def __eq__(self, value) -> bool:
101+
if isinstance(value, ConstraintComparison):
102+
return self.constraint._meta.name == value.constraint._meta.name
103+
return False
104+
105+
88106
@dataclass
89107
class DiffableTable:
90108
"""
@@ -96,6 +114,7 @@ class DiffableTable:
96114
tablename: str
97115
schema: t.Optional[str] = None
98116
columns: t.List[Column] = field(default_factory=list)
117+
constraints: t.List[Constraint] = field(default_factory=list)
99118
previous_class_name: t.Optional[str] = None
100119

101120
def __post_init__(self) -> None:
@@ -189,10 +208,54 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
189208
)
190209
)
191210

211+
add_constraints = [
212+
AddConstraint(
213+
table_class_name=self.class_name,
214+
constraint_name=i.constraint._meta.name,
215+
constraint_class_name=i.constraint.__class__.__name__,
216+
constraint_class=i.constraint.__class__,
217+
params=i.constraint._meta.params,
218+
schema=self.schema,
219+
)
220+
for i in sorted(
221+
{
222+
ConstraintComparison(constraint=constraint)
223+
for constraint in self.constraints
224+
}
225+
- {
226+
ConstraintComparison(constraint=constraint)
227+
for constraint in value.constraints
228+
},
229+
key=lambda x: x.constraint._meta.name,
230+
)
231+
]
232+
233+
drop_constraints = [
234+
DropConstraint(
235+
table_class_name=self.class_name,
236+
constraint_name=i.constraint._meta.name,
237+
tablename=value.tablename,
238+
schema=self.schema,
239+
)
240+
for i in sorted(
241+
{
242+
ConstraintComparison(constraint=constraint)
243+
for constraint in value.constraints
244+
}
245+
- {
246+
ConstraintComparison(constraint=constraint)
247+
for constraint in self.constraints
248+
},
249+
key=lambda x: x.constraint._meta.name,
250+
)
251+
]
252+
192253
return TableDelta(
193254
add_columns=add_columns,
194255
drop_columns=drop_columns,
195256
alter_columns=alter_columns,
257+
add_constraints=add_constraints,
258+
drop_constraints=drop_constraints,
196259
)
197260

198261
def __hash__(self) -> int:
@@ -218,10 +281,14 @@ def to_table_class(self) -> t.Type[Table]:
218281
"""
219282
Converts the DiffableTable into a Table subclass.
220283
"""
284+
class_members: t.Dict[str, t.Any] = {}
285+
for column in self.columns:
286+
class_members[column._meta.name] = column
287+
for constraint in self.constraints:
288+
class_members[constraint._meta.name] = constraint
289+
221290
return create_table_class(
222291
class_name=self.class_name,
223292
class_kwargs={"tablename": self.tablename, "schema": self.schema},
224-
class_members={
225-
column._meta.name: column for column in self.columns
226-
},
293+
class_members=class_members,
227294
)

0 commit comments

Comments
 (0)