Skip to content

Commit a71f8a4

Browse files
authored
Merge branch 'main' into Feature/#335
2 parents f8ddc5f + b9d24ef commit a71f8a4

File tree

2 files changed

+35
-51
lines changed

2 files changed

+35
-51
lines changed
Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import asyncio
21
from typing import List, Tuple
32

3+
import pandas as pd
44
import torch
55
from FlagEmbedding import FlagReranker
66

77
from autorag.nodes.passagereranker.base import passage_reranker_node
8-
from autorag.utils.util import process_batch
8+
from autorag.utils.util import make_batch, sort_by_scores, flatten_apply, select_top_k
99

1010

1111
@passage_reranker_node
@@ -31,48 +31,29 @@ def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]],
3131
model = FlagReranker(
3232
model_name_or_path=model_name, use_fp16=use_fp16
3333
)
34-
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model)
35-
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)]
36-
loop = asyncio.get_event_loop()
37-
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
38-
content_result = list(map(lambda x: x[0], results))
39-
id_result = list(map(lambda x: x[1], results))
40-
score_result = list(map(lambda x: x[2], results))
34+
nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)]
35+
rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch)
36+
37+
df = pd.DataFrame({
38+
'contents': contents_list,
39+
'ids': ids_list,
40+
'scores': rerank_scores,
41+
})
42+
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
43+
results = select_top_k(df, ['contents', 'ids', 'scores'], top_k)
4144

4245
del model
4346
if torch.cuda.is_available():
4447
torch.cuda.empty_cache()
4548

46-
return content_result, id_result, score_result
47-
48-
49-
async def flag_embedding_reranker_pure(query: str, contents: List[str], scores: List[float], top_k: int,
50-
ids: List[str], model) -> Tuple[List[str], List[str], List[float]]:
51-
"""
52-
Rerank a list of contents based on their relevance to a query using BAAI Reranker model.
53-
54-
:param query: The query to use for reranking
55-
:param contents: The list of contents to rerank
56-
:param scores: The list of scores retrieved from the initial ranking
57-
:param ids: The list of ids retrieved from the initial ranking
58-
:param top_k: The number of passages to be retrieved
59-
:param model: BAAI Reranker model.
60-
:return: tuple of lists containing the reranked contents, ids, and scores
61-
"""
62-
input_texts = [(query, content) for content in contents]
63-
with torch.no_grad():
64-
pred_scores = model.compute_score(sentence_pairs=input_texts)
65-
66-
content_ids_probs = list(zip(contents, ids, pred_scores))
67-
68-
# Sort the list of pairs based on the relevance score in descending order
69-
sorted_content_ids_probs = sorted(content_ids_probs, key=lambda x: x[2], reverse=True)
70-
71-
# crop with top_k
72-
if len(contents) < top_k:
73-
top_k = len(contents)
74-
sorted_content_ids_probs = sorted_content_ids_probs[:top_k]
49+
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()
7550

76-
content_result, id_result, score_result = zip(*sorted_content_ids_probs)
7751

78-
return list(content_result), list(id_result), list(score_result)
52+
def flag_embedding_run_model(input_texts, model, batch_size: int):
53+
batch_input_texts = make_batch(input_texts, batch_size)
54+
results = []
55+
for batch_texts in batch_input_texts:
56+
with torch.no_grad():
57+
pred_scores = model.compute_score(sentence_pairs=batch_texts)
58+
results.extend(pred_scores)
59+
return results

autorag/nodes/passagereranker/flag_embedding_llm.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import asyncio
21
from typing import List, Tuple
32

3+
import pandas as pd
44
import torch
55
from FlagEmbedding import FlagLLMReranker
66

77
from autorag.nodes.passagereranker.base import passage_reranker_node
8-
from autorag.nodes.passagereranker.flag_embedding import flag_embedding_reranker_pure
9-
from autorag.utils.util import process_batch
8+
from autorag.nodes.passagereranker.flag_embedding import flag_embedding_run_model
9+
from autorag.utils.util import flatten_apply, sort_by_scores, select_top_k
1010

1111

1212
@passage_reranker_node
@@ -32,16 +32,19 @@ def flag_embedding_llm_reranker(queries: List[str], contents_list: List[List[str
3232
model = FlagLLMReranker(
3333
model_name_or_path=model_name, use_fp16=use_fp16
3434
)
35-
tasks = [flag_embedding_reranker_pure(query, contents, scores, top_k, ids, model)
36-
for query, contents, scores, ids in zip(queries, contents_list, scores_list, ids_list)]
37-
loop = asyncio.get_event_loop()
38-
results = loop.run_until_complete(process_batch(tasks, batch_size=batch))
39-
content_result = list(map(lambda x: x[0], results))
40-
id_result = list(map(lambda x: x[1], results))
41-
score_result = list(map(lambda x: x[2], results))
35+
nested_list = [list(map(lambda x: [query, x], content_list)) for query, content_list in zip(queries, contents_list)]
36+
rerank_scores = flatten_apply(flag_embedding_run_model, nested_list, model=model, batch_size=batch)
37+
38+
df = pd.DataFrame({
39+
'contents': contents_list,
40+
'ids': ids_list,
41+
'scores': rerank_scores,
42+
})
43+
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand')
44+
results = select_top_k(df, ['contents', 'ids', 'scores'], top_k)
4245

4346
del model
4447
if torch.cuda.is_available():
4548
torch.cuda.empty_cache()
4649

47-
return content_result, id_result, score_result
50+
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()

0 commit comments

Comments
 (0)