Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support reranker #1532

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fern/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ navigation:
contents:
- page: LLM Backends
path: ./docs/pages/manual/llms.mdx
- page: Reranker
path: ./docs/pages/manual/reranker.mdx
- section: User Interface
contents:
- page: User interface (Gradio) Manual
Expand Down
1 change: 1 addition & 0 deletions fern/docs/pages/installation/installation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Where `<extra>` can be any of the following:
- vector-stores-qdrant: adds support for Qdrant vector store
- vector-stores-chroma: adds support for Chroma DB vector store
- vector-stores-postgres: adds support for Postgres vector store
- reranker-flagembedding: adds support for Flagembedding reranker

## Recommended Setups

Expand Down
57 changes: 57 additions & 0 deletions fern/docs/pages/manual/reranker.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Reranker

PrivateGPT supports the integration with the `Reranker` which has the potential to enhance the performance of the Retrieval-Augmented Generation (RAG) system.

Currently we only support `flagembedding` as reranker mode, in order to use it, set the `reranker.mode` property in the `settings.yaml` file to `flagembedding`.

```yaml
reranker:
mode: flagembedding
enabled: true
```

Use the `enabled` flag to toggle the `Reranker` as per requirement for optimized results.

## FlagEmbeddingReranker

To enable FlagEmbeddingReranker, set the `reranker.mode` property in the `settings.yaml` file to `flagembedding` and install the `reranker-flagembedding` extra.

```bash
poetry install --extras reranker-flagembedding
```

Download / Setup models from huggingface.

```bash
poetry run python scripts/setup
```

The FlagEmbeddingReranker can be configured using the following parameters:

- **top_n**: Represents the number of top documents to retrieve.
- **cut_off**: A threshold score for similarity below which documents are dismissed.
- **hf_model_name**: The Hugging Face model identifier for the FlagReranker.

### Behavior of Reranker

The functionality of the `Reranker` is as follows:

1. It evaluates the similarity between a query and documents retrieved by the retriever.
2. If the similarity score is less than `cut_off`, the document is excluded from the results.
3. In scenarios where the filtered documents are fewer than `top_n`, the system defaults to providing the top `top_n` documents ignoring the `cut_off` score.
4. The `hf_model_name` parameter allows users to specify the particular FlagReranker model from [Hugging Face](https://huggingface.co/) for the reranking process.

### Example Usage

To utilize the `Reranker` with your desired settings:

```yml
flagembedding_reranker:
hf_model_name: BAAI/bge-reranker-large
top_n: 5
cut_off: 0.75
```

## Conclusion

`Reranker` serves as a [Node Postprocessor](https://docs.llamaindex.ai/en/stable/module_guides/querying/node_postprocessors/root.html). With these settings, it offers a robust and flexible way to improve the performance of the RAG system by filtering and ranking the retrieved documents based on relevancy.
3,865 changes: 2,193 additions & 1,672 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _completion_to_prompt(self, completion: str) -> str:


def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.

Expand Down
Empty file.
65 changes: 65 additions & 0 deletions private_gpt/components/reranker/flagembedding_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
List,
Tuple,
)

from FlagEmbedding import FlagReranker # type: ignore
from llama_index.core.bridge.pydantic import Field
from llama_index.core.indices.postprocessor import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle

logger = logging.getLogger(__name__)


class FlagEmbeddingRerankerComponent(BaseNodePostprocessor):
"""Reranker component.

- top_n: Top N nodes to return.
- cut_off: Cut off score for nodes.

If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
Otherwise, return all nodes with score > cut_off.
"""

top_n: int = Field(10, description="Top N nodes to return.")
cut_off: float = Field(0.0, description="Cut off score for nodes.")
reranker: FlagReranker = Field(..., description="Flag Reranker model.")

@classmethod
def class_name(cls) -> str:
return "FlagEmbeddingReranker"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore], # noqa: UP006
query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]: # noqa: UP006
if query_bundle is None:
raise ValueError("Query bundle must be provided.")

logger.info("Postprocessing nodes with FlagEmbeddingReranker.")
logger.info(f"top_n: {self.top_n}, cut_off: {self.cut_off}")

query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
for node in nodes:
content = node.get_content()
sentence_pairs.append((query_str, content))

scores = self.reranker.compute_score(sentence_pairs)
for i, node in enumerate(nodes):
node.score = scores[i]

# cut off nodes with low scores
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
if len(res) > self.top_n:
logger.info(
"Number of nodes with score > cut_off is > top_n, returning all nodes with score > cut_off."
)
return res

logger.info(
"Number of nodes with score > cut_off is <= top_n, returning top_n nodes."
)
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n]
59 changes: 59 additions & 0 deletions private_gpt/components/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging

from injector import inject, singleton

from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings

logger = logging.getLogger(__name__)


@singleton
class RerankerComponent:
"""Reranker component.

- mode: Reranker mode.
- enabled: Reranker enabled.

"""

@inject
def __init__(self, settings: Settings) -> None:
if settings.reranker.enabled is False:
raise ValueError("Reranker component is not enabled.")

match settings.reranker.mode:
case "flagembedding":
logger.info(
"Initializing the reranker model in mode=%s", settings.reranker.mode
)

try:
from FlagEmbedding import FlagReranker # type: ignore

