-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvectorize.py
81 lines (73 loc) · 2.81 KB
/
vectorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
from sqlalchemy import select
from safetensors.numpy import save_file
import numpy as np
from loguru import logger
from model_utils import init_model
from constants import LANGS, PREFIXES
from schemas import Definition
from db_utils import init_sql_session
lang2id = {lang: i for i, lang in enumerate(LANGS)}
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="Name of the model")
parser.add_argument("--n_items", type=int, default=None)
parser.add_argument("--prefix", choices=["query", "passage"], default=None)
parser.add_argument("--show_progress_bar", action="store_true")
parser.add_argument("--random_weights", action="store_true")
parser.add_argument("--batch_size", type=int, default=1024)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
session_maker = init_sql_session()
model_name = args.model_name
model_name, model = init_model(model_name, args.random_weights)
logger.info("Reading from the database")
with session_maker() as session:
statement = (
select(
Definition.definition_id,
Definition.value,
Definition.word_id,
Definition.lang,
)
.distinct()
.limit(args.n_items)
)
definition_records = session.execute(statement)
definition_ids = []
definitions = []
word_ids = []
langs = []
for definition_id, value, word_id, lang in definition_records:
definition_ids.append(definition_id)
definitions.append(value)
word_ids.append(word_id)
langs.append(lang)
if args.prefix:
logger.info(f"Prepending definitions with prefix `{args.prefix}`")
cur_definitions = [
f"{PREFIXES[args.prefix]}{definition}" for definition in definitions
]
else:
cur_definitions = definitions
logger.info("Encoding definitions")
vectors = model.encode(
cur_definitions,
batch_size=args.batch_size,
show_progress_bar=args.show_progress_bar,
)
assert len(definition_ids) == vectors.shape[0]
output = dict(
vectors=vectors,
definition_ids=np.array(definition_ids),
word_ids=np.array(word_ids),
langs=np.array([lang2id[lang] for lang in langs]),
)
if args.prefix:
file_name = f"{model_name}-{args.prefix}.safetensors"
else:
file_name = f"{model_name}.safetensors"
logger.info(f"Saving to `{file_name}`")
save_file(tensor_dict=output, filename=file_name)
logger.success(f"Successfully saved to `{file_name}`")