-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollection.py
85 lines (77 loc) · 2.88 KB
/
collection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
from pathlib import Path
import time
from qdrant_client import models
from qdrant_client.models import Distance, VectorParams, PointStruct
from safetensors.numpy import load_file
from loguru import logger
from constants import LANGS
from db_utils import init_qdrant_client, init_sql_session
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("cur_model", type=str)
parser.add_argument("--step", type=int, default=1000)
parser.add_argument(
"--collection_name", type=str, default="sonaveeb-semantic-search"
)
parser.add_argument("--postfix", choices=["query", "passage"], default=None)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
collection_name: str = args.collection_name
if args.postfix:
collection_name = f"{collection_name}-{args.postfix}"
session_maker = init_sql_session()
file_path = Path(f"embeddings/{args.cur_model}.safetensors").resolve()
logger.info(f"Reading safetensors at {file_path}")
loaded = load_file(file_path)
logger.success(f"Loaded safetensors from {file_path}")
client = init_qdrant_client()
logger.info(f"Creating collection `{collection_name}`")
client.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=loaded["vectors"].shape[1], distance=Distance.COSINE
),
optimizers_config=models.OptimizersConfigDiff(
indexing_threshold=0,
),
)
logger.success("Created collection")
logger.info("Inserting items")
cur = time.time()
step: int = args.step
for i in range(0, loaded["vectors"].shape[0], step):
logger.debug(f"Current step: {i}")
points = []
for index, (vector, definition_id, word_id, lang) in enumerate(
zip(
loaded["vectors"].tolist()[i : i + step],
loaded["definition_ids"].tolist()[i : i + step],
loaded["word_ids"].tolist()[i : i + step],
loaded["langs"].tolist()[i : i + step],
),
start=i,
):
points.append(
PointStruct(
id=index,
vector=vector,
payload=dict(
definition_id=definition_id,
word_id=word_id,
lang=LANGS[lang],
),
)
)
client.upsert(
collection_name=collection_name,
points=points,
)
logger.success(f"All items are inserted to `{collection_name}`")
logger.info("Restoring indexing")
client.update_collection(
collection_name=collection_name,
optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
)
logger.info(f"Time elapsed: {time.time() - cur}")