-
Notifications
You must be signed in to change notification settings - Fork 0
/
oaps.py
113 lines (83 loc) · 3.76 KB
/
oaps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import redis
from sentence_transformers import SentenceTransformer
from redis.commands.search.query import Query
from redis.commands.search.field import TextField, VectorField, TagField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import numpy as np
import os
import re
from img2vec_pytorch import Img2Vec
from PIL import Image
img2vec = Img2Vec(cuda=False)
IMAGE_VECTOR_DIMENSION=512
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v1')
MODELSIZE = 384
def get_db(decode=True):
try:
return redis.StrictRedis(host=os.getenv('DB_SERVICE', '127.0.0.1'),
port=int(os.getenv('DB_PORT',6379)),
password=os.getenv('DB_PWD',''),
decode_responses=decode)
except redis.exceptions.ConnectionError:
print("connection error")
def init():
indexes = get_db().execute_command("FT._LIST")
if "oaps_txt_idx" not in indexes:
index_def = IndexDefinition(prefix=["oaps:seq:"], index_type=IndexType.JSON)
schema = ( TextField("$.sentence", as_name="sentence"),
VectorField("$.embedding", "HNSW", {"TYPE": "FLOAT32", "DIM": MODELSIZE, "DISTANCE_METRIC": "COSINE"}, as_name="embedding"))
get_db(False).ft('oaps_txt_idx').create_index(schema, definition=index_def)
if "oaps_pic_idx" not in indexes:
index_def = IndexDefinition(prefix=["oaps:pic:"], index_type=IndexType.JSON)
schema = ( TagField("$.file", as_name="file"),
VectorField("$.embedding", "HNSW", {"TYPE": "FLOAT32", "DIM": IMAGE_VECTOR_DIMENSION, "DISTANCE_METRIC": "COSINE"}, as_name="embedding"))
get_db(False).ft('oaps_pic_idx').create_index(schema, definition=index_def)
def get_embedding_as_vector(text):
return model.encode(text).tolist()
def get_embedding_as_blob(text):
return model.encode(text).astype(np.float32).tobytes()
def get_image_embedding_as_vector(imagepath):
img = Image.open(imagepath).convert('RGB')
return img2vec.get_vec(img).tolist()
def get_image_embedding_as_blob(imagepath):
img = Image.open(imagepath).convert('RGB')
return img2vec.get_vec(img).astype(np.float32).tobytes()
def index_document(pk, text):
seq = 0
sentences = re.split("[//.|//!|//?]", text)
for txt_sentence in sentences:
sentence = {
'seq': seq,
'sentence': txt_sentence,
'embedding': get_embedding_as_vector(txt_sentence)
}
get_db(False).json().set("oaps:seq:{}:{}".format(pk,seq), '$', sentence)
seq = seq + 1
def check_document(text, epsilon):
res = []
sentences = re.split("[//.|//!|//?]", text)
for txt_sentence in sentences:
q = Query("@embedding:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: score}")\
.sort_by("score", asc=True)\
.return_field("score")\
.dialect(2)
p = {"vec": get_embedding_as_blob(txt_sentence), "radius": epsilon}
found = get_db(False).ft("oaps_txt_idx").search(q, p).docs
if len(found) > 0:
res.append([x['id'] for x in found])
return res
def index_image(pk, imagepath):
pic = {
'embedding':get_image_embedding_as_vector(imagepath),
'file':imagepath
}
get_db(False).json().set("oaps:pic:{}".format(pk), '$', pic)
def check_image(imagepath, epsilon):
q = Query("@embedding:[VECTOR_RANGE $radius $vec]=>{$YIELD_DISTANCE_AS: score}")\
.sort_by("score", asc=True)\
.return_field("score")\
.return_field("file")\
.dialect(2)
p = {"vec": get_image_embedding_as_blob(imagepath), "radius": epsilon}
res = get_db(False).ft("oaps_pic_idx").search(q, p).docs
return res