Skip to content

Commit

Permalink
fix: black
Browse files Browse the repository at this point in the history
  • Loading branch information
Anhui-tqhuang committed Mar 21, 2024
1 parent 8d0e0cb commit 841158e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 47 deletions.
33 changes: 14 additions & 19 deletions private_gpt/components/reranker/flagembedding_reranker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
List,
Tuple,
Expand All @@ -8,8 +9,7 @@
from llama_index.core.indices.postprocessor import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle

from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)


class FlagEmbeddingRerankerComponent(BaseNodePostprocessor):
Expand All @@ -22,23 +22,9 @@ class FlagEmbeddingRerankerComponent(BaseNodePostprocessor):
Otherwise, return all nodes with score > cut_off.
"""

reranker: FlagReranker = Field(description="Reranker class.")
top_n: int = Field(description="Top N nodes to return.")
cut_off: float = Field(description="Cut off score for nodes.")

def __init__(self, settings: Settings) -> None:
path = models_path / "flagembedding_reranker"
top_n = settings.flagembedding_reranker.top_n
cut_off = settings.flagembedding_reranker.cut_off
reranker = FlagReranker(
model_name_or_path=path,
)

super().__init__(
top_n=top_n,
cut_off=cut_off,
reranker=reranker,
)
top_n: int = Field(10, description="Top N nodes to return.")
cut_off: float = Field(0.0, description="Cut off score for nodes.")
reranker: FlagReranker = Field(..., description="Flag Reranker model.")

@classmethod
def class_name(cls) -> str:
Expand All @@ -52,6 +38,9 @@ def _postprocess_nodes(
if query_bundle is None:
raise ValueError("Query bundle must be provided.")

logger.info("Postprocessing nodes with FlagEmbeddingReranker.")
logger.info(f"top_n: {self.top_n}, cut_off: {self.cut_off}")

query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
for node in nodes:
Expand All @@ -65,6 +54,12 @@ def _postprocess_nodes(
# cut off nodes with low scores
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
if len(res) > self.top_n:
logger.info(
"Number of nodes with score > cut_off is > top_n, returning all nodes with score > cut_off."
)
return res

logger.info(
"Number of nodes with score > cut_off is <= top_n, returning top_n nodes."
)
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n]
44 changes: 17 additions & 27 deletions private_gpt/components/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
import logging
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
List,
)

from injector import inject, singleton
from llama_index.core.bridge.pydantic import Field
from llama_index.core.indices.postprocessor import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle

from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings

logger = logging.getLogger(__name__)


@singleton
class RerankerComponent(BaseNodePostprocessor):
class RerankerComponent:
"""Reranker component.
- mode: Reranker mode.
- enabled: Reranker enabled.
"""

nodePostPorcesser: BaseNodePostprocessor = Field(
description="BaseNodePostprocessor class."
)

@inject
def __init__(self, settings: Settings) -> None:
if settings.reranker.enabled is False:
Expand All @@ -38,6 +29,8 @@ def __init__(self, settings: Settings) -> None:
)

try:
from FlagEmbedding import FlagReranker # type: ignore

from private_gpt.components.reranker.flagembedding_reranker import (
FlagEmbeddingRerankerComponent,
)
Expand All @@ -46,24 +39,21 @@ def __init__(self, settings: Settings) -> None:
"Local dependencies not found, install with `poetry install --extras reranker-flagembedding`"
) from e

nodePostPorcesser = FlagEmbeddingRerankerComponent(settings)
path = models_path / "flagembedding_reranker"

if settings.flagembedding_reranker is None:
raise ValueError("FlagEmbeddingReranker settings is not provided.")

top_n = settings.flagembedding_reranker.top_n
cut_off = settings.flagembedding_reranker.cut_off
flagReranker = FlagReranker(
model_name_or_path=path,
)
self.nodePostPorcesser = FlagEmbeddingRerankerComponent(
top_n=top_n, cut_off=cut_off, reranker=flagReranker
)

case _:
raise ValueError(
"Reranker mode not supported, currently only support flagembedding."
)

super().__init__(
nodePostPorcesser=nodePostPorcesser,
)

@classmethod
def class_name(cls) -> str:
return "Reranker"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore], # noqa: UP006
query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]: # noqa: UP006
return self.nodePostPorcesser._postprocess_nodes(nodes, query_bundle)
2 changes: 1 addition & 1 deletion private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _chat_engine(
]

if self.reranker_component:
node_postprocessors.append(self.reranker_component)
node_postprocessors.append(self.reranker_component.nodePostPorcesser)

return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
Expand Down

0 comments on commit 841158e

Please sign in to comment.