Skip to content

Commit

Permalink
Add generic entities
Browse files Browse the repository at this point in the history
  • Loading branch information
liukidar committed Dec 5, 2024
1 parent 29171c3 commit 2c060f1
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 89 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
## News (and Coming Soon)
- [ ] Support for IDF weightening of entities
- [ ] Support for generic entities and concepts
- [x] Support for generic entities and concepts (initial commit)
- [x] [2024.12.02] Benchmarks comparing Fast GraphRAG to LightRAG, GraphRAG and VectorDBs released [here](https://github.com/circlemind-ai/fast-graphrag/blob/main/benchmarks/README.md)

## Features
Expand Down
22 changes: 15 additions & 7 deletions fast_graphrag/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def dump_to_csv(
chain(
separator.join(
chain(
(str(getattr(d, field)).replace("\t", " ") for field in fields),
(str(v).replace("\t", " ") for v in vs),
(str(getattr(d, field)).replace("\n", " ").replace("\t", " ") for field in fields),
(str(v).replace("\n", " ").replace("\t", " ") for v in vs),
)
)
for d, *vs in zip(data, *values.values())
Expand Down Expand Up @@ -97,13 +97,21 @@ class TEntityDescription(BaseModel):


class TQueryEntities(BaseModel):
entities: List[str] = Field(
named: List[str] = Field(
...,
description=("List of entities extracted from the query"),
description=("List of named entities extracted from the query"),
)
generic: List[str] = Field(
...,
description=("List of generic entities extracted from the query"),
)
n: int = Field(..., description="Number of named entities found") # So that the LLM can answer 0.

@field_validator("entities", mode="before")
@field_validator("named", mode="before")
@classmethod
def uppercase_source(cls, value: List[str]):
def uppercase_named(cls, value: List[str]):
return [e.upper() for e in value] if value else value

# @field_validator("generic", mode="before")
# @classmethod
# def uppercase_generic(cls, value: List[str]):
# return [e.upper() for e in value] if value else value
19 changes: 10 additions & 9 deletions fast_graphrag/_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@
{{"source": "Radio City", "target": "English", "desc": "Radio City broadcasts English songs."}},
{{"source": "Radio City", "target": "Hindi", "desc": "Radio City broadcasts songs in the Hindi language."}},
{{"source": "Radio City", "target": "PlanetRadiocity.com", "desc": "Radio City launched PlanetRadiocity.com in May 2008."}},
{{"source": "PlanetRadiocity.com", "target": "music portal", "desc": "PlanetRadiocity.com is a music portal that offers music related news, videos and more."}}
{{"source": "PlanetRadiocity.com", "target": "music portal", "desc": "PlanetRadiocity.com is a music portal that offers music related news, videos and more."}},
{{"source": "PlanetRadiocity.com", "target": "video", "desc": "PlanetRadiocity.com offers music related videos."}}
],
"other_relationships": [
{{"source": "Radio City", "target": "New Media", "desc": "Radio City forayed into New Media in May 2008."}},
{{"source": "PlanetRadiocity.com", "target": "news", "desc": "PlanetRadiocity.com offers music related news."}},
{{"source": "PlanetRadiocity.com", "target": "video", "desc": "PlanetRadiocity.com offers music related videos."}},
{{"source": "PlanetRadiocity.com", "target": "song", "desc": "PlanetRadiocity.com offers songs."}}
]
}}
Expand All @@ -68,22 +68,23 @@

PROMPTS["entity_relationship_gleaning_done_extraction"] = "Retrospectively check if all entities have been correctly identified: answer done if so, or continue if there are still entities that need to be added."

PROMPTS["entity_extraction_query"] = """You are a helpful assistant that helps a human analyst identify all the named entities present in the input query that are important for answering the query.
PROMPTS["entity_extraction_query"] = """Given the query below, your task is to extract all entities relevant to perform information retrieval to produce an answer.
# Example 1
Query: Do the magazines Arthur's Magazine or First for Women have the same publisher?
Ouput: {{"entities": ["Arthur's Magazine", "First for Women"], "n": 2}}
-Example 1-
Query: Who directed the film that was shot in or around Leland, North Carolina in 1986?
Ouput: {{"named": ["Leland", "North Carolina", "1986"], "generic": ["film director"]}}
# Example 2
Query: Which film has the director who was born earlier, Avatar II: The Return or The Interstellar?
Ouput: {{"entities": ["Avatar II: The Return", "The Interstellar"], "n": 2}}
-Example 2-
Query: What relationship does Fred Gehrke have to the 23rd overall pick in the 2010 Major League Baseball Draft?
Ouput: {{"named": ["Fred Gehrke", "2010 Major League Baseball Draft"], "generic": ["23rd baseball draft pick"]}}
# INPUT
Query: {query}
Output:
"""



PROMPTS[
"summarize_entity_descriptions"
] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
Expand Down
5 changes: 2 additions & 3 deletions fast_graphrag/_services/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
GTNode,
TContext,
TDocument,
TEntity,
TIndex,
)

Expand Down Expand Up @@ -58,7 +57,7 @@ def extract(

async def extract_entities_from_query(
self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
) -> Iterable[TEntity]:
) -> Dict[str, List[str]]:
"""Extract entities from the given query."""
raise NotImplementedError

Expand Down Expand Up @@ -128,7 +127,7 @@ async def upsert(
raise NotImplementedError

async def get_context(
self, query: str, entities: Iterable[TEntity]
self, query: str, entities: Dict[str, List[str]]
) -> Optional[TContext[GTNode, GTEdge, GTHash, GTChunk]]:
"""Retrieve relevant state from the storage."""
raise NotImplementedError
Expand Down
7 changes: 5 additions & 2 deletions fast_graphrag/_services/_information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def extract(

async def extract_entities_from_query(
self, llm: BaseLLMService, query: str, prompt_kwargs: Dict[str, str]
) -> Iterable[TEntity]:
) -> Dict[str, List[str]]:
"""Extract entities from the given query."""
prompt_kwargs["query"] = query
entities, _ = await format_and_send_prompt(
Expand All @@ -51,7 +51,10 @@ async def extract_entities_from_query(
response_model=TQueryEntities,
)

return [TEntity(name=name, type="", description="") for name in entities.entities]
return {
"named": entities.named,
"generic": entities.generic
}

async def _extract(
self, llm: BaseLLMService, chunks: Iterable[TChunk], prompt_kwargs: Dict[str, str], entity_types: List[str]
Expand Down
68 changes: 39 additions & 29 deletions fast_graphrag/_services/_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix, vstack
from tqdm import tqdm

from fast_graphrag._llm import BaseLLMService
Expand Down Expand Up @@ -35,7 +35,8 @@
@dataclass
class DefaultStateManagerService(BaseStateManagerService[TEntity, TRelation, THash, TChunk, TId, TEmbedding]):
blob_storage_cls: Type[BaseBlobStorage[csr_matrix]] = field(default=PickleBlobStorage)
similarity_score_threshold: float = field(default=0.80)
insert_similarity_score_threshold: float = field(default=0.8)
query_similarity_score_threshold: Optional[float] = field(default=0.7)

def __post_init__(self):
assert self.workspace is not None, "Workspace must be provided."
Expand Down Expand Up @@ -77,7 +78,7 @@ async def upsert(
llm: BaseLLMService,
subgraphs: List[asyncio.Future[Optional[BaseGraphStorage[TEntity, TRelation, TId]]]],
documents: Iterable[Iterable[TChunk]],
show_progress: bool = True
show_progress: bool = True,
) -> None:
nodes: Iterable[List[TEntity]]
edges: Iterable[List[TRelation]]
Expand All @@ -95,12 +96,16 @@ async def _get_graphs(

return (nodes, edges)

graphs = [r for graph in tqdm(
asyncio.as_completed([_get_graphs(fgraph) for fgraph in subgraphs]),
total=len(subgraphs),
desc="Extracting data",
disable=not show_progress,
) if (r := await graph) is not None]
graphs = [
r
for graph in tqdm(
asyncio.as_completed([_get_graphs(fgraph) for fgraph in subgraphs]),
total=len(subgraphs),
desc="Extracting data",
disable=not show_progress,
)
if (r := await graph) is not None
]

if len(graphs) == 0:
return
Expand Down Expand Up @@ -136,7 +141,7 @@ async def _get_graphs(
# We use the fact that similarity scores are symmetric between entity pairs,
# so we only select half of that by index order
similar_indices[
(scores < self.similarity_score_threshold)
(scores < self.insert_similarity_score_threshold)
| (similar_indices <= upserted_indices) # remove indices smaller or equal the entity
] = 0 # 0 can be used here (not 100% sure, but 99% sure)
progress_bar.update(1)
Expand All @@ -147,6 +152,7 @@ async def _get_graphs(

# STEP: insert identity edges
progress_bar.set_description("Building... [identity edges]")

async def _insert_identiy_edges(
source_index: TIndex, target_indices: npt.NDArray[np.int32]
) -> Iterable[Tuple[TIndex, TIndex]]:
Expand Down Expand Up @@ -177,30 +183,33 @@ async def _insert_identiy_edges(
progress_bar.set_description("Building [done]")

async def get_context(
self, query: str, entities: Iterable[TEntity]
self, query: str, entities: Dict[str, List[str]]
) -> Optional[TContext[TEntity, TRelation, THash, TChunk]]:
if self.entity_storage.size == 0:
return None

try:
entity_names = [entity.name for entity in entities]

query_embeddings = await self.embedding_service.encode(entity_names + [query])

query_embeddings = await self.embedding_service.encode(
[f"[NAME] {n}" for n in entities["named"]] + [f"[NAME] {n}" for n in entities["generic"]] + [query]
)
entity_scores: List[csr_matrix] = []
# Similarity-search over entities
if len(entity_names) > 0:
vdb_entity_scores_by_name = await self._score_entities_by_vectordb(
query_embeddings=query_embeddings[:-1], top_k=1
if len(entities["named"]) > 0:
vdb_entity_scores_by_named_entity = await self._score_entities_by_vectordb(
query_embeddings=query_embeddings[: len(entities["named"])],
top_k=1,
threshold=self.query_similarity_score_threshold,
)
else:
vdb_entity_scores_by_name = 0
vdb_entity_scores_by_query = await self._score_entities_by_vectordb(
query_embeddings=query_embeddings[-1:], top_k=8
entity_scores.append(vdb_entity_scores_by_named_entity)

vdb_entity_scores_by_generic_entity_and_query = await self._score_entities_by_vectordb(
query_embeddings=query_embeddings[len(entities["named"]) :], top_k=20, threshold=0.5
)
entity_scores.append(vdb_entity_scores_by_generic_entity_and_query)

vdb_entity_scores = vdb_entity_scores_by_name + vdb_entity_scores_by_query
vdb_entity_scores = vstack(entity_scores).max(axis=0)

if vdb_entity_scores.nnz == 0:
if isinstance(vdb_entity_scores, int) or vdb_entity_scores.nnz == 0:
return None
except Exception as e:
logger.error(f"Error during information extraction and scoring for query entities {entities}.\n{e}")
Expand Down Expand Up @@ -254,24 +263,25 @@ async def get_context(
async def _get_entities_to_num_docs(self) -> Any:
raise NotImplementedError

async def _score_entities_by_vectordb(self, query_embeddings: Iterable[TEmbedding], top_k: int = 1) -> csr_matrix:
async def _score_entities_by_vectordb(
self, query_embeddings: Iterable[TEmbedding], top_k: int = 1, threshold: Optional[float] = None
) -> csr_matrix:
# TODO: check this
# if top_k != 1:
# logger.warning(f"Top-k > 1 is not tested yet. Using top_k={top_k}.")
if self.node_specificity:
raise NotImplementedError("Node specificity is not supported yet.")

all_entity_probs_by_query_entity = await self.entity_storage.score_all(
np.array(query_embeddings), top_k=top_k
np.array(query_embeddings), top_k=top_k, threshold=threshold
) # (#query_entities, #all_entities)

# TODO: if top_k > 1, we need to aggregate the scores here
if all_entity_probs_by_query_entity.shape[1] == 0:
return all_entity_probs_by_query_entity
all_entity_weights: csr_matrix = all_entity_probs_by_query_entity.max(axis=0) # (1, #all_entities)

# Normalize the scores
all_entity_weights /= all_entity_weights.sum()
all_entity_probs_by_query_entity /= all_entity_probs_by_query_entity.sum(axis=1) + 1e-8
all_entity_weights: csr_matrix = all_entity_probs_by_query_entity.max(axis=0) # (1, #all_entities)

if self.node_specificity:
all_entity_weights = all_entity_weights.multiply(1.0 / await self._get_entities_to_num_docs())
Expand Down
2 changes: 1 addition & 1 deletion fast_graphrag/_storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def upsert(
raise NotImplementedError

async def score_all(
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, confidence_threshold: float = 0.0
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None
) -> csr_matrix:
"""Score all embeddings against the given queries.
Expand Down
15 changes: 8 additions & 7 deletions fast_graphrag/_storage/_vdb_hnswlib.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import hnswlib
import numpy as np
Expand All @@ -16,9 +16,9 @@

@dataclass
class HNSWVectorStorageConfig:
ef_construction: int = field(default=64)
M: int = field(default=48)
ef_search: int = field(default=64)
ef_construction: int = field(default=256)
M: int = field(default=64)
ef_search: int = field(default=96)
num_threads: int = field(default=-1)


Expand Down Expand Up @@ -83,10 +83,8 @@ async def get_knn(
return ids, 1.0 - np.array(distances, dtype=TScore) * 0.5

async def score_all(
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, confidence_threshold: float = 0.0
self, embeddings: Iterable[GTEmbedding], top_k: int = 1, threshold: Optional[float] = None
) -> csr_matrix:
if confidence_threshold > 0.0:
raise NotImplementedError("Confidence threshold is not supported yet.")
if not isinstance(embeddings, np.ndarray):
embeddings = np.array(list(embeddings), dtype=np.float32)

Expand All @@ -104,6 +102,9 @@ async def score_all(
ids = np.array(ids)
scores = 1.0 - np.array(distances, dtype=TScore) * 0.5

if threshold is not None:
scores[scores < threshold] = 0

# Create sparse distance matrix with shape (#embeddings, #all_embeddings)
flattened_ids = ids.ravel()
flattened_scores = scores.ravel()
Expand Down
6 changes: 3 additions & 3 deletions tests/_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

class TestModels(unittest.TestCase):
def test_tqueryentities(self):
query_entities = TQueryEntities(entities=["Entity1", "Entity2"], n=2)
self.assertEqual(query_entities.entities, ["ENTITY1", "ENTITY2"])
self.assertEqual(query_entities.n, 2)
query_entities = TQueryEntities(named=["Entity1", "Entity2"], generic=["Generic1", "Generic2"])
self.assertEqual(query_entities.named, ["ENTITY1", "ENTITY2"])
self.assertEqual(query_entities.generic, ["Generic1", "Generic2"])

with self.assertRaises(ValidationError):
TQueryEntities(entities=["Entity1", "Entity2"], n="two")
Expand Down
2 changes: 1 addition & 1 deletion tests/_policies/_graph_upsert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def test_call_same_node_no_summarize(self):

self.assertEqual(
upserted_nodes[0][1].description,
"This is a short random description 1. This is a short random description 2.",
"This is a short random description 1.\nThis is a short random description 2.",
)

# Assertions
Expand Down
Loading

0 comments on commit 2c060f1

Please sign in to comment.