-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsearch.py
122 lines (97 loc) · 4.68 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from sqlalchemy import text, func, or_
import numpy as np
def full_text_search(search_term, db, DisplayCases, date_from=None, date_to=None, originating_bodies=[], importance_levels=[], respondent_states=[]):
# Start by building the base query with possible filters
query = db.session.query(
DisplayCases.judgment_url,
DisplayCases.title, DisplayCases.importance_level,
DisplayCases.judgment_date, DisplayCases.judgment_facts, DisplayCases.judgment_conclusion
)
# Apply filters if provided
if originating_bodies:
query = query.filter(DisplayCases.originating_body.in_(originating_bodies))
if importance_levels:
query = query.filter(DisplayCases.importance_level.in_(importance_levels))
if respondent_states:
query = query.filter(or_(*[DisplayCases.respondent_state.like(f"%{state}%") for state in respondent_states]))
if date_from:
query = query.filter(DisplayCases.judgment_date >= date_from)
if date_to:
query = query.filter(DisplayCases.judgment_date <= date_to)
# Only apply full-text search if there is a search term
if search_term:
squery = func.websearch_to_tsquery('english', search_term)
query = query.add_columns(
func.ts_headline('english', DisplayCases.judgment_full_text, squery,
'StartSel = <b>, StopSel = </b>, ShortWord = 3, MinWords = 50, MaxWords = 60').label('highlighted_text'),
func.ts_rank_cd(DisplayCases.search_vector, squery).label('rank')
).filter(
DisplayCases.search_vector.op('@@')(squery)
).order_by(
text('rank DESC')
)
return query
def semantic_query_normalisation(query, model, tokenizer, torch):
def generate_embedding(query):
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
# Mean pooling
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.squeeze().cpu().numpy()
# Generate the embedding for the query
query_embedding = generate_embedding(query)
# Normalize the query embedding
def normalize_single(vector):
norm = np.linalg.norm(vector)
return vector / norm
normalized_query_embedding = normalize_single(query_embedding)
return normalized_query_embedding
def semantic_search(query, index_with_ids, k, model, tokenizer, torch):
query_to_process = semantic_query_normalisation(query, model, tokenizer, torch)
distances, indices = index_with_ids.search(query_to_process[np.newaxis, :], k)
nearest_ids = indices[0]
return nearest_ids
def semantic_search_with_filters(query, db, DisplayCases, date_from=None, date_to=None, originating_bodies=[], importance_levels=[], respondent_states=[], index_with_ids=None, model=None, tokenizer=None, torch=None):
# Perform the semantic search
nearest_ids = semantic_search(query, index_with_ids, k=50, model=model, tokenizer=tokenizer, torch=torch)
# Start building the base query
query = db.session.query(
DisplayCases.judgment_url,
DisplayCases.title,
DisplayCases.importance_level,
DisplayCases.judgment_date,
DisplayCases.judgment_facts,
DisplayCases.judgment_conclusion
)
# Apply filters if provided
if originating_bodies:
query = query.filter(DisplayCases.originating_body.in_(originating_bodies))
if importance_levels:
query = query.filter(DisplayCases.importance_level.in_(importance_levels))
if respondent_states:
query = query.filter(or_(*[DisplayCases.respondent_state.like(f"%{state}%") for state in respondent_states]))
if date_from:
query = query.filter(DisplayCases.judgment_date >= date_from)
if date_to:
query = query.filter(DisplayCases.judgment_date <= date_to)
# Apply semantic search filter
if nearest_ids.size > 0:
query = query.filter(DisplayCases.id.in_(nearest_ids.tolist()))
return query
def prediction_semantic(query, db, DisplayCases, index_with_ids=None, model=None, tokenizer=None, torch=None):
# Perform the semantic search
nearest_ids = semantic_search(query, index_with_ids, k=5, model=model, tokenizer=tokenizer, torch=torch)
# Start building the base query
query = db.session.query(
DisplayCases.judgment_url,
DisplayCases.title,
DisplayCases.importance_level,
DisplayCases.judgment_date,
DisplayCases.judgment_facts,
DisplayCases.judgment_conclusion
)
# Apply semantic search filter
if nearest_ids.size > 0:
query = query.filter(DisplayCases.id.in_(nearest_ids.tolist()))
return query