Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Anhui-tqhuang committed Mar 14, 2024
1 parent 0ab035f commit 6b7d7dc
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 16 deletions.
2 changes: 1 addition & 1 deletion private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _completion_to_prompt(self, completion: str) -> str:


def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
Expand Down
20 changes: 10 additions & 10 deletions private_gpt/components/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import List, Tuple
from injector import singleton, inject
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from llama_index.bridge.pydantic import Field
from FlagEmbedding import FlagReranker
from injector import inject, singleton
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import NodeWithScore, QueryBundle

from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings


@singleton
class RerankerComponent(BaseNodePostprocessor):
"""
Reranker component:
"""Reranker component.
- top_n: Top N nodes to return.
- cut_off: Cut off score for nodes.
Expand Down Expand Up @@ -47,14 +47,14 @@ def class_name(cls) -> str:

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]:
) -> list[NodeWithScore]:
if query_bundle is None:
return ValueError("Query bundle must be provided.")

query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = []
sentence_pairs: list[tuple[str, str]] = []
for node in nodes:
content = node.get_content()
sentence_pairs.append([query_str, content])
Expand Down
1 change: 0 additions & 1 deletion private_gpt/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@


def create_app(root_injector: Injector) -> FastAPI:

# Start the API
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
Expand Down
2 changes: 1 addition & 1 deletion private_gpt/server/chat/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from llama_index.core.types import TokenGen
from pydantic import BaseModel

from private_gpt.components.reranker.reranker import RerankerComponent
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.reranker.reranker import RerankerComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
Expand Down
2 changes: 1 addition & 1 deletion private_gpt/server/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def authenticated() -> bool:

# Method to be used as a dependency to check if the request is authenticated.
def authenticated(
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
_simple_authentication: Annotated[bool, Depends(_simple_authentication)],
) -> bool:
"""Check if the request is authenticated."""
assert settings().server.auth.enabled
Expand Down
1 change: 0 additions & 1 deletion private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def build_history() -> list[ChatMessage]:
)
match mode:
case "Query Files":

# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:
Expand Down
7 changes: 6 additions & 1 deletion settings-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@ qdrant:
llm:
mode: mock

<<<<<<< HEAD
embedding:
mode: mock
=======
reranker:
enabled: false
>>>>>>> c096818 (fix: tests)

ui:
enabled: false
enabled: false

0 comments on commit 6b7d7dc

Please sign in to comment.