@@ -256,7 +256,7 @@ def target_mapper(self):
256
256
257
257
@property
258
258
def __isabstractmethod__ (self ):
259
- # compatibility between this description __getattr__ usage and abc.ABC
259
+ # compatibility between this descriptor __getattr__ usage and abc.ABC
260
260
return False
261
261
262
262
def __get__ (self , obj , owner = None ):
@@ -267,18 +267,29 @@ def __get__(self, obj, owner=None):
267
267
return obj .__dict__ [self .attribute ]
268
268
269
269
def __set__ (self , obj , value ):
270
- obj .__dict__ [self .attribute ] = value if self .single else self .list_class (obj , self , value )
270
+ if self .single :
271
+ obj .__dict__ [self .attribute ] = value
272
+ self .update_related_objs (obj , value )
273
+ flag_dirty_attr (obj , self .source_attr )
274
+ else :
275
+ obj .__dict__ [self .attribute ] = self .list_class (obj , self , [])
276
+ for item in value :
277
+ obj .__dict__ [self .attribute ].append (item ) # ensure target_attr is set
271
278
272
279
def is_loaded (self , obj ):
273
280
return self .attribute in obj .__dict__
274
281
275
282
def fetch (self , obj ):
276
283
"""Fetches the list of related objects from the database and loads it in the object"""
277
284
r = self .target .query (self .select_from_target (obj ))
278
- self .__set__ (obj , r .first () if self .single else r .all ())
285
+ obj .__dict__ [self .attribute ] = r .first () if self .single else self .list_class (obj , self , r .all ())
286
+
287
+ def load (self , obj , values ):
288
+ value = values [self .attribute ]
289
+ obj .__dict__ [self .attribute ] = value if self .single else self .list_class (obj , self , value )
279
290
280
291
281
- class RelatedObjectsList ( object ) :
292
+ class RelatedObjectsList :
282
293
def __init__ (self , obj , relationship , items ):
283
294
self .obj = obj
284
295
self .relationship = relationship
@@ -297,20 +308,12 @@ def __contains__(self, item):
297
308
return item in self .items
298
309
299
310
def append (self , item ):
300
- target_attr = self .relationship .target_attr
301
- if not target_attr :
302
- raise MapperError (
303
- f"Missing target_attr on relationship '{ self .relationship .attribute } '"
304
- )
305
- setattr (item , target_attr , getattr (self .obj , self .relationship .source_attr ))
311
+ self .relationship .update_related_objs (self .obj , item )
312
+ flag_dirty_attr (item , self .relationship .target_attr )
306
313
307
314
def remove (self , item ):
308
- target_attr = self .relationship .target_attr
309
- if not target_attr :
310
- raise MapperError (
311
- f"Missing target_attr on relationship '{ self .relationship .attribute } '"
312
- )
313
- setattr (item , target_attr , None )
315
+ self .relationship .update_related_objs (None , item )
316
+ flag_dirty_attr (item , self .relationship .target_attr )
314
317
315
318
316
319
def flag_dirty_attr (obj , attr ):
@@ -367,6 +370,7 @@ def __sql__(self):
367
370
368
371
class Model (BaseModel , abc .ABC ):
369
372
"""Our standard model class with CRUD methods"""
373
+ __resultset_class__ = CompositeResultSet
370
374
371
375
class Meta :
372
376
insert_update_dirty_only : bool = (
@@ -397,12 +401,12 @@ def query(cls, stmt, params=None) -> CompositeResultSet:
397
401
with ensure_transaction (cls .__engine__ ) as tx :
398
402
rv = _signal_rv (cls .before_query .send (cls , stmt = stmt , params = params ))
399
403
if rv is False :
400
- return ResultSet (None )
404
+ return cls . __resultset_class__ (None )
401
405
if isinstance (rv , ResultSet ):
402
406
return rv
403
407
if isinstance (rv , tuple ):
404
408
stmt , params = rv
405
- return tx .fetchhydrated (cls , stmt , params )
409
+ return tx .fetchhydrated (cls , stmt , params , resultset_class = cls . __resultset_class__ )
406
410
407
411
@classmethod
408
412
def find_all (
@@ -473,8 +477,11 @@ def __init__(self, **values):
473
477
setattr (self , k , v )
474
478
475
479
def __setattr__ (self , name , value ):
476
- self .__dict__ [name ] = value
477
- flag_dirty_attr (self , name )
480
+ if isinstance (getattr (self .__class__ , name , None ), (ModelColumnMixin , Relationship )):
481
+ super ().__setattr__ (name , value )
482
+ else :
483
+ self .__dict__ [name ] = value
484
+ flag_dirty_attr (self , name )
478
485
479
486
def refresh (self , ** select_kwargs ):
480
487
stmt = self .__mapper__ .select_by_pk (self .__mapper__ .get_primary_key (self ), ** select_kwargs )
0 commit comments