from private_gpt.components.reranker.flagembedding_reranker import (
FlagEmbeddingRerankerComponent,
)
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras reranker-flagembedding`"
) from e

path = models_path / "flagembedding_reranker"

if settings.flagembedding_reranker is None:
raise ValueError("FlagEmbeddingReranker settings is not provided.")

top_n = settings.flagembedding_reranker.top_n
cut_off = settings.flagembedding_reranker.cut_off
flagReranker = FlagReranker(
model_name_or_path=path,
)
self.nodePostPorcesser = FlagEmbeddingRerankerComponent(
top_n=top_n, cut_off=cut_off, reranker=flagReranker
)

case _:
raise ValueError(
"Reranker mode not supported, currently only support flagembedding."
)
1 change: 0 additions & 1 deletion private_gpt/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


def create_app(root_injector: Injector) -> FastAPI:

# Start the API
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
Expand Down
25 changes: 19 additions & 6 deletions private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from dataclasses import dataclass

from injector import inject, singleton
Expand All @@ -18,13 +19,17 @@
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.reranker.reranker import RerankerComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings

if typing.TYPE_CHECKING:
from llama_index.core.indices.postprocessor import BaseNodePostprocessor


class Completion(BaseModel):
response: str
Expand Down Expand Up @@ -99,6 +104,8 @@ def __init__(
embed_model=embedding_component.embedding_model,
show_progress=True,
)
if settings.reranker.enabled:
self.reranker_component = RerankerComponent(settings=settings)

def _chat_engine(
self,
Expand All @@ -113,16 +120,22 @@ def _chat_engine(
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)

node_postprocessors: list[BaseNodePostprocessor] = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
]

if self.reranker_component:
node_postprocessors.append(self.reranker_component.nodePostPorcesser)

return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
llm=self.llm_component.llm, # Takes no effect at the moment
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
],
node_postprocessors=node_postprocessors,
)
else:
return SimpleChatEngine.from_defaults(
Expand Down
2 changes: 1 addition & 1 deletion private_gpt/server/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def authenticated() -> bool:

# Method to be used as a dependency to check if the request is authenticated.
def authenticated(
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
_simple_authentication: Annotated[bool, Depends(_simple_authentication)],
) -> bool:
"""Check if the request is authenticated."""
assert settings().server.auth.enabled
Expand Down
25 changes: 25 additions & 0 deletions private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,29 @@ class NodeStoreSettings(BaseModel):
database: Literal["simple", "postgres"]


class FlagEmbeddingReRankerSettings(BaseModel):
hf_model_name: str = Field(
"BAAI/bge-reranker-large",
description="Name of the HuggingFace model to use for reranking",
)
top_n: int = Field(
5,
description="Top N nodes to return.",
)
cut_off: float = Field(
0.75,
description="Cut off score for nodes.",
)


class RerankerSettings(BaseModel):
enabled: bool = Field(
False,
description="Flag indicating if reranker is enabled or not",
)
mode: Literal["flagembedding"]


class LlamaCPPSettings(BaseModel):
llm_hf_repo_id: str
llm_hf_model_file: str
Expand Down Expand Up @@ -391,6 +414,8 @@ class Settings(BaseModel):
vectorstore: VectorstoreSettings
nodestore: NodeStoreSettings
rag: RagSettings
reranker: RerankerSettings
flagembedding_reranker: FlagEmbeddingReRankerSettings
qdrant: QdrantSettings | None = None
postgres: PostgresSettings | None = None

Expand Down
1 change: 0 additions & 1 deletion private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def build_history() -> list[ChatMessage]:
)
match mode:
case "Query Files":

# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ asyncpg = {version="^0.29.0", optional = true}
boto3 = {version ="^1.34.51", optional = true}
# Optional UI
gradio = {version ="^4.19.2", optional = true}
flagembedding = {version="^1.2.5", optional = true}

[tool.poetry.extras]
ui = ["gradio"]
Expand All @@ -57,6 +58,8 @@ vector-stores-qdrant = ["llama-index-vector-stores-qdrant"]
vector-stores-chroma = ["llama-index-vector-stores-chroma"]
vector-stores-postgres = ["llama-index-vector-stores-postgres"]
storage-nodestore-postgres = ["llama-index-storage-docstore-postgres","llama-index-storage-index-store-postgres","psycopg2-binary","asyncpg"]
reranker-flagembedding = ["flagembedding"]


[tool.poetry.group.dev.dependencies]
black = "^22"
Expand Down
21 changes: 16 additions & 5 deletions scripts/setup
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#!/usr/bin/env python3
import os
import argparse
import os

from huggingface_hub import hf_hub_download, snapshot_download
from transformers import AutoTokenizer

from private_gpt.paths import models_path, models_cache_path
from private_gpt.paths import models_cache_path, models_path
from private_gpt.settings.settings import settings

resume_download = True
if __name__ == '__main__':
parser = argparse.ArgumentParser(prog='Setup: Download models from Hugging Face')
parser.add_argument('--resume', default=True, action=argparse.BooleanOptionalAction, help='Enable/Disable resume_download options to restart the download progress interrupted')
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="Setup: Download models from huggingface")
parser.add_argument("--resume", default=True, action=argparse.BooleanOptionalAction, help="Enable/Disable resume_download options to restart the download progress interrupted")
args = parser.parse_args()
resume_download = args.resume

Expand All @@ -27,6 +27,17 @@ snapshot_download(
)
print("Embedding model downloaded!")

if settings().reranker.enabled and settings().reranker.mode == "flagembedding":
# Download Reranker model
reranker_path = models_path / "flagembedding_reranker"
print(f"Downloading reranker {settings().flagembedding_reranker.hf_model_name}")
snapshot_download(
repo_id=settings().flagembedding_reranker.hf_model_name,
cache_dir=models_cache_path,
local_dir=reranker_path,
)
print("Reranker model downloaded!")

# Download LLM and create a symlink to the model file
print(f"Downloading LLM {settings().llamacpp.llm_hf_model_file}")
hf_hub_download(
Expand Down
Loading
Loading