Skip to content

Commit 2c060f1

Browse files
committed
Add generic entities
1 parent 29171c3 commit 2c060f1

File tree

13 files changed

+111
-89
lines changed

13 files changed

+111
-89
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
3131
## News (and Coming Soon)
3232
- [ ] Support for IDF weightening of entities
33-
- [ ] Support for generic entities and concepts
33+
- [x] Support for generic entities and concepts (initial commit)
3434
- [x] [2024.12.02] Benchmarks comparing Fast GraphRAG to LightRAG, GraphRAG and VectorDBs released [here](https://github.com/circlemind-ai/fast-graphrag/blob/main/benchmarks/README.md)
3535

3636
## Features

fast_graphrag/_models.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def dump_to_csv(
5353
chain(
5454
separator.join(
5555
chain(
56-
(str(getattr(d, field)).replace("\t", " ") for field in fields),
57-
(str(v).replace("\t", " ") for v in vs),
56+
(str(getattr(d, field)).replace("\n", " ").replace("\t", " ") for field in fields),
57+
(str(v).replace("\n", " ").replace("\t", " ") for v in vs),
5858
)
5959
)
6060
for d, *vs in zip(data, *values.values())
@@ -97,13 +97,21 @@ class TEntityDescription(BaseModel):
9797

9898

9999
class TQueryEntities(BaseModel):
100-
entities: List[str] = Field(
100+
named: List[str] = Field(
101101
...,
102-
description=("List of entities extracted from the query"),
102+
description=("List of named entities extracted from the query"),
103+
)
104+
generic: List[str] = Field(
105+
...,
106+
description=("List of generic entities extracted from the query"),
103107
)
104-
n: int = Field(..., description="Number of named entities found") # So that the LLM can answer 0.
105108

106-
@field_validator("entities", mode="before")
109+
@field_validator("named", mode="before")
107110
@classmethod
108-
def uppercase_source(cls, value: List[str]):
111+
def uppercase_named(cls, value: List[str]):
109112
return [e.upper() for e in value] if value else value
113+
114+
# @field_validator("generic", mode="before")
115+
# @classmethod
116+
# def uppercase_generic(cls, value: List[str]):
117+
# return [e.upper() for e in value] if value else value

fast_graphrag/_prompt.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@
4747
{{"source": "Radio City", "target": "English", "desc": "Radio City broadcasts English songs."}},
4848
{{"source": "Radio City", "target": "Hindi", "desc": "Radio City broadcasts songs in the Hindi language."}},
4949
{{"source": "Radio City", "target": "PlanetRadiocity.com", "desc": "Radio City launched PlanetRadiocity.com in May 2008."}},
50-
{{"source": "PlanetRadiocity.com", "target": "music portal", "desc": "PlanetRadiocity.com is a music portal that offers music related news, videos and more."}}
50+
{{"source": "PlanetRadiocity.com", "target": "music portal", "desc": "PlanetRadiocity.com is a music portal that offers music related news, videos and more."}},
51+
{{"source": "PlanetRadiocity.com", "target": "video", "desc": "PlanetRadiocity.com offers music related videos."}}
5152
],
5253
"other_relationships": [
5354
{{"source": "Radio City", "target": "New Media", "desc": "Radio City forayed into New Media in May 2008."}},
5455
{{"source": "PlanetRadiocity.com", "target": "news", "desc": "PlanetRadiocity.com offers music related news."}},
55-
{{"source": "PlanetRadiocity.com", "target": "video", "desc": "PlanetRadiocity.com offers music related videos."}},
5656
{{"source": "PlanetRadiocity.com", "target": "song", "desc": "PlanetRadiocity.com offers songs."}}
5757
]
5858
}}
@@ -68,22 +68,23 @@
6868

6969
PROMPTS["entity_relationship_gleaning_done_extraction"] = "Retrospectively check if all entities have been correctly identified: answer done if so, or continue if there are still entities that need to be added."
7070

71-
PROMPTS["entity_extraction_query"] = """You are a helpful assistant that helps a human analyst identify all the named entities present in the input query that are important for answering the query.
71+
PROMPTS["entity_extraction_query"] = """Given the query below, your task is to extract all entities relevant to perform information retrieval to produce an answer.
7272
73-
# Example 1
74-
Query: Do the magazines Arthur's Magazine or First for Women have the same publisher?
75-
Ouput: {{"entities": ["Arthur's Magazine", "First for Women"], "n": 2}}
73+
-Example 1-
74+
Query: Who directed the film that was shot in or around Leland, North Carolina in 1986?
75+
Ouput: {{"named": ["Leland", "North Carolina", "1986"], "generic": ["film director"]}}
7676
77-
# Example 2
78-
Query: Which film has the director who was born earlier, Avatar II: The Return or The Interstellar?
79-
Ouput: {{"entities": ["Avatar II: The Return", "The Interstellar"], "n": 2}}
77+
-Example 2-
78+
Query: What relationship does Fred Gehrke have to the 23rd overall pick in the 2010 Major League Baseball Draft?
79+
Ouput: {{"named": ["Fred Gehrke", "2010 Major League Baseball Draft"], "generic": ["23rd baseball draft pick"]}}
8080
8181
# INPUT
8282
Query: {query}
8383
Output:
8484
"""
8585

8686

87+
8788
PROMPTS[
8889
"summarize_entity_descriptions"
8990
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.

fast_graphrag/_services/_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
GTNode,
2323
TContext,
2424
TDocument,
25-
TEntity,
2625
TIndex,
2726
)
2827

@@ -58,7 +57,7 @@ def extract(
5857

5958
async def extract_entities_from_query(
6059
self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
61-
) -> Iterable[TEntity]:
60+
) -> Dict[str, List[str]]:
6261
"""Extract entities from the given query."""
6362
raise NotImplementedError
6463

@@ -128,7 +127,7 @@ async def upsert(
128127
raise NotImplementedError
129128

130129
async def get_context(
131-
self, query: str, entities: Iterable[TEntity]
130+
self, query: str, entities: Dict[str, List[str]]
132131
) -> Optional[TContext[GTNode, GTEdge, GTHash, GTChunk]]:
133132
"""Retrieve relevant state from the storage."""
134133
raise NotImplementedError

fast_graphrag/_services/_information_extraction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def extract(
4141

4242
async def extract_entities_from_query(
4343
self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
44-
) -> Iterable[TEntity]:
44+
) -> Dict[str, List[str]]:
4545
"""Extract entities from the given query."""
4646
prompt_kwargs["query"] = query
4747
entities, _ = await format_and_send_prompt(
@@ -51,7 +51,10 @@ async def extract_entities_from_query(
5151
response_model=TQueryEntities,
5252
)
5353

54-
return [TEntity(name=name, type="", description="") for name in entities.entities]
54+
return {
55+
"named": entities.named,
56+
"generic": entities.generic
57+
}
5558

5659
async def _extract(
5760
self, llm: BaseLLMService, chunks: Iterable[TChunk], prompt_kwargs: Dict[str, str], entity_types: List[str]

fast_graphrag/_services/_state_manager.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import numpy.typing as npt
8-
from scipy.sparse import csr_matrix
8+
from scipy.sparse import csr_matrix, vstack
99
from tqdm import tqdm
1010

1111
from fast_graphrag._llm import BaseLLMService
@@ -35,7 +35,8 @@
3535
@dataclass
3636
class DefaultStateManagerService(BaseStateManagerService[TEntity, TRelation, THash, TChunk, TId, TEmbedding]):
3737
blob_storage_cls: Type[BaseBlobStorage[csr_matrix]] = field(default=PickleBlobStorage)
38-
similarity_score_threshold: float = field(default=0.80)
38+
insert_similarity_score_threshold: float = field(default=0.8)
39+
query_similarity_score_threshold: Optional[float] = field(default=0.7)
3940

4041
def __post_init__(self):
4142
assert self.workspace is not None, "Workspace must be provided."
@@ -77,7 +78,7 @@ async def upsert(
7778
llm: BaseLLMService,
7879
subgraphs: List[asyncio.Future[Optional[BaseGraphStorage[TEntity, TRelation, TId]]]],
7980
documents: Iterable[Iterable[TChunk]],
80-
show_progress: bool = True
81+
show_progress: bool = True,
8182
) -> None:
8283
nodes: Iterable[List[TEntity]]
8384
edges: Iterable[List[TRelation]]
@@ -95,12 +96,16 @@ async def _get_graphs(
9596

9697
return (nodes, edges)
9798

98-
graphs = [r for graph in tqdm(
99-
asyncio.as_completed([_get_graphs(fgraph) for fgraph in subgraphs]),
100-
total=len(subgraphs),
101-
desc="Extracting data",
102-
disable=not show_progress,
103-
) if (r := await graph) is not None]
99+
graphs = [
100+
r
101+
for graph in tqdm(
102+
asyncio.as_completed([_get_graphs(fgraph) for fgraph in subgraphs]),
103+
total=len(subgraphs),
104+
desc="Extracting data",
105+
disable=not show_progress,
106+
)
107+
if (r := await graph) is not None
108+
]
104109

105110
if len(graphs) == 0:
106111
return
@@ -136,7 +141,7 @@ async def _get_graphs(
136141
# We use the fact that similarity scores are symmetric between entity pairs,
137142
# so we only select half of that by index order
138143
similar_indices[
139-
(scores < self.similarity_score_threshold)
144+
(scores < self.insert_similarity_score_threshold)
140145
| (similar_indices <= upserted_indices) # remove indices smaller or equal the entity
141146
] = 0 # 0 can be used here (not 100% sure, but 99% sure)
142147
progress_bar.update(1)
@@ -147,6 +152,7 @@ async def _get_graphs(
147152

148153
# STEP: insert identity edges
149154
progress_bar.set_description("Building... [identity edges]")
155+
150156
async def _insert_identiy_edges(
151157
source_index: TIndex, target_indices: npt.NDArray[np.int32]
152158
) -> Iterable[Tuple[TIndex, TIndex]]:
@@ -177,30 +183,33 @@ async def _insert_identiy_edges(
177183
progress_bar.set_description("Building [done]")
178184

179185
async def get_context(
180-
self, query: str, entities: Iterable[TEntity]
186+
self, query: str, entities: Dict[str, List[str]]
181187
) -> Optional[TContext[TEntity, TRelation, THash, TChunk]]:
182188
if self.entity_storage.size == 0:
183189
return None
184190

185191
try:
186-
entity_names = [entity.name for entity in entities]
187-
188-
query_embeddings = await self.embedding_service.encode(entity_names + [query])
189-
192+
query_embeddings = await self.embedding_service.encode(
193+
[f"[NAME] {n}" for n in entities["named"]] + [f"[NAME] {n}" for n in entities["generic"]] + [query]
194+
)
195+
entity_scores: List[csr_matrix] = []
190196
# Similarity-search over entities
191-
if len(entity_names) > 0:
192-
vdb_entity_scores_by_name = await self._score_entities_by_vectordb(
193-
query_embeddings=query_embeddings[:-1], top_k=1
197+
if len(entities["named"]) > 0:
198+
vdb_entity_scores_by_named_entity = await self._score_entities_by_vectordb(
199+
query_embeddings=query_embeddings[: len(entities["named"])],
200+
top_k=1,
201+
threshold=self.query_similarity_score_threshold,
194202
)
195-
else:
196-
vdb_entity_scores_by_name = 0
197-
vdb_entity_scores_by_query = await self._score_entities_by_vectordb(
198-
query_embeddings=query_embeddings[-1:], top_k=8
203+
entity_scores.append(vdb_entity_scores_by_named_entity)
204+
205+
vdb_entity_scores_by_generic_entity_and_query = await self._score_entities_by_vectordb(
206+
query_embeddings=query_embeddings[len(entities["named"]) :], top_k=20, threshold=0.5
199207
)
208+
entity_scores.append(vdb_entity_scores_by_generic_entity_and_query)
200209

201-
vdb_entity_scores = vdb_entity_scores_by_name + vdb_entity_scores_by_query
210+
vdb_entity_scores = vstack(entity_scores).max(axis=0)
202211

203-
if vdb_entity_scores.nnz == 0:
212+
if isinstance(vdb_entity_scores, int) or vdb_entity_scores.nnz == 0:
204213
return None
205214
except Exception as e:
206215
logger.error(f"Error during information extraction and scoring for query entities {entities}.\n{e}")
@@ -254,24 +263,25 @@ async def get_context(
254263
async def _get_entities_to_num_docs(self) -> Any:
255264
raise NotImplementedError
256265

257-
async def _score_entities_by_vectordb(self, query_embeddings: Iterable[TEmbedding], top_k: int = 1) -> csr_matrix:
266+
async def _score_entities_by_vectordb(
267+
self, query_embeddings: Iterable[TEmbedding], top_k: int = 1, threshold: Optional[float] = None
268+
) -> csr_matrix:
258269
# TODO: check this
259270
# if top_k != 1:
260271
# logger.warning(f"Top-k > 1 is not tested yet. Using top_k={top_k}.")
261272
if self.node_specificity:
262273
raise NotImplementedError("Node specificity is not supported yet.")
263274

264275
all_entity_probs_by_query_entity = await self.entity_storage.score_all(
265-
np.array(query_embeddings), top_k=top_k
276+
np.array(query_embeddings), top_k=top_k, threshold=threshold
266277
) # (#query_entities, #all_entities)
267278

268279
# TODO: if top_k > 1, we need to aggregate the scores here
269280
if all_entity_probs_by_query_entity.shape[1] == 0:
270281
return all_entity_probs_by_query_entity
271-
all_entity_weights: csr_matrix = all_entity_probs_by_query_entity.max(axis=0) # (1, #all_entities)
272-
273282
# Normalize the scores
274-
all_entity_weights /= all_entity_weights.sum()
283+
all_entity_probs_by_query_entity /= all_entity_probs_by_query_entity.sum(axis=1) + 1e-8
284+
all_entity_weights: csr_matrix = all_entity_probs_by_query_entity.max(axis=0) # (1, #all_entities)
275285

276286
if self.node_specificity:
277287
all_entity_weights = all_entity_weights.multiply(1.0 / await self._get_entities_to_num_docs())

fast_graphrag/_storage/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ async def upsert(
184184
raise NotImplementedError
185185

186186
async def score_all(
187-
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, confidence_threshold: float = 0.0
187+
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None
188188
) -> csr_matrix:
189189
"""Score all embeddings against the given queries.
190190

fast_graphrag/_storage/_vdb_hnswlib.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pickle
22
from dataclasses import dataclass, field
3-
from typing import Any, Dict, Iterable, List, Tuple, Union
3+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
44

55
import hnswlib
66
import numpy as np
@@ -16,9 +16,9 @@
1616

1717
@dataclass
1818
class HNSWVectorStorageConfig:
19-
ef_construction: int = field(default=64)
20-
M: int = field(default=48)
21-
ef_search: int = field(default=64)
19+
ef_construction: int = field(default=256)
20+
M: int = field(default=64)
21+
ef_search: int = field(default=96)
2222
num_threads: int = field(default=-1)
2323

2424

@@ -83,10 +83,8 @@ async def get_knn(
8383
return ids, 1.0 - np.array(distances, dtype=TScore) * 0.5
8484

8585
async def score_all(
86-
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, confidence_threshold: float = 0.0
86+
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None
8787
) -> csr_matrix:
88-
if confidence_threshold > 0.0:
89-
raise NotImplementedError("Confidence threshold is not supported yet.")
9088
if not isinstance(embeddings, np.ndarray):
9189
embeddings = np.array(list(embeddings), dtype=np.float32)
9290

@@ -104,6 +102,9 @@ async def score_all(
104102
ids = np.array(ids)
105103
scores = 1.0 - np.array(distances, dtype=TScore) * 0.5
106104

105+
if threshold is not None:
106+
scores[scores < threshold] = 0
107+
107108
# Create sparse distance matrix with shape (#embeddings, #all_embeddings)
108109
flattened_ids = ids.ravel()
109110
flattened_scores = scores.ravel()

tests/_models_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
class TestModels(unittest.TestCase):
1717
def test_tqueryentities(self):
18-
query_entities = TQueryEntities(entities=["Entity1", "Entity2"], n=2)
19-
self.assertEqual(query_entities.entities, ["ENTITY1", "ENTITY2"])
20-
self.assertEqual(query_entities.n, 2)
18+
query_entities = TQueryEntities(named=["Entity1", "Entity2"], generic=["Generic1", "Generic2"])
19+
self.assertEqual(query_entities.named, ["ENTITY1", "ENTITY2"])
20+
self.assertEqual(query_entities.generic, ["Generic1", "Generic2"])
2121

2222
with self.assertRaises(ValidationError):
2323
TQueryEntities(entities=["Entity1", "Entity2"], n="two")

tests/_policies/_graph_upsert_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def test_call_same_node_no_summarize(self):
7575

7676
self.assertEqual(
7777
upserted_nodes[0][1].description,
78-
"This is a short random description 1. This is a short random description 2.",
78+
"This is a short random description 1.\nThis is a short random description 2.",
7979
)
8080

8181
# Assertions

0 commit comments

Comments
 (0)