Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions fastapi_solo/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def exec(self, q, **kwargs):
return result.scalars()
return result

def find_or_create(self, model: Type[Base], find_by=None, flush=True, **kwargs):
def find_or_create(self, _model: Type[Base], find_by=None, flush=True, **kwargs):
"""Find or create a model

it doesnt update the model if it already exists
Expand All @@ -240,26 +240,26 @@ def find_or_create(self, model: Type[Base], find_by=None, flush=True, **kwargs):
user = User.find_or_create(find_by=["name"], name="Albert")
```
"""
q = select(model)
q = select(_model)
filters = kwargs
if find_by:
filters = {k: v for k, v in kwargs.items() if k in find_by}
else:
pks = list(map(lambda x: x.name, model.__mapper__.primary_key))
pks = list(map(lambda x: x.name, _model.__mapper__.primary_key))
if all(k in kwargs for k in pks):
filters = {k: v for k, v in kwargs.items() if k in pks}

q = q.filter_by(**filters)

ret = self.exec(q).one_or_none()
if not ret:
ret = model(**kwargs)
ret = _model(**kwargs)
self.add(ret)
if flush:
self.flush()
return ret

def upsert(self, model: Type[Base], find_by=None, flush=True, **kwargs):
def upsert(self, _model: Type[Base], find_by=None, flush=True, **kwargs):
"""Update or create a model

it will update the model if it already exists before returning it
Expand All @@ -278,10 +278,10 @@ def upsert(self, model: Type[Base], find_by=None, flush=True, **kwargs):
```
"""
if not find_by:
pks = list(map(lambda x: x.name, model.__mapper__.primary_key))
pks = list(map(lambda x: x.name, _model.__mapper__.primary_key))
if not all(k in kwargs for k in pks):
raise DbException("find_by or primary key must be provided")
e = self.find_or_create(model, find_by=find_by, flush=False, **kwargs)
e = self.find_or_create(_model, find_by=find_by, flush=False, **kwargs)
for k, v in kwargs.items():
setattr(e, k, v)
if flush:
Expand Down
Loading