diff --git a/private_gpt/components/reranker/reranker.py b/private_gpt/components/reranker/reranker.py index 1eb406ea52..31da18f301 100644 --- a/private_gpt/components/reranker/reranker.py +++ b/private_gpt/components/reranker/reranker.py @@ -1,4 +1,9 @@ -from FlagEmbedding import FlagReranker +from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex + List, + Tuple, +) + +from FlagEmbedding import FlagReranker # type: ignore from injector import inject, singleton from llama_index.bridge.pydantic import Field from llama_index.postprocessor.types import BaseNodePostprocessor @@ -29,17 +34,13 @@ def __init__(self, settings: Settings) -> None: raise ValueError("Reranker component is not enabled.") path = models_path / "reranker" - top_n = settings.reranker.top_n - cut_off = settings.reranker.cut_off - reranker = FlagReranker( + self.top_n = settings.reranker.top_n + self.cut_off = settings.reranker.cut_off + self.reranker = FlagReranker( model_name_or_path=path, ) - super().__init__( - top_n=top_n, - reranker=reranker, - cut_off=cut_off, - ) + super().__init__() @classmethod def class_name(cls) -> str: @@ -47,24 +48,24 @@ def class_name(cls) -> str: def _postprocess_nodes( self, - nodes: list[NodeWithScore], + nodes: List[NodeWithScore], # noqa: UP006 query_bundle: QueryBundle | None = None, - ) -> list[NodeWithScore]: + ) -> List[NodeWithScore]: # noqa: UP006 if query_bundle is None: - return ValueError("Query bundle must be provided.") + raise ValueError("Query bundle must be provided.") query_str = query_bundle.query_str - sentence_pairs: list[tuple[str, str]] = [] + sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006 for node in nodes: content = node.get_content() - sentence_pairs.append([query_str, content]) + sentence_pairs.append((query_str, content)) scores = self.reranker.compute_score(sentence_pairs) for i, node in enumerate(nodes): node.score = scores[i] # cut off nodes with low scores - res = [node for node in nodes if node.score > self.cut_off] + res = [node for node in nodes if (node.score or 0.0) > self.cut_off] if len(res) > self.top_n: return res diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index 714361e13c..0551df6e8a 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -1,3 +1,4 @@ +import typing from dataclasses import dataclass from injector import inject, singleton @@ -22,6 +23,9 @@ from private_gpt.server.chunks.chunks_service import Chunk from private_gpt.settings.settings import Settings +if typing.TYPE_CHECKING: + from llama_index.postprocessor.types import BaseNodePostprocessor + class Completion(BaseModel): response: str @@ -108,7 +112,7 @@ def _chat_engine( index=self.index, context_filter=context_filter ) - node_postprocessors = [ + node_postprocessors: list[BaseNodePostprocessor] = [ MetadataReplacementPostProcessor(target_metadata_key="window"), ]