-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
142 lines (114 loc) · 4.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Main script.
This script does two things:
1. Populates weaviate instance with data
2. Runs search query against indexed and vectorized data
"""
import json
import pandas as pd
from tqdm import tqdm
from weaviate import Client
from src import config
from src.data import WeaviateDateLoader
from src.utils import (
add_schema,
pprint_response,
print_weaviate_class_stats,
split_string,
)
def load_data(client: Client, data: pd.DataFrame) -> None:
"""Loads data into Weaviate instance.
Parameters
----------
client : Client
connection to Weaviate instance
data : pd.DataFrame
data that will be loaded into Weaviate instance
"""
with WeaviateDateLoader(client) as loader:
for _, row in tqdm(data.iterrows(), total=len(data), ascii=True):
# some articles have multiple authors
# in order to create `Author` object for each author
# we need to unfold articles
# for example:
# ----------------------------
# -------------------------------------- | "article_1" | "author_1" |
# | "article_1" | "author_1, author_2" | --> ----------------------------
# -------------------------------------- | "article_1" | "author_2" |
# ----------------------------
# Weaviate data loader will not sent
# the same object to weaviate more than once
row["author_name"] = split_string(row["author_name"])
row = row.to_frame().T.explode("author_name")
for _, inner_row in row.iterrows():
inner_row = dict(inner_row)
inner_row["article_keywords"] = split_string(inner_row["article_keywords"])
inner_row["article_descriptionWordCount"] = inner_row["article_description"].count(" ") + 1
loader.load(inner_row)
def search(client: Client) -> dict:
"""Search example.
Parameters
----------
client : Client
connection to Weaviate instance
Returns
-------
dict
dictionary with response
"""
# ----------- CONSTRUCT QUERY REQUEST ----------- #
# concepts for vector search
nearText = {
"concepts": ["banks hedge fonds predictions"],
"certainty": 0.5,
}
# against what class to run query
class_name = "Article"
# what fields to return
properties = [
"title",
"keywords",
"short_description",
"_additional {certainty}",
]
# filter candidates by keywords
where_filter = {
"operator": "Equal",
"path": ["keywords"],
"valueText": ["bonds"],
}
# ---------------- RUN QUERY ---------------- #
response = ( # noqa: says that expression is too complex
client.query.get(class_name=class_name, properties=properties)
.with_near_text(content=nearText)
.with_where(content=where_filter)
.with_limit(limit=5)
.do()
)
return response
def main() -> None:
"""Main function.
Creates connection to Weaviate instance, creates schema inside it, loads data,
outputs information with number of objects per each class
and runs example search query as a sanity check.
"""
# --------- CREATE CONNECTION --------- #
client = Client(f"{config.weaviate.instance.host}:{config.weaviate.instance.port}")
# ----------- CREATE SCHEMA ----------- #
with open(config.weaviate.schema.path) as fin:
schema = json.load(fin)
add_schema(client, schema)
# ------------- LOAD DATA ------------- #
data = pd.read_csv(config.data.path.interim)
# as different classes might have the same name, in order to access data correctly
# we need to rename columns to a format that is expected by data loader:
# [class_name]_[property_name] -> article_title, author_name, ..
data = data.rename(columns=config.data.loader.names_map)
load_data(client, data)
# --------- PRINT CLASS STATS ---------- #
# show number of objects per each class
print_weaviate_class_stats(client)
# ------------- DO SEARCH -------------- #
response = search(client)
pprint_response(response)
if __name__ == "__main__":
main()