Skip to content

Commit 8eec885

Browse files
committed
Improved relationship support & customizable results set classes
1 parent 105c84f commit 8eec885

File tree

7 files changed

+109
-45
lines changed

7 files changed

+109
-45
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlorm-py"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
description = "A new kind or ORM that do not abstract away your database or SQL queries."
55
authors = [
66
{"name" = "Maxime Bouroumeau-Fuseau", email = "maxime.bouroumeau@gmail.com"}

src/sqlorm/engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -465,19 +465,19 @@ def executemany(self, stmt, seq_of_parameters):
465465
self.after_execute.send(self, cursor=cur, stmt=stmt, params=seq_of_parameters, many=True)
466466
cur.close()
467467

468-
def fetch(self, stmt, params=None, model=None, obj=None, loader=None):
468+
def fetch(self, stmt, params=None, model=None, obj=None, loader=None, resultset_class=None):
469469
cursor = self.cursor(stmt, params)
470470

471471
if obj and not model:
472472
model = obj.__class__
473473
if model:
474-
rs = HydratedResultSet(cursor, model)
474+
rs = (resultset_class or HydratedResultSet)(cursor, model)
475475
if obj:
476476
rs.mapper.hydrate(obj, rs.first(with_loader=False))
477477
return obj
478478
return rs
479479

480-
return ResultSet(cursor, loader)
480+
return (resultset_class or ResultSet)(cursor, loader)
481481

482482
def fetchall(self, stmt, params=None, **fetch_kwargs):
483483
return self.fetch(stmt, params, **fetch_kwargs).all()
@@ -501,6 +501,7 @@ def fetchcomposite(
501501
obj=None,
502502
map=None,
503503
separator=None,
504+
resultset_class=CompositeResultSet,
504505
):
505506
if obj and not model:
506507
model = obj.__class__
@@ -514,7 +515,7 @@ def fetchcomposite(
514515
elif not map:
515516
map = CompositionMap.create([loader, nested])
516517

517-
rs = CompositeResultSet(
518+
rs = resultset_class(
518519
self.cursor(stmt, params), map, separator or self.default_composite_separator
519520
)
520521
if obj:

src/sqlorm/mapper.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,17 @@ def hydrate_new(self, data):
165165
def hydrate(self, obj, data, with_unknown=None):
166166
"""Populates a model object with data from the db"""
167167
attrs = set()
168+
relnames = {r.attribute: r for r in self.relationships}
168169
for key in (
169170
data.keys()
170171
): # avoid using .items() as some DBAPI returned objects only provide keys() (eg: sqlite3)
171172
if key in self.columns.names:
172173
col = self.columns[key]
173174
col.load(obj, data)
174175
key = col.attribute
176+
elif key in relnames:
177+
rel = relnames[key]
178+
rel.load(obj, data)
175179
elif with_unknown or with_unknown is None and self.allow_unknown_columns:
176180
obj.__dict__[key] = data[key] # ensure that no custom setter are used
177181
else:
@@ -414,8 +418,8 @@ def __init__(
414418
self,
415419
target_mapper,
416420
target_col=None,
417-
target_attr=None,
418421
source_col=None,
422+
target_attr=None,
419423
source_attr=None,
420424
join_type="LEFT JOIN",
421425
join_condition=None,
@@ -424,7 +428,7 @@ def __init__(
424428
lazy=True,
425429
):
426430
self._target_mapper = target_mapper
427-
self.target_col = target_col
431+
self._target_col = target_col
428432
self._target_attr = target_attr
429433
self._source_col = source_col
430434
self._source_attr = source_attr
@@ -443,6 +447,16 @@ def target_mapper(self) -> Mapper:
443447
def target_table(self):
444448
return self.target_mapper.table
445449

450+
@property
451+
def target_col(self):
452+
if self._target_col:
453+
return self._target_col
454+
if not self.single:
455+
raise MapperError(f"Missing target_col on relationship '{self.attribute}'")
456+
if self.target_mapper.primary_key:
457+
return self.target_mapper.primary_key.name
458+
return "id"
459+
446460
@property
447461
def target_attr(self):
448462
if self._target_attr:
@@ -493,8 +507,6 @@ def join_condition(self, target_alias=None, source_alias=None) -> SQL:
493507
)
494508
elif self._join_condition:
495509
return self._join_condition
496-
if not self.target_col:
497-
raise MapperError(f"Missing target_col on relationship '{self.attribute}'")
498510
return SQL.Col(self.target_col, table=target_alias) == SQL.Col(
499511
self.source_col, table=source_alias
500512
)
@@ -538,6 +550,23 @@ def delete_from_target(self, source_obj):
538550
return SQL.delete_from(self.target_table).where(
539551
SQL.Col(self.target_col) == SQL.Param(getattr(source_obj, source_attr))
540552
)
553+
554+
def update_related_objs(self, source_obj, target_obj):
555+
target_attr = self.target_attr
556+
if not target_attr:
557+
raise MapperError(
558+
f"Missing target_attr on relationship '{self.attribute}'"
559+
)
560+
if self.single:
561+
setattr(source_obj, self.source_attr, getattr(target_obj, target_attr) if target_obj else None)
562+
else:
563+
setattr(target_obj, target_attr, None if source_obj is None else getattr(source_obj, self.source_attr, None))
564+
565+
def load(self, obj, values):
566+
"""Sets the object attribute from the database row, using the provided load function if needed
567+
(used by Mapper.hydrate())
568+
"""
569+
obj.__dict__[self.attribute] = values[self.attribute]
541570

542571
def __repr__(self):
543572
return f"<Relationship({self.attribute})>"
@@ -552,14 +581,18 @@ def __init__(self, cursor, mapper, auto_close_cursor=True):
552581

553582

554583
class HydrationMap(CompositionMap):
555-
def __init__(self, mapper, nested=None, single=False):
584+
def __init__(self, mapper, nested=None, single=False, _already_visited=None):
556585
if not isinstance(mapper, Mapper):
557586
mapper = Mapper.from_class(mapper)
587+
if _already_visited is None:
588+
_already_visited = set()
558589
rowid = mapper.primary_key.name if mapper.primary_key else None
559590
_nested = {
560-
r.attribute: HydrationMap(r.target_mapper, single=r.single)
591+
r.attribute: HydrationMap(r.target_mapper, single=r.single, _already_visited={mapper} | _already_visited)
561592
for r in mapper.relationships
593+
if r.target_mapper not in _already_visited
562594
}
595+
_already_visited.add(mapper)
563596
if nested:
564597
_nested.update({k: HydrationMap.create(v) for k, v in nested.items()})
565598
super().__init__(mapper.hydrate_new, _nested, rowid, single)

src/sqlorm/model.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def target_mapper(self):
256256

257257
@property
258258
def __isabstractmethod__(self):
259-
# compatibility between this description __getattr__ usage and abc.ABC
259+
# compatibility between this descriptor __getattr__ usage and abc.ABC
260260
return False
261261

262262
def __get__(self, obj, owner=None):
@@ -267,18 +267,29 @@ def __get__(self, obj, owner=None):
267267
return obj.__dict__[self.attribute]
268268

269269
def __set__(self, obj, value):
270-
obj.__dict__[self.attribute] = value if self.single else self.list_class(obj, self, value)
270+
if self.single:
271+
obj.__dict__[self.attribute] = value
272+
self.update_related_objs(obj, value)
273+
flag_dirty_attr(obj, self.source_attr)
274+
else:
275+
obj.__dict__[self.attribute] = self.list_class(obj, self, [])
276+
for item in value:
277+
obj.__dict__[self.attribute].append(item) # ensure target_attr is set
271278

272279
def is_loaded(self, obj):
273280
return self.attribute in obj.__dict__
274281

275282
def fetch(self, obj):
276283
"""Fetches the list of related objects from the database and loads it in the object"""
277284
r = self.target.query(self.select_from_target(obj))
278-
self.__set__(obj, r.first() if self.single else r.all())
285+
obj.__dict__[self.attribute] = r.first() if self.single else self.list_class(obj, self, r.all())
286+
287+
def load(self, obj, values):
288+
value = values[self.attribute]
289+
obj.__dict__[self.attribute] = value if self.single else self.list_class(obj, self, value)
279290

280291

281-
class RelatedObjectsList(object):
292+
class RelatedObjectsList:
282293
def __init__(self, obj, relationship, items):
283294
self.obj = obj
284295
self.relationship = relationship
@@ -297,20 +308,12 @@ def __contains__(self, item):
297308
return item in self.items
298309

299310
def append(self, item):
300-
target_attr = self.relationship.target_attr
301-
if not target_attr:
302-
raise MapperError(
303-
f"Missing target_attr on relationship '{self.relationship.attribute}'"
304-
)
305-
setattr(item, target_attr, getattr(self.obj, self.relationship.source_attr))
311+
self.relationship.update_related_objs(self.obj, item)
312+
flag_dirty_attr(item, self.relationship.target_attr)
306313

307314
def remove(self, item):
308-
target_attr = self.relationship.target_attr
309-
if not target_attr:
310-
raise MapperError(
311-
f"Missing target_attr on relationship '{self.relationship.attribute}'"
312-
)
313-
setattr(item, target_attr, None)
315+
self.relationship.update_related_objs(None, item)
316+
flag_dirty_attr(item, self.relationship.target_attr)
314317

315318

316319
def flag_dirty_attr(obj, attr):
@@ -367,6 +370,7 @@ def __sql__(self):
367370

368371
class Model(BaseModel, abc.ABC):
369372
"""Our standard model class with CRUD methods"""
373+
__resultset_class__ = CompositeResultSet
370374

371375
class Meta:
372376
insert_update_dirty_only: bool = (
@@ -397,12 +401,12 @@ def query(cls, stmt, params=None) -> CompositeResultSet:
397401
with ensure_transaction(cls.__engine__) as tx:
398402
rv = _signal_rv(cls.before_query.send(cls, stmt=stmt, params=params))
399403
if rv is False:
400-
return ResultSet(None)
404+
return cls.__resultset_class__(None)
401405
if isinstance(rv, ResultSet):
402406
return rv
403407
if isinstance(rv, tuple):
404408
stmt, params = rv
405-
return tx.fetchhydrated(cls, stmt, params)
409+
return tx.fetchhydrated(cls, stmt, params, resultset_class=cls.__resultset_class__)
406410

407411
@classmethod
408412
def find_all(
@@ -473,8 +477,11 @@ def __init__(self, **values):
473477
setattr(self, k, v)
474478

475479
def __setattr__(self, name, value):
476-
self.__dict__[name] = value
477-
flag_dirty_attr(self, name)
480+
if isinstance(getattr(self.__class__, name, None), (ModelColumnMixin, Relationship)):
481+
super().__setattr__(name, value)
482+
else:
483+
self.__dict__[name] = value
484+
flag_dirty_attr(self, name)
478485

479486
def refresh(self, **select_kwargs):
480487
stmt = self.__mapper__.select_by_pk(self.__mapper__.get_primary_key(self), **select_kwargs)

src/sqlorm/sqlfunc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def query_builder(*args, **kwargs):
4242

4343
def is_sqlfunc(func):
4444
"""Checks if func is an empty function with a python doc"""
45+
if not inspect.isfunction(func):
46+
return False
4547
doc = inspect.getdoc(func)
4648
src = inspect.getsource(func).strip(' "\n\r')
4749
return doc and src.endswith(func.__doc__.strip(' \n\r')) and not getattr(func, "sqlfunc", False)

tests/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class Task(Model):
77
title: str
88
done: bool = Column("completed")
99
user_id: Integer = Column(references="users(id)")
10+
user = Relationship("User", "id", "user_id", single=True)
1011

1112
@classmethod
1213
def find_todos(cls):

tests/test_model.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,6 @@ def on_after_refresh(sender, obj):
175175
assert listener_called == 2
176176

177177

178-
def test_relationships(engine):
179-
with engine:
180-
user = User.get(1)
181-
assert not User.tasks.is_loaded(user)
182-
assert len(user.tasks) == 1
183-
assert User.tasks.is_loaded(user)
184-
assert isinstance(user.tasks[0], Task)
185-
assert user.tasks[0].id == 1
186-
187-
user = User.get(1, with_rels=True)
188-
assert User.tasks.is_loaded(user)
189-
assert len(user.tasks) == 1
190-
191-
192178
def test_sql_methods(engine):
193179
assert Task.toggle.query_decorator == "update"
194180
assert Task.find_todos.query_decorator == "fetchall"
@@ -386,3 +372,37 @@ def on_after_delete(sender, obj):
386372

387373
user = User.get(4)
388374
assert not user
375+
376+
377+
def test_relationships_many(engine):
378+
with engine:
379+
user = User.get(1)
380+
assert not User.tasks.is_loaded(user)
381+
assert len(user.tasks) == 1
382+
assert User.tasks.is_loaded(user)
383+
assert isinstance(user.tasks[0], Task)
384+
assert user.tasks[0].id == 1
385+
386+
user = User.get(1, with_rels=True)
387+
assert User.tasks.is_loaded(user)
388+
assert len(user.tasks) == 1
389+
390+
user = User.get(1)
391+
task = Task()
392+
user.tasks.append(task)
393+
assert task.user_id == user.id
394+
user.tasks.remove(task)
395+
assert task.user_id is None
396+
397+
398+
def test_relationships_single(engine):
399+
with engine:
400+
task = Task.get(1)
401+
assert not Task.user.is_loaded(task)
402+
assert task.user.id == 1
403+
assert Task.user.is_loaded(task)
404+
405+
user = User.get(1)
406+
task = Task()
407+
task.user = user
408+
assert task.user_id == user.id

0 commit comments

Comments
 (0)