Skip to content

Commit

Permalink
fix: type hionts
Browse files Browse the repository at this point in the history
  • Loading branch information
Anhui-tqhuang committed Mar 14, 2024
1 parent 6b7d7dc commit a18dd55
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
31 changes: 16 additions & 15 deletions private_gpt/components/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from FlagEmbedding import FlagReranker
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
List,
Tuple,
)

from FlagEmbedding import FlagReranker # type: ignore
from injector import inject, singleton
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor
Expand Down Expand Up @@ -29,42 +34,38 @@ def __init__(self, settings: Settings) -> None:
raise ValueError("Reranker component is not enabled.")

path = models_path / "reranker"
top_n = settings.reranker.top_n
cut_off = settings.reranker.cut_off
reranker = FlagReranker(
self.top_n = settings.reranker.top_n
self.cut_off = settings.reranker.cut_off
self.reranker = FlagReranker(
model_name_or_path=path,
)

super().__init__(
top_n=top_n,
reranker=reranker,
cut_off=cut_off,
)
super().__init__()

@classmethod
def class_name(cls) -> str:
return "Reranker"

def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
nodes: List[NodeWithScore], # noqa: UP006
query_bundle: QueryBundle | None = None,
) -> list[NodeWithScore]:
) -> List[NodeWithScore]: # noqa: UP006
if query_bundle is None:
return ValueError("Query bundle must be provided.")
raise ValueError("Query bundle must be provided.")

query_str = query_bundle.query_str
sentence_pairs: list[tuple[str, str]] = []
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
for node in nodes:
content = node.get_content()
sentence_pairs.append([query_str, content])
sentence_pairs.append((query_str, content))

scores = self.reranker.compute_score(sentence_pairs)
for i, node in enumerate(nodes):
node.score = scores[i]

# cut off nodes with low scores
res = [node for node in nodes if node.score > self.cut_off]
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
if len(res) > self.top_n:
return res

Expand Down
6 changes: 5 additions & 1 deletion private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from dataclasses import dataclass

from injector import inject, singleton
Expand All @@ -23,6 +24,9 @@
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings

if typing.TYPE_CHECKING:
from llama_index.postprocessor.types import BaseNodePostprocessor


class Completion(BaseModel):
response: str
Expand Down Expand Up @@ -108,7 +112,7 @@ def _chat_engine(
index=self.index, context_filter=context_filter
)

node_postprocessors = [
node_postprocessors: list[BaseNodePostprocessor] = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
]

Expand Down

0 comments on commit a18dd55

Please sign in to comment.