-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchunks_service.py
124 lines (109 loc) · 4.4 KB
/
chunks_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import TYPE_CHECKING, Literal
from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.schema import NodeWithScore
from pydantic import BaseModel, Field
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.ingest.ingest_service import IngestedDoc
if TYPE_CHECKING:
from llama_index.schema import RelatedNodeInfo
class Chunk(BaseModel):
object: Literal["context.chunk"]
score: float = Field(examples=[0.023])
document: IngestedDoc
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
previous_texts: list[str] | None = Field(
default=None,
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
)
next_texts: list[str] | None = Field(
default=None,
examples=[
[
"New leads came from Google Ads campaign.",
"The campaign was run by the Marketing Department",
]
],
)
@classmethod
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
return cls(
object="context.chunk",
score=node.score or 0.0,
document=IngestedDoc(
object="ingest.document",
doc_id=doc_id,
doc_metadata=node.metadata,
),
text=node.get_content(),
)
@singleton
class ChunksService:
@inject
def __init__(
self,
llm_component: LLMComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
) -> None:
self.vector_store_component = vector_store_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
self.query_service_context = ServiceContext.from_defaults(
llm=llm_component.llm, embed_model=embedding_component.embedding_model
)
def _get_sibling_nodes_text(
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
) -> list[str]:
explored_nodes_texts = []
current_node = node_with_score.node
for _ in range(related_number):
explored_node_info: RelatedNodeInfo | None = (
current_node.next_node if forward else current_node.prev_node
)
if explored_node_info is None:
break
explored_node = self.storage_context.docstore.get_node(
explored_node_info.node_id
)
explored_nodes_texts.append(explored_node.get_content())
current_node = explored_node
return explored_nodes_texts
def retrieve_relevant(
self,
text: str,
context_filter: ContextFilter | None = None,
limit: int = 10,
prev_next_chunks: int = 0,
) -> list[Chunk]:
index = VectorStoreIndex.from_vector_store(
self.vector_store_component.vector_store,
storage_context=self.storage_context,
service_context=self.query_service_context,
show_progress=True,
)
vector_index_retriever = self.vector_store_component.get_retriever(
index=index, context_filter=context_filter, similarity_top_k=limit
)
nodes = vector_index_retriever.retrieve(text)
nodes.sort(key=lambda n: n.score or 0.0, reverse=True)
retrieved_nodes = []
for node in nodes:
chunk = Chunk.from_node(node)
chunk.previous_texts = self._get_sibling_nodes_text(
node, prev_next_chunks, False
)
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
retrieved_nodes.append(chunk)
return retrieved_nodes