Skip to content

Commit

Permalink
Refactor delete file into a method
Browse files Browse the repository at this point in the history
  • Loading branch information
trducng committed May 25, 2024
1 parent 6110302 commit bfff3c9
Showing 1 changed file with 24 additions and 34 deletions.
58 changes: 24 additions & 34 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,28 @@ def finish(self, file_id: str, file_path: Path) -> str:

return file_id

def delete_file(self, file_id: str):
"""Delete a file from the db, including its chunks in docstore and vectorstore
Args:
file_id: the file id
"""
with Session(engine) as session:
session.execute(delete(self.Source).where(self.Source.id == file_id))
vs_ids, ds_ids = [], []
index = session.execute(
select(self.Index).where(self.Index.source_id == file_id)
).all()
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self.VS.delete(vs_ids)
self.DS.delete(ds_ids)

def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
"""Index the file and return the file id"""
# check for duplication
Expand All @@ -396,23 +418,7 @@ def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
)
else:
# remove the existing records
with Session(engine) as session:
session.execute(
delete(self.Source).where(self.Source.id == file_id)
)
vs_ids, ds_ids = [], []
index = session.execute(
select(self.Index).where(self.Index.source_id == file_id)
).all()
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self.VS.delete(vs_ids)
self.DS.delete(ds_ids)
self.delete_file(file_id)
file_id = self.store_file(file_path)
else:
# add record to db
Expand Down Expand Up @@ -442,23 +448,7 @@ def stream(
else:
# remove the existing records
yield Document(f" => Removing old {file_path.name}", channel="debug")
with Session(engine) as session:
session.execute(
delete(self.Source).where(self.Source.id == file_id)
)
vs_ids, ds_ids = [], []
index = session.execute(
select(self.Index).where(self.Index.source_id == file_id)
).all()
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self.VS.delete(vs_ids)
self.DS.delete(ds_ids)
self.delete_file(file_id)
file_id = self.store_file(file_path)
else:
# add record to db
Expand Down

0 comments on commit bfff3c9

Please sign in to comment.