1
- import asyncio
2
1
from typing import List , Tuple
3
2
3
+ import pandas as pd
4
4
import torch
5
5
from FlagEmbedding import FlagReranker
6
6
7
7
from autorag .nodes .passagereranker .base import passage_reranker_node
8
- from autorag .utils .util import process_batch
8
+ from autorag .utils .util import make_batch , sort_by_scores , flatten_apply , select_top_k
9
9
10
10
11
11
@passage_reranker_node
@@ -31,48 +31,29 @@ def flag_embedding_reranker(queries: List[str], contents_list: List[List[str]],
31
31
model = FlagReranker (
32
32
model_name_or_path = model_name , use_fp16 = use_fp16
33
33
)
34
- tasks = [flag_embedding_reranker_pure (query , contents , scores , top_k , ids , model )
35
- for query , contents , scores , ids in zip (queries , contents_list , scores_list , ids_list )]
36
- loop = asyncio .get_event_loop ()
37
- results = loop .run_until_complete (process_batch (tasks , batch_size = batch ))
38
- content_result = list (map (lambda x : x [0 ], results ))
39
- id_result = list (map (lambda x : x [1 ], results ))
40
- score_result = list (map (lambda x : x [2 ], results ))
34
+ nested_list = [list (map (lambda x : [query , x ], content_list )) for query , content_list in zip (queries , contents_list )]
35
+ rerank_scores = flatten_apply (flag_embedding_run_model , nested_list , model = model , batch_size = batch )
36
+
37
+ df = pd .DataFrame ({
38
+ 'contents' : contents_list ,
39
+ 'ids' : ids_list ,
40
+ 'scores' : rerank_scores ,
41
+ })
42
+ df [['contents' , 'ids' , 'scores' ]] = df .apply (sort_by_scores , axis = 1 , result_type = 'expand' )
43
+ results = select_top_k (df , ['contents' , 'ids' , 'scores' ], top_k )
41
44
42
45
del model
43
46
if torch .cuda .is_available ():
44
47
torch .cuda .empty_cache ()
45
48
46
- return content_result , id_result , score_result
47
-
48
-
49
- async def flag_embedding_reranker_pure (query : str , contents : List [str ], scores : List [float ], top_k : int ,
50
- ids : List [str ], model ) -> Tuple [List [str ], List [str ], List [float ]]:
51
- """
52
- Rerank a list of contents based on their relevance to a query using BAAI Reranker model.
53
-
54
- :param query: The query to use for reranking
55
- :param contents: The list of contents to rerank
56
- :param scores: The list of scores retrieved from the initial ranking
57
- :param ids: The list of ids retrieved from the initial ranking
58
- :param top_k: The number of passages to be retrieved
59
- :param model: BAAI Reranker model.
60
- :return: tuple of lists containing the reranked contents, ids, and scores
61
- """
62
- input_texts = [(query , content ) for content in contents ]
63
- with torch .no_grad ():
64
- pred_scores = model .compute_score (sentence_pairs = input_texts )
65
-
66
- content_ids_probs = list (zip (contents , ids , pred_scores ))
67
-
68
- # Sort the list of pairs based on the relevance score in descending order
69
- sorted_content_ids_probs = sorted (content_ids_probs , key = lambda x : x [2 ], reverse = True )
70
-
71
- # crop with top_k
72
- if len (contents ) < top_k :
73
- top_k = len (contents )
74
- sorted_content_ids_probs = sorted_content_ids_probs [:top_k ]
49
+ return results ['contents' ].tolist (), results ['ids' ].tolist (), results ['scores' ].tolist ()
75
50
76
- content_result , id_result , score_result = zip (* sorted_content_ids_probs )
77
51
78
- return list (content_result ), list (id_result ), list (score_result )
52
+ def flag_embedding_run_model (input_texts , model , batch_size : int ):
53
+ batch_input_texts = make_batch (input_texts , batch_size )
54
+ results = []
55
+ for batch_texts in batch_input_texts :
56
+ with torch .no_grad ():
57
+ pred_scores = model .compute_score (sentence_pairs = batch_texts )
58
+ results .extend (pred_scores )
59
+ return results
0 commit comments