From 31c5a640dbf3cf6dc83da1bd0e1ecaaebcdb4f15 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 3 Dec 2025 17:27:53 +0800 Subject: [PATCH 01/28] feat: add config and operator node types --- graphgen/bases/__init__.py | 2 +- graphgen/bases/datatypes.py | 37 +++++++++++++++++ graphgen/{ => operators}/evaluate.py | 0 graphgen/operators/storage.py | 59 ---------------------------- 4 files changed, 38 insertions(+), 60 deletions(-) rename graphgen/{ => operators}/evaluate.py (100%) delete mode 100644 graphgen/operators/storage.py diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 3d0bc800..f4c3a0e8 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -13,4 +13,4 @@ StorageNameSpace, ) from .base_tokenizer import BaseTokenizer -from .datatypes import Chunk, QAPair, Token +from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index cb3be345..199ba80b 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -2,6 +2,8 @@ from dataclasses import dataclass, field from typing import List, Union +from pydantic import BaseModel, Field, field_validator + @dataclass class Chunk: @@ -48,3 +50,38 @@ class Community: nodes: List[str] = field(default_factory=list) edges: List[tuple] = field(default_factory=list) metadata: dict = field(default_factory=dict) + + +class Node(BaseModel): + id: str = Field(..., description="unique node id") + op_name: str = Field(..., description="operator name") + type: str = Field( + ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch" + ) + params: dict = Field(default_factory=dict, description="operator parameters") + dependencies: List[str] = Field( + default_factory=list, description="list of dependent node ids" + ) + + @classmethod + @field_validator("type") + def validate_type(cls, v: str) -> str: + valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"} + if v not in valid_types: + raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.") + return v + + +class Config(BaseModel): + nodes: List[Node] = Field( + ..., min_length=1, description="list of nodes in the computation graph" + ) + + @classmethod + @field_validator("nodes") + def validate_unique_ids(cls, v: List[Node]) -> List[Node]: + ids = [node.id for node in v] + if len(ids) != len(set(ids)): + duplicates = {id_ for id_ in ids if ids.count(id_) > 1} + raise ValueError(f"Duplicate node ids found: {duplicates}") + return v diff --git a/graphgen/evaluate.py b/graphgen/operators/evaluate.py similarity index 100% rename from graphgen/evaluate.py rename to graphgen/operators/evaluate.py diff --git a/graphgen/operators/storage.py b/graphgen/operators/storage.py deleted file mode 100644 index ea5488ac..00000000 --- a/graphgen/operators/storage.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from typing import Any - -import ray - -from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage - - -@ray.remote -class StorageManager: - """ - Centralized storage for all operators - - Example Usage: - ---------- - # init - storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123) - - # visit storage in tasks - @ray.remote - def some_task(storage_manager): - full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs")) - - # visit storage in other actors - @ray.remote - class SomeOperator: - def __init__(self, storage_manager): - self.storage_manager = storage_manager - def some_method(self): - full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs")) - """ - - def __init__(self, working_dir: str, unique_id: int): - self.working_dir = working_dir - self.unique_id = unique_id - - # Initialize all storage backends - self.storages = { - "full_docs": JsonKVStorage(working_dir, namespace="full_docs"), - "chunks": JsonKVStorage(working_dir, namespace="chunks"), - "graph": NetworkXStorage(working_dir, namespace="graph"), - "rephrase": JsonKVStorage(working_dir, namespace="rephrase"), - "partition": JsonListStorage(working_dir, namespace="partition"), - "search": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="search", - ), - "extraction": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="extraction", - ), - "qa": JsonListStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="qa", - ), - } - - def get_storage(self, name: str) -> Any: - return self.storages.get(name) From 8bcbe519bb0bc730b85a868813e9fdd254b34d2f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 3 Dec 2025 18:43:20 +0800 Subject: [PATCH 02/28] refactor: refactor readers with ray data --- graphgen/bases/base_reader.py | 95 +++++++++------ graphgen/models/reader/__init__.py | 1 - graphgen/models/reader/csv_reader.py | 30 +++-- graphgen/models/reader/json_reader.py | 36 +++--- graphgen/models/reader/jsonl_reader.py | 30 ----- graphgen/models/reader/parquet_reader.py | 31 +++-- graphgen/models/reader/pdf_reader.py | 56 ++++++--- graphgen/models/reader/pickle_reader.py | 84 ++++++++++--- graphgen/models/reader/rdf_reader.py | 108 +++++++++++++++-- graphgen/models/reader/txt_reader.py | 33 ++++- graphgen/operators/read/read_files.py | 148 +++++++++++++---------- 11 files changed, 429 insertions(+), 223 deletions(-) delete mode 100644 graphgen/models/reader/jsonl_reader.py diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 89778469..91d55fcd 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,8 +1,10 @@ import os from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import pandas as pd import requests +from ray.data import Dataset class BaseReader(ABC): @@ -14,52 +16,65 @@ def __init__(self, text_column: str = "content"): self.text_column = text_column @abstractmethod - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read data from the specified file path. - :param file_path: Path to the input file. - :return: List of dictionaries containing the data. + :param input_path: Path to the input file or list of file paths. + :return: Ray Dataset containing the read data. """ - @staticmethod - def filter(data: List[dict]) -> List[dict]: + def _should_keep_item(self, item: Dict[str, Any]) -> bool: + """ + Determine whether to keep the given item based on the text column. + + :param item: Dictionary representing a data entry. + :return: True if the item should be kept, False otherwise. """ - Filter out entries with empty or missing text in the specified column. + item_type = item.get("type") + assert item_type in [ + "text", + "image", + "table", + "equation", + "protein", + ], f"Unsupported item type: {item_type}" + if item_type == "text": + content = item.get(self.text_column, "").strip() + return bool(content) + return True - :param data: List of dictionaries containing the data. - :return: Filtered list of dictionaries. + def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """ + Validate data format. """ + if "type" not in batch.columns: + raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") - def _image_exists(path_or_url: str, timeout: int = 3) -> bool: - """ - Check if an image exists at the given local path or URL. - :param path_or_url: Local file path or remote URL of the image. - :param timeout: Timeout for remote URL requests in seconds. - :return: True if the image exists, False otherwise. - """ - if not path_or_url: - return False - if not path_or_url.startswith(("http://", "https://", "ftp://")): - path = path_or_url.replace("file://", "", 1) - path = os.path.abspath(path) - return os.path.isfile(path) - try: - resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) - return resp.status_code == 200 - except requests.RequestException: - return False + if "text" in batch["type"].values: + if self.text_column not in batch.columns: + raise ValueError( + f"Missing '{self.text_column}' column for text documents" + ) - filtered_data = [] - for item in data: - if item.get("type") == "text": - content = item.get("content", "").strip() - if content: - filtered_data.append(item) - elif item.get("type") in ("image", "table", "equation"): - img_path = item.get("img_path") - if _image_exists(img_path): - filtered_data.append(item) - else: - filtered_data.append(item) - return filtered_data + return batch + + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except requests.RequestException: + return False diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index 600ffb4a..220460c3 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -1,6 +1,5 @@ from .csv_reader import CSVReader from .json_reader import JSONReader -from .jsonl_reader import JSONLReader from .parquet_reader import ParquetReader from .pdf_reader import PDFReader from .pickle_reader import PickleReader diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index bc865a3b..99faa30e 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,13 +14,20 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read CSV files and return Ray Dataset. - df = pd.read_csv(file_path) - for _, row in df.iterrows(): - assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}" - if row["type"] == "text" and self.text_column not in row: - raise ValueError( - f"Missing '{self.text_column}' in document: {row.to_dict()}" - ) - return self.filter(df.to_dict(orient="records")) + :param input_path: Path to CSV file or list of CSV files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_csv(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 8253041c..1bcba4ea 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,26 +1,32 @@ -import json -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class JSONReader(BaseReader): """ - Reader for JSON files. + Reader for JSON and JSONL files. Columns: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - return self.filter(data) - raise ValueError("JSON file must contain a list of documents.") + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read JSON file and return Ray Dataset. + :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. + :param parallelism: Number of parallel workers for reading files. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_json(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py deleted file mode 100644 index 31bc3195..00000000 --- a/graphgen/models/reader/jsonl_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -from typing import Any, Dict, List - -from graphgen.bases.base_reader import BaseReader -from graphgen.utils import logger - - -class JSONLReader(BaseReader): - """ - Reader for JSONL files. - Columns: - - type: The type of the document (e.g., "text", "image", etc.) - - if type is "text", "content" column must be present. - """ - - def read(self, file_path: str) -> List[Dict[str, Any]]: - docs = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - try: - doc = json.loads(line) - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - docs.append(doc) - except json.JSONDecodeError as e: - logger.error("Error decoding JSON line: %s. Error: %s", line, e) - return self.filter(docs) diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index a325b876..5423643b 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,12 +14,22 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - df = pd.read_parquet(file_path) - data: List[Dict[str, Any]] = df.to_dict(orient="records") + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read Parquet files using Ray Data. - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") - return self.filter(data) + :param input_path: Path to Parquet file or list of Parquet files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + ds = ray.data.read_parquet(input_path, override_num_blocks=parallelism) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 94562cb5..9d5c7c27 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -5,6 +5,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray +from ray.data import Dataset + from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device @@ -62,19 +65,32 @@ def __init__( self.parser = MinerUParser() self.txt_reader = TXTReader() - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: - """ - file_path - **override: override MinerU parameters - """ - pdf_path = Path(file_path).expanduser().resolve() - if not pdf_path.is_file(): - raise FileNotFoundError(pdf_path) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + **override, + ) -> Dataset: + + # Ensure input_path is a list + if isinstance(input_path, str): + input_path = [input_path] + + paths_ds = ray.data.from_items(input_path) + + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + pdf_path = row["item"] + kwargs = {**self._default_kwargs, **override} + return self._call_mineru(Path(pdf_path), kwargs) + except Exception as e: + logger.error("Failed to process %s: %s", row, e) + return [] - kwargs = {**self._default_kwargs, **override} + docs_ds = paths_ds.flat_map(process_pdf) + docs_ds = docs_ds.filter(self._should_keep_item) - mineru_result = self._call_mineru(pdf_path, kwargs) - return self.filter(mineru_result) + return docs_ds def _call_mineru( self, pdf_path: Path, kwargs: Dict[str, Any] @@ -161,18 +177,18 @@ def _try_load_cached_result( base = os.path.dirname(json_file) results = [] - for item in data: + for it in data: for key in ("img_path", "table_img_path", "equation_img_path"): - rel_path = item.get(key) + rel_path = it.get(key) if rel_path: - item[key] = str(Path(base).joinpath(rel_path).resolve()) - if item["type"] == "text": - item["content"] = item["text"] - del item["text"] + it[key] = str(Path(base).joinpath(rel_path).resolve()) + if it["type"] == "text": + it["content"] = it["text"] + del it["text"] for key in ("page_idx", "bbox", "text_level"): - if item.get(key) is not None: - del item[key] - results.append(item) + if it.get(key) is not None: + del it[key] + results.append(it) return results @staticmethod diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1a11dc11..0b0e5719 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,30 +1,82 @@ import pickle -from typing import Any, Dict, List +from typing import List, Union + +import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class PickleReader(BaseReader): """ - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. - - Columns: + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. + Each pickle file should contain a list of dictionaries with at least: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. + + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. + For Ray >= 2.5, consider using read_pickle if available in your version. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "rb") as f: - data = pickle.load(f) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = None, + ) -> Dataset: + """ + Read Pickle files using Ray Data. + + :param input_path: Path to pickle file or list of pickle files. + :param parallelism: Number of blocks for Ray Dataset reading. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + # Use read_binary_files as a reliable alternative to read_pickle + ds = ray.data.read_binary_files( + input_path, override_num_blocks=parallelism, include_paths=True + ) + + # Deserialize pickle files and flatten into individual records + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + all_records = [] + for _, row in batch.iterrows(): + try: + # Load pickle data from bytes + data = pickle.loads(row["bytes"]) + + # Validate structure + if not isinstance(data, list): + logger.error( + "Pickle file {row['path']} must contain a list, got {type(data)}" + ) + continue + + if not all(isinstance(item, dict) for item in data): + logger.error( + "Pickle file {row['path']} must contain a list of dictionaries" + ) + continue + + # Flatten: each dict in the list becomes a separate row + all_records.extend(data) + except Exception as e: + logger.error( + "Failed to deserialize pickle file %s: %s", row["path"], str(e) + ) + continue + + return pd.DataFrame(all_records) - if not isinstance(data, list): - raise ValueError("Pickle file must contain a list of documents.") + # Apply deserialization and flattening + ds = ds.map_batches(deserialize_batch, batch_format="pandas") - for doc in data: - if not isinstance(doc, dict): - raise ValueError("Every item in the list must be a dict.") - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") + # Validate the schema + ds = ds.map_batches(self._validate_batch, batch_format="pandas") - return self.filter(data) + # Filter valid items + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index cce167c1..406478f5 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,48 +1,130 @@ -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Union +import ray import rdflib +from ray.data import Dataset from rdflib import Literal from rdflib.util import guess_format from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class RDFReader(BaseReader): """ Reader for RDF files that extracts triples and represents them as dictionaries. + + Uses Ray Data for distributed processing of multiple RDF files. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def __init__(self, *, text_column: str = "content", **kwargs): + """ + Initialize RDFReader. + + :param text_column: The column name for text content (default: "content"). + """ + super().__init__(**kwargs) + self.text_column = text_column + + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read RDF file(s) using Ray Data. + + :param input_path: Path to RDF file or list of RDF files. + :param parallelism: Number of parallel workers for processing. + :return: Ray Dataset containing extracted documents. + """ + if not ray.is_initialized(): + ray.init() + + # Ensure input_path is a list to prevent Ray from splitting string into characters + if isinstance(input_path, str): + input_path = [input_path] + + # Create dataset from file paths + paths_ds = ray.data.from_items(input_path) + + def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single RDF file and return list of documents.""" + try: + file_path = row["item"] + return self._parse_rdf_file(Path(file_path)) + except Exception as e: + logger.error( + "Failed to process RDF file %s: %s", row.get("item", "unknown"), e + ) + return [] + + # Process files in parallel and flatten results + docs_ds = paths_ds.flat_map(process_rdf) + + # Filter valid documents + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds + + def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: + """ + Parse a single RDF file and extract documents. + + :param file_path: Path to RDF file. + :return: List of document dictionaries. + """ + if not file_path.is_file(): + raise FileNotFoundError(f"RDF file not found: {file_path}") + g = rdflib.Graph() - fmt = guess_format(file_path) + fmt = guess_format(str(file_path)) + try: - g.parse(file_path, format=fmt) + g.parse(str(file_path), format=fmt) except Exception as e: raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e docs: List[Dict[str, Any]] = [] - text_col = self.text_column + # Process each unique subject in the RDF graph for subj in set(g.subjects()): literals = [] props = {} + + # Extract all triples for this subject for _, pred, obj in g.triples((subj, None, None)): pred_str = str(pred) + obj_str = str(obj) + + # Collect literal values as text content if isinstance(obj, Literal): - literals.append(str(obj)) - props.setdefault(pred_str, []).append(str(obj)) + literals.append(obj_str) + + # Store all properties (including non-literals) + props.setdefault(pred_str, []).append(obj_str) + # Join all literal values as the text content text = " ".join(literals).strip() if not text: - raise ValueError( - f"Subject {subj} has no literal values; " - f"missing '{text_col}' for text column." + logger.warning( + "Subject %s in %s has no literal values; document will have empty '%s' field.", + subj, + file_path, + self.text_column, ) - doc = {"id": str(subj), text_col: text, "properties": props} + # Create document dictionary + doc = { + "id": str(subj), + self.text_column: text, + "properties": props, + "source_file": str(file_path), + } docs.append(doc) if not docs: - raise ValueError("RDF file contains no valid documents.") + logger.warning("RDF file %s contains no valid documents.", file_path) - return self.filter(docs) + return docs diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index ec2ff747..bb6cce9e 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,10 +1,33 @@ -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class TXTReader(BaseReader): - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - docs = [{"type": "text", self.text_column: f.read()}] - return self.filter(docs) + def read( + self, + input_path: Union[str, List[str]], + parallelism: int = 4, + ) -> Dataset: + """ + Read text files from the specified input path. + :param input_path: Path to the input text file or list of text files. + :param parallelism: Number of blocks to override for Ray Dataset reading. + :return: Ray Dataset containing the read text data. + """ + docs_ds = ray.data.read_text( + input_path, encoding="utf-8", override_num_blocks=parallelism + ) + + docs_ds = docs_ds.map( + lambda row: { + "type": "text", + self.text_column: row["text"], + } + ) + + docs_ds = docs_ds.filter(self._should_keep_item) + return docs_ds diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py index d9e7f673..34ffee85 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read_files.py @@ -1,9 +1,10 @@ from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, List, Optional, Union + +import ray from graphgen.models import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, @@ -16,7 +17,7 @@ from .parallel_file_scanner import ParallelFileScanner _MAPPING = { - "jsonl": JSONLReader, + "jsonl": JSONReader, "json": JSONReader, "txt": TXTReader, "csv": CSVReader, @@ -30,70 +31,93 @@ } -def _build_reader(suffix: str, cache_dir: str | None): +def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): + """Factory function to build appropriate reader instance""" suffix = suffix.lower() - if suffix == "pdf" and cache_dir is not None: - return _MAPPING[suffix](output_dir=cache_dir) - return _MAPPING[suffix]() + reader_cls = _MAPPING.get(suffix) + if not reader_cls: + raise ValueError(f"Unsupported file suffix: {suffix}") + + # Special handling for PDFReader which needs output_dir + if suffix == "pdf": + if cache_dir is None: + raise ValueError("cache_dir must be provided for PDFReader") + return reader_cls(output_dir=cache_dir, **reader_kwargs) + + return reader_cls(**reader_kwargs) def read_files( - input_file: str, + input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, cache_dir: Optional[str] = None, - max_workers: int = 4, - rescan: bool = False, -) -> Iterator[Dict[str, Any]]: + parallelism: int = 4, + recursive: bool = True, + **reader_kwargs: Any, +) -> ray.data.Dataset: """ - Read files from a path using parallel scanning and appropriate readers. - - Args: - input_file: Path to a file or directory - allowed_suffix: List of file suffixes to read. If None, uses all supported types - cache_dir: Directory for caching PDF extraction and scan results - max_workers: Number of workers for parallel scanning - rescan: Whether to force rescan even if cached results exist + Unified entry point to read files of multiple types using Ray Data. + + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents """ - - path = Path(input_file).expanduser() - if not path.exists(): - raise FileNotFoundError(f"input_path not found: {input_file}") - - if allowed_suffix is None: - support_suffix = set(_MAPPING.keys()) - else: - support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} - - with ParallelFileScanner( - cache_dir=cache_dir or "cache", - allowed_suffix=support_suffix, - rescan=rescan, - max_workers=max_workers, - ) as scanner: - scan_results = scanner.scan(str(path), recursive=True) - - # Extract files from scan results - files_to_read = [] - for path_result in scan_results.values(): - if "error" in path_result: - logger.warning("Error scanning %s: %s", path_result.path, path_result.error) - continue - files_to_read.extend(path_result.get("files", [])) - - logger.info( - "Found %d eligible file(s) under folder %s (allowed_suffix=%s)", - len(files_to_read), - input_file, - support_suffix, - ) - - for file_info in files_to_read: - try: - file_path = file_info["path"] - suffix = Path(file_path).suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - - yield from reader.read(file_path) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error reading %s: %s", file_info.get("path"), e) + try: + # 1. Scan all paths to discover files + logger.info("[READ] Scanning paths: %s", input_path) + scanner = ParallelFileScanner( + cache_dir=cache_dir, + allowed_suffix=allowed_suffix, + rescan=False, + max_workers=parallelism if parallelism > 0 else 1, + ) + + all_files = [] + scan_results = scanner.scan(input_path, recursive=recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + + if not all_files: + return ray.data.from_items([]) + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + for suffix, file_paths in files_by_suffix.items(): + reader = _build_reader(suffix, cache_dir, **reader_kwargs) + ds = reader.read(file_paths, parallelism=parallelism) + read_tasks.append(ds) + + # 4. Combine all datasets + if not read_tasks: + logger.warning("[READ] No datasets created") + return ray.data.from_items([]) + + if len(read_tasks) == 1: + logger.info("[READ] Successfully read files from %s", input_path) + return read_tasks[0] + # len(read_tasks) > 1 + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds + + except Exception as e: + logger.error("[READ] Failed to read files from %s: %s", input_path, e) + raise From 246348fb9781b3c0585bffcac5eb9f89a23b8c0f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 3 Dec 2025 19:26:09 +0800 Subject: [PATCH 03/28] fix: delete param parallelism for readers --- graphgen/models/reader/csv_reader.py | 9 ++------ graphgen/models/reader/json_reader.py | 9 ++------ graphgen/models/reader/parquet_reader.py | 9 ++------ graphgen/models/reader/pdf_reader.py | 1 - graphgen/models/reader/pickle_reader.py | 6 +---- graphgen/models/reader/rdf_reader.py | 2 -- graphgen/models/reader/txt_reader.py | 6 +---- graphgen/operators/__init__.py | 2 +- graphgen/operators/evaluate.py | 10 ++++++--- graphgen/operators/read/__init__.py | 2 +- .../operators/read/{read_files.py => read.py} | 6 ++--- graphgen/operators/registry.py | 22 +++++++++++++++++++ 12 files changed, 42 insertions(+), 42 deletions(-) rename graphgen/operators/read/{read_files.py => read.py} (96%) create mode 100644 graphgen/operators/registry.py diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index 99faa30e..a0343d97 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -14,20 +14,15 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read( - self, - input_path: Union[str, List[str]], - parallelism: int = None, - ) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read CSV files and return Ray Dataset. :param input_path: Path to CSV file or list of CSV files. - :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated and filtered data. """ - ds = ray.data.read_csv(input_path, override_num_blocks=parallelism) + ds = ray.data.read_csv(input_path) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 1bcba4ea..b53c8b1d 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -14,19 +14,14 @@ class JSONReader(BaseReader): - if type is "text", "content" column must be present. """ - def read( - self, - input_path: Union[str, List[str]], - parallelism: int = 4, - ) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read JSON file and return Ray Dataset. :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. - :param parallelism: Number of parallel workers for reading files. :return: Ray Dataset containing validated and filtered data. """ - ds = ray.data.read_json(input_path, override_num_blocks=parallelism) + ds = ray.data.read_json(input_path) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index 5423643b..dd289e31 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -14,22 +14,17 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read( - self, - input_path: Union[str, List[str]], - parallelism: int = None, - ) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read Parquet files using Ray Data. :param input_path: Path to Parquet file or list of Parquet files. - :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated documents. """ if not ray.is_initialized(): ray.init() - ds = ray.data.read_parquet(input_path, override_num_blocks=parallelism) + ds = ray.data.read_parquet(input_path) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 9d5c7c27..55dab30b 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -68,7 +68,6 @@ def __init__( def read( self, input_path: Union[str, List[str]], - parallelism: int = 4, **override, ) -> Dataset: diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 0b0e5719..6e3d1949 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -23,22 +23,18 @@ class PickleReader(BaseReader): def read( self, input_path: Union[str, List[str]], - parallelism: int = None, ) -> Dataset: """ Read Pickle files using Ray Data. :param input_path: Path to pickle file or list of pickle files. - :param parallelism: Number of blocks for Ray Dataset reading. :return: Ray Dataset containing validated documents. """ if not ray.is_initialized(): ray.init() # Use read_binary_files as a reliable alternative to read_pickle - ds = ray.data.read_binary_files( - input_path, override_num_blocks=parallelism, include_paths=True - ) + ds = ray.data.read_binary_files(input_path, include_paths=True) # Deserialize pickle files and flatten into individual records def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index 406478f5..9670107a 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -30,13 +30,11 @@ def __init__(self, *, text_column: str = "content", **kwargs): def read( self, input_path: Union[str, List[str]], - parallelism: int = 4, ) -> Dataset: """ Read RDF file(s) using Ray Data. :param input_path: Path to RDF file or list of RDF files. - :param parallelism: Number of parallel workers for processing. :return: Ray Dataset containing extracted documents. """ if not ray.is_initialized(): diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index bb6cce9e..0194ca68 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -10,17 +10,13 @@ class TXTReader(BaseReader): def read( self, input_path: Union[str, List[str]], - parallelism: int = 4, ) -> Dataset: """ Read text files from the specified input path. :param input_path: Path to the input text file or list of text files. - :param parallelism: Number of blocks to override for Ray Dataset reading. :return: Ray Dataset containing the read text data. """ - docs_ds = ray.data.read_text( - input_path, encoding="utf-8", override_num_blocks=parallelism - ) + docs_ds = ray.data.read_text(input_path, encoding="utf-8") docs_ds = docs_ds.map( lambda row: { diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 97f4b3c8..38ced41e 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -4,6 +4,6 @@ from .init import init_llm from .partition import partition_kg from .quiz_and_judge import judge_statement, quiz -from .read import read_files +from .read import read from .search import search_all from .split import chunk_documents diff --git a/graphgen/operators/evaluate.py b/graphgen/operators/evaluate.py index d1e2413b..fdbfbf82 100644 --- a/graphgen/operators/evaluate.py +++ b/graphgen/operators/evaluate.py @@ -9,9 +9,13 @@ from dotenv import load_dotenv from graphgen.bases.datatypes import QAPair - -from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator -from .utils import logger, set_logger +from graphgen.models import ( + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + UniEvaluator, +) +from graphgen.utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) diff --git a/graphgen/operators/read/__init__.py b/graphgen/operators/read/__init__.py index 075ae938..cda44587 100644 --- a/graphgen/operators/read/__init__.py +++ b/graphgen/operators/read/__init__.py @@ -1 +1 @@ -from .read_files import read_files +from .read import read diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read.py similarity index 96% rename from graphgen/operators/read/read_files.py rename to graphgen/operators/read/read.py index 34ffee85..a7a41f72 100644 --- a/graphgen/operators/read/read_files.py +++ b/graphgen/operators/read/read.py @@ -47,10 +47,10 @@ def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): return reader_cls(**reader_kwargs) -def read_files( +def read( input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = None, + cache_dir: Optional[str] = "cache", parallelism: int = 4, recursive: bool = True, **reader_kwargs: Any, @@ -101,7 +101,7 @@ def read_files( read_tasks = [] for suffix, file_paths in files_by_suffix.items(): reader = _build_reader(suffix, cache_dir, **reader_kwargs) - ds = reader.read(file_paths, parallelism=parallelism) + ds = reader.read(file_paths) read_tasks.append(ds) # 4. Combine all datasets diff --git a/graphgen/operators/registry.py b/graphgen/operators/registry.py new file mode 100644 index 00000000..e0c5826d --- /dev/null +++ b/graphgen/operators/registry.py @@ -0,0 +1,22 @@ +from .build_kg import build_kg +from .extract import extract_info +from .generate import generate_qas +from .init import init_llm +from .partition import partition_kg +from .quiz_and_judge import judge_statement, quiz +from .read import read +from .search import search_all +from .split import chunk_documents + +operators = { + "read": read, + "init_llm": init_llm, + "chunk_documents": chunk_documents, + "extract_info": extract_info, + "search_all": search_all, + "build_kg": build_kg, + "partition_kg": partition_kg, + "generate_qas": generate_qas, + "quiz": quiz, + "judge_statement": judge_statement, +} From 319e1e7a956f70f6d6a2928d5a7c7cb3ff17d618 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 3 Dec 2025 19:38:25 +0800 Subject: [PATCH 04/28] fix: fix import error --- graphgen/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3ef1ff69..3716aa9a 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -18,7 +18,6 @@ ) from .reader import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, From 42fcb0970b895ea7e0703c3fc7b43fe6a393dcbe Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 13:11:14 +0800 Subject: [PATCH 05/28] refactor read and chunk operators with no side effects --- .../storage/{ => graph}/networkx_storage.py | 0 .../models/storage/{ => kv}/json_storage.py | 40 +-------- graphgen/operators/chunk/__init__.py | 1 + graphgen/operators/chunk/chunk_service.py | 87 +++++++++++++++++++ graphgen/operators/read/read.py | 16 ++-- graphgen/operators/split/__init__.py | 1 - graphgen/operators/split/split_chunks.py | 84 ------------------ 7 files changed, 100 insertions(+), 129 deletions(-) rename graphgen/models/storage/{ => graph}/networkx_storage.py (100%) rename graphgen/models/storage/{ => kv}/json_storage.py (59%) create mode 100644 graphgen/operators/chunk/__init__.py create mode 100644 graphgen/operators/chunk/chunk_service.py delete mode 100644 graphgen/operators/split/__init__.py delete mode 100644 graphgen/operators/split/split_chunks.py diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py similarity index 100% rename from graphgen/models/storage/networkx_storage.py rename to graphgen/models/storage/graph/networkx_storage.py diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/kv/json_storage.py similarity index 59% rename from graphgen/models/storage/json_storage.py rename to graphgen/models/storage/kv/json_storage.py index 53962117..3fcae5a3 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/kv/json_storage.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass -from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage +from graphgen.bases.base_storage import BaseKVStorage from graphgen.utils import load_json, logger, write_json @@ -54,41 +54,3 @@ def upsert(self, data: dict): def drop(self): if self._data: self._data.clear() - - -@dataclass -class JsonListStorage(BaseListStorage): - working_dir: str = None - namespace: str = None - _data: list = None - - def __post_init__(self): - self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") - self._data = load_json(self._file_name) or [] - logger.info("Load List %s with %d data", self.namespace, len(self._data)) - - @property - def data(self): - return self._data - - def all_items(self) -> list: - return self._data - - def index_done_callback(self): - write_json(self._data, self._file_name) - - def get_by_index(self, index: int): - if index < 0 or index >= len(self._data): - return None - return self._data[index] - - def append(self, data): - self._data.append(data) - - def upsert(self, data: list): - left_data = [d for d in data if d not in self._data] - self._data.extend(left_data) - return left_data - - def drop(self): - self._data = [] diff --git a/graphgen/operators/chunk/__init__.py b/graphgen/operators/chunk/__init__.py new file mode 100644 index 00000000..f2f116f7 --- /dev/null +++ b/graphgen/operators/chunk/__init__.py @@ -0,0 +1 @@ +from .chunk_service import ChunkService diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py new file mode 100644 index 00000000..df54ef10 --- /dev/null +++ b/graphgen/operators/chunk/chunk_service.py @@ -0,0 +1,87 @@ +import asyncio +import os +from functools import lru_cache +from typing import Union + +import pandas as pd +from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.models import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, + Tokenizer, +) +from graphgen.utils import compute_content_hash, detect_main_language + +_MAPPING = { + "en": RecursiveCharacterSplitter, + "zh": ChineseRecursiveTextSplitter, +} + +SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] + + +@lru_cache(maxsize=None) +def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: + cls = _MAPPING[language] + kwargs = dict(frozen_kwargs) + return cls(**kwargs) + + +def split_chunks(text: str, language: str = "en", **kwargs) -> list: + if language not in _MAPPING: + raise ValueError( + f"Unsupported language: {language}. " + f"Supported languages are: {list(_MAPPING.keys())}" + ) + frozen_kwargs = frozenset( + (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() + ) + splitter = _get_splitter(language, frozen_kwargs) + return splitter.split_text(text) + + +class ChunkService: + def __init__(self, **chunk_kwargs): + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) + self.chunk_kwargs = chunk_kwargs + + def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + return pd.DataFrame(self.chunk_documents(docs)) + + def chunk_documents(self, new_docs: list) -> list: + for doc in new_docs: + doc_id = doc.get("_doc_id") + doc_type = doc.get("type") + + if doc_type == "text": + doc_language = detect_main_language(doc["content"]) + text_chunks = split_chunks( + doc["content"], + language=doc_language, + **self.chunk_kwargs, + ) + + return [ + { + "_chunk_id": compute_content_hash(chunk_text, prefix="chunk-"), + "content": chunk_text, + "type": "text", + "_doc_id": doc_id, + "length": len(self.tokenizer_instance.encode(chunk_text)) + if self.tokenizer_instance + else len(chunk_text), + "language": doc_language, + } + for chunk_text in text_chunks + ] + + # other types of documents(images, sequences) are not chunked + return [ + { + "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), + **doc, + } + ] diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index a7a41f72..f98a97ca 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -12,7 +12,7 @@ RDFReader, TXTReader, ) -from graphgen.utils import logger +from graphgen.utils import compute_mm_hash, logger from .parallel_file_scanner import ParallelFileScanner @@ -110,10 +110,16 @@ def read( return ray.data.from_items([]) if len(read_tasks) == 1: - logger.info("[READ] Successfully read files from %s", input_path) - return read_tasks[0] - # len(read_tasks) > 1 - combined_ds = read_tasks[0].union(*read_tasks[1:]) + combined_ds = read_tasks[0] + else: + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + combined_ds = combined_ds.map( + lambda record: { + **record, + "_doc_id": compute_mm_hash(record), + } + ) logger.info("[READ] Successfully read files from %s", input_path) return combined_ds diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py deleted file mode 100644 index 2afc738d..00000000 --- a/graphgen/operators/split/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .split_chunks import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/split/split_chunks.py deleted file mode 100644 index 3f728e00..00000000 --- a/graphgen/operators/split/split_chunks.py +++ /dev/null @@ -1,84 +0,0 @@ -from functools import lru_cache -from typing import Union - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import ( - ChineseRecursiveTextSplitter, - RecursiveCharacterSplitter, - Tokenizer, -) -from graphgen.utils import compute_content_hash, detect_main_language - -_MAPPING = { - "en": RecursiveCharacterSplitter, - "zh": ChineseRecursiveTextSplitter, -} - -SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] - - -@lru_cache(maxsize=None) -def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: - cls = _MAPPING[language] - kwargs = dict(frozen_kwargs) - return cls(**kwargs) - - -def split_chunks(text: str, language: str = "en", **kwargs) -> list: - if language not in _MAPPING: - raise ValueError( - f"Unsupported language: {language}. " - f"Supported languages are: {list(_MAPPING.keys())}" - ) - frozen_kwargs = frozenset( - (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() - ) - splitter = _get_splitter(language, frozen_kwargs) - return splitter.split_text(text) - - -async def chunk_documents( - new_docs: dict, - tokenizer_instance: Tokenizer = None, - progress_bar=None, - **kwargs, -) -> dict: - inserting_chunks = {} - cur_index = 1 - doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): - doc_type = doc.get("type") - if doc_type == "text": - doc_language = detect_main_language(doc["content"]) - - text_chunks = split_chunks( - doc["content"], - language=doc_language, - **kwargs, - ) - - chunks = { - compute_content_hash(txt, prefix="chunk-"): { - "content": txt, - "type": "text", - "_full_docs_id": doc_key, - "length": len(tokenizer_instance.encode(txt)) - if tokenizer_instance - else len(txt), - "language": doc_language, - } - for txt in text_chunks - } - else: - chunks = {doc_key.replace("doc-", f"{doc_type}-"): {**doc}} - - inserting_chunks.update(chunks) - - if progress_bar is not None: - progress_bar(cur_index / doc_number, f"Chunking {doc_key}") - cur_index += 1 - - return inserting_chunks From b458e48901655ae4930d8c3a53df372b67e2c0d9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 13:12:10 +0800 Subject: [PATCH 06/28] fix: fix import error --- graphgen/operators/chunk/chunk_service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index df54ef10..f4bd05e2 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -1,10 +1,8 @@ -import asyncio import os from functools import lru_cache from typing import Union import pandas as pd -from tqdm.asyncio import tqdm as tqdm_async from graphgen.models import ( ChineseRecursiveTextSplitter, From 95c478339eb36563ff5b9f5e2cbb9853c24d979d Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 13:20:52 +0800 Subject: [PATCH 07/28] fix: fix return logic --- graphgen/operators/chunk/chunk_service.py | 44 +++++++++++++---------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index f4bd05e2..a0abbe57 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -50,6 +50,7 @@ def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame(self.chunk_documents(docs)) def chunk_documents(self, new_docs: list) -> list: + chunks = [] for doc in new_docs: doc_id = doc.get("_doc_id") doc_type = doc.get("type") @@ -62,24 +63,29 @@ def chunk_documents(self, new_docs: list) -> list: **self.chunk_kwargs, ) - return [ + chunks.extend( + [ + { + "_chunk_id": compute_content_hash( + chunk_text, prefix="chunk-" + ), + "content": chunk_text, + "type": "text", + "_doc_id": doc_id, + "length": len(self.tokenizer_instance.encode(chunk_text)) + if self.tokenizer_instance + else len(chunk_text), + "language": doc_language, + } + for chunk_text in text_chunks + ] + ) + else: + # other types of documents(images, sequences) are not chunked + chunks.append( { - "_chunk_id": compute_content_hash(chunk_text, prefix="chunk-"), - "content": chunk_text, - "type": "text", - "_doc_id": doc_id, - "length": len(self.tokenizer_instance.encode(chunk_text)) - if self.tokenizer_instance - else len(chunk_text), - "language": doc_language, + "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), + **doc, } - for chunk_text in text_chunks - ] - - # other types of documents(images, sequences) are not chunked - return [ - { - "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), - **doc, - } - ] + ) + return chunks From c844d656803d03c7db17bf66e8063448f77b79fa Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 13:32:04 +0800 Subject: [PATCH 08/28] refactor: rename operator split to chunk --- graphgen/bases/base_splitter.py | 6 +++--- graphgen/models/splitter/character_splitter.py | 2 +- graphgen/models/splitter/markdown_splitter.py | 4 ++-- graphgen/models/splitter/recursive_character_splitter.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index b2d1ad3a..f77be6e4 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -4,7 +4,7 @@ from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger +from graphgen.utils.log import logger class BaseSplitter(ABC): @@ -33,7 +33,7 @@ def split_text(self, text: str) -> List[str]: """ Split the input text into smaller chunks. - :param text: The input text to be split. + :param text: The input text to be chunk. :return: A list of text chunks. """ @@ -111,7 +111,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: def _split_text_with_regex( text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/models/splitter/character_splitter.py b/graphgen/models/splitter/character_splitter.py index 1c91877e..8877c861 100644 --- a/graphgen/models/splitter/character_splitter.py +++ b/graphgen/models/splitter/character_splitter.py @@ -17,7 +17,7 @@ def __init__( def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. + # First we naively chunk the large input into a bunch of smaller ones. separator = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) diff --git a/graphgen/models/splitter/markdown_splitter.py b/graphgen/models/splitter/markdown_splitter.py index 03def6ae..40b6a44e 100644 --- a/graphgen/models/splitter/markdown_splitter.py +++ b/graphgen/models/splitter/markdown_splitter.py @@ -6,12 +6,12 @@ class MarkdownTextRefSplitter(RecursiveCharacterSplitter): - """Attempts to split the text along Markdown-formatted headings.""" + """Attempts to chunk the text along Markdown-formatted headings.""" def __init__(self, **kwargs: Any) -> None: """Initialize a MarkdownTextRefSplitter.""" separators = [ - # First, try to split along Markdown headings (starting with level 2) + # First, try to chunk along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 diff --git a/graphgen/models/splitter/recursive_character_splitter.py b/graphgen/models/splitter/recursive_character_splitter.py index c9d7c543..b1ee8e06 100644 --- a/graphgen/models/splitter/recursive_character_splitter.py +++ b/graphgen/models/splitter/recursive_character_splitter.py @@ -7,7 +7,7 @@ class RecursiveCharacterSplitter(BaseSplitter): """Splitting text by recursively look at characters. - Recursively tries to split by different characters to find one that works. + Recursively tries to chunk by different characters to find one that works. """ def __init__( @@ -88,7 +88,7 @@ def __init__( def _split_text_with_regex_from_end( self, text: str, separator: str, keep_separator: bool ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. From c44793616eed66ffd10d5b2812f557b5572d2a8e Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 17:11:37 +0800 Subject: [PATCH 09/28] refactor: refactor build_kg to accomodate ray data --- .../{operators/init => common}/init_llm.py | 2 + graphgen/engine.py | 269 +++++++++++------- graphgen/operators/build_kg/__init__.py | 2 +- graphgen/operators/build_kg/build_kg.py | 59 ---- .../operators/build_kg/build_kg_service.py | 55 ++++ graphgen/operators/build_kg/build_mm_kg.py | 15 +- graphgen/operators/build_kg/build_text_kg.py | 15 +- graphgen/operators/init/__init__.py | 1 - graphgen/utils/run_concurrent.py | 66 ++--- 9 files changed, 258 insertions(+), 226 deletions(-) rename graphgen/{operators/init => common}/init_llm.py (97%) delete mode 100644 graphgen/operators/build_kg/build_kg.py create mode 100644 graphgen/operators/build_kg/build_kg_service.py delete mode 100644 graphgen/operators/init/__init__.py diff --git a/graphgen/operators/init/init_llm.py b/graphgen/common/init_llm.py similarity index 97% rename from graphgen/operators/init/init_llm.py rename to graphgen/common/init_llm.py index e294d2c3..9d3fe12a 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/common/init_llm.py @@ -79,3 +79,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: backend = config.pop("backend") llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) return llm_wrapper + +# TODO: use ray serve when loading large models to avoid re-loading in each actor diff --git a/graphgen/engine.py b/graphgen/engine.py index 2989226c..bc529f27 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,125 +1,186 @@ -""" -orchestration engine for GraphGen -""" +import inspect +import logging +from collections import defaultdict, deque +from functools import wraps +from typing import Any, Callable, Dict, List, Set -import threading -import traceback -from typing import Any, Callable, List +import ray +import ray.data +from graphgen.bases import Config, Node -class Context(dict): - _lock = threading.Lock() - def set(self, k, v): - with self._lock: - self[k] = v +class Engine: + def __init__( + self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs + ): + self.config = Config(**config) + self.functions = functions + self.datasets: Dict[str, ray.data.Dataset] = {} + + if not ray.is_initialized(): + context = ray.init( + ignore_reinit_error=True, + logging_level=logging.ERROR, + log_to_driver=True, + **ray_init_kwargs, + ) + print(f"Ray Dashboard URL: {context.dashboard_url}") - def get(self, k, default=None): - with self._lock: - return super().get(k, default) + @staticmethod + def _topo_sort(nodes: List[Node]) -> List[Node]: + id_to_node: Dict[str, Node] = {} + for n in nodes: + id_to_node[n.id] = n + + indeg: Dict[str, int] = {nid: 0 for nid in id_to_node} + adj: Dict[str, List[str]] = defaultdict(list) + + for n in nodes: + nid = n.id + deps: List[str] = n.dependencies + uniq_deps: Set[str] = set(deps) + for d in uniq_deps: + if d not in id_to_node: + raise ValueError( + f"The dependency node id {d} of node {nid} is not defined in the configuration." + ) + indeg[nid] += 1 + adj[d].append(nid) + + zero_deg: deque = deque( + [id_to_node[nid] for nid, deg in indeg.items() if deg == 0] + ) + sorted_nodes: List[Node] = [] + + while zero_deg: + cur = zero_deg.popleft() + sorted_nodes.append(cur) + cur_id = cur.id + for nb_id in adj.get(cur_id, []): + indeg[nb_id] -= 1 + if indeg[nb_id] == 0: + zero_deg.append(id_to_node[nb_id]) + + if len(sorted_nodes) != len(nodes): + remaining = [nid for nid, deg in indeg.items() if deg > 0] + raise ValueError( + f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}" + ) + return sorted_nodes -class OpNode: - def __init__( - self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] - ): - self.name, self.deps, self.func = name, deps, func + def _get_input_dataset( + self, node: Node, initial_ds: ray.data.Dataset + ) -> ray.data.Dataset: + deps = node.dependencies + if not deps: + return initial_ds -class Engine: - def __init__(self, max_workers: int = 4): - self.max_workers = max_workers - - def run(self, ops: List[OpNode], ctx: Context): - self._validate(ops) - name2op = {operation.name: operation for operation in ops} - - # topological sort - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(cur) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) - - if len(topo) != len(ops): + if len(deps) == 1: + return self.datasets[deps[0]] + + main_ds = self.datasets[deps[0]] + other_dss = [self.datasets[d] for d in deps[1:]] + if not all(ds.schema() == main_ds.schema() for ds in other_dss): raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." + f"Union requires all datasets to have the same schema for node {node.id}" ) + return main_ds.union(*other_dss) + + def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): + if node.op_name not in self.functions: + raise ValueError(f"Operator {node.op_name} not found for node {node.id}") + + if node.type == "source": + op_handler = self.functions[node.op_name] + node_params = node.params + self.datasets[node.id] = op_handler(**node_params) + return - # semaphore for max_workers - sem = threading.Semaphore(self.max_workers) - done = {n: threading.Event() for n in name2op} - exc = {} - - def _exec(n: str): - with sem: - for d in name2op[n].deps: - done[d].wait() - if any(d in exc for d in name2op[n].deps): - exc[n] = Exception("Skipped due to failed dependencies") - done[n].set() - return - try: - name2op[n].func(name2op[n], ctx) - except Exception: - exc[n] = traceback.format_exc() - done[n].set() - - ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] - for t in ts: - t.start() - for t in ts: - t.join() - if exc: - raise RuntimeError( - "Some operations failed:\n" - + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) + input_ds = self._get_input_dataset(node, initial_ds) + + op_handler = self.functions[node.op_name] + node_params = node.params + + if inspect.isclass(op_handler): + replicas = node_params.pop("replicas", 1) + batch_size = ( + int(node_params.pop("batch_size")) + if "batch_size" in node_params + else "default" ) + compute_resources = node_params.pop("compute_resources", {}) + + if node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), + batch_size=None, # aggregate processes the whole dataset at once + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) + else: + # others like map, filter, flatmap, map_batch let actors process data inside batches + self.datasets[node.id] = input_ds.map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), + batch_size=batch_size, + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) - @staticmethod - def _validate(ops: List[OpNode]): - name_set = set() - for op in ops: - if op.name in name_set: - raise ValueError(f"Duplicate operation name: {op.name}") - name_set.add(op.name) - for op in ops: - for dep in op.deps: - if dep not in name_set: - raise ValueError( - f"Operation {op.name} has unknown dependency: {dep}" - ) + else: + @wraps(op_handler) + def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: + return op_handler(row_or_batch, **node_params) + + if node.type == "map": + self.datasets[node.id] = input_ds.map(func_wrapper) + elif node.type == "filter": + self.datasets[node.id] = input_ds.filter(func_wrapper) + elif node.type == "flatmap": + self.datasets[node.id] = input_ds.flat_map(func_wrapper) + elif node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + func_wrapper, batch_format="default" + ) + elif node.type == "map_batch": + self.datasets[node.id] = input_ds.map_batches(func_wrapper) + else: + raise ValueError( + f"Unsupported node type {node.type} for node {node.id}" + ) -def collect_ops(config: dict, graph_gen) -> List[OpNode]: - """ - build operation nodes from yaml config - :param config - :param graph_gen - """ - ops: List[OpNode] = [] - for stage in config["pipeline"]: - name = stage["name"] - method_name = stage.get("op_key") - method = getattr(graph_gen, method_name) - deps = stage.get("deps", []) + @staticmethod + def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: + all_ids = {n.id for n in nodes} + deps_set = set() + for n in nodes: + deps_set.update(n.dependencies) + return all_ids - deps_set - if "params" in stage: + def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]: + sorted_nodes = self._topo_sort(self.config.nodes) - def func(self, ctx, _method=method, _params=stage.get("params", {})): - return _method(_params) + for node in sorted_nodes: + self._execute_node(node, initial_ds) - else: + leaf_nodes = self._find_leaf_nodes(sorted_nodes) - def func(self, ctx, _method=method): - return _method() + @ray.remote + def _fetch_result(ds: ray.data.Dataset) -> List[Any]: + return ds.take_all() - op_node = OpNode(name=name, deps=deps, func=func) - ops.append(op_node) - return ops + results = ray.get( + [_fetch_result.remote(self.datasets[node_id]) for node_id in leaf_nodes] + ) + return dict(zip(leaf_nodes, results)) diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index 18766fe6..a8b22ce9 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -1 +1 @@ -from .build_kg import build_kg +from .build_kg_service import BuildKGService diff --git a/graphgen/operators/build_kg/build_kg.py b/graphgen/operators/build_kg/build_kg.py deleted file mode 100644 index a8a6146d..00000000 --- a/graphgen/operators/build_kg/build_kg.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import List - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger - -from .build_mm_kg import build_mm_kg -from .build_text_kg import build_text_kg - - -async def build_kg( - llm_client: BaseLLMWrapper, - kg_instance: BaseGraphStorage, - chunks: List[Chunk], - progress_bar: gr.Progress = None, -): - """ - Build knowledge graph (KG) and merge into kg_instance - :param llm_client: Synthesizer LLM model to extract entities and relationships - :param kg_instance - :param chunks - :param anchor_type: get this type of information from chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction - :return: - """ - - text_chunks = [chunk for chunk in chunks if chunk.type == "text"] - mm_chunks = [ - chunk - for chunk in chunks - if chunk.type in ("image", "video", "table", "formula") - ] - - if len(text_chunks) == 0: - logger.info("All text chunks are already in the storage") - else: - logger.info("[Text Entity and Relation Extraction] processing ...") - await build_text_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=text_chunks, - progress_bar=progress_bar, - ) - - if len(mm_chunks) == 0: - logger.info("All multi-modal chunks are already in the storage") - else: - logger.info("[Multi-modal Entity and Relation Extraction] processing ...") - await build_mm_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=mm_chunks, - progress_bar=progress_bar, - ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py new file mode 100644 index 00000000..7520dd7e --- /dev/null +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -0,0 +1,55 @@ +from typing import List +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseGraphStorage +from graphgen.bases.datatypes import Chunk +from graphgen.common import init_llm, init_storage +from graphgen.utils import logger +from .build_text_kg import build_text_kg +from .build_mm_kg import build_mm_kg + + +class BuildKGService: + def __init__(self): + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", working_dir="cache",namespace="graph") + + def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] + return pd.DataFrame(self.build_kg(docs)) + + + def build_kg(self, chunks: List[Chunk]) -> List: + """ + Build knowledge graph (KG) and merge into kg_instance + """ + text_chunks = [chunk for chunk in chunks if chunk.type == "text"] + mm_chunks = [ + chunk + for chunk in chunks + if chunk.type in ("image", "video", "table", "formula") + ] + + if len(text_chunks) == 0: + logger.info("All text chunks are already in the storage") + else: + logger.info("[Text Entity and Relation Extraction] processing ...") + build_text_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=text_chunks, + ) + if len(mm_chunks) == 0: + logger.info("All multi-modal chunks are already in the storage") + else: + logger.info("[Multi-modal Entity and Relation Extraction] processing ...") + build_mm_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=mm_chunks, + ) + + self.graph_storage.index_done_callback() + return [{"_chunk_id": chunk.id} for chunk in chunks] diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py index 624b10ad..ee0459ea 100644 --- a/graphgen/operators/build_kg/build_mm_kg.py +++ b/graphgen/operators/build_kg/build_mm_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.utils import run_concurrent -async def build_mm_kg( +def build_mm_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ Build multi-modal KG and merge into kg_instance :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ mm_builder = MMKGBuilder(llm_client=llm_client) - results = await run_concurrent( + results = run_concurrent( mm_builder.extract, chunks, desc="[2/4] Extracting entities and relationships from multi-modal chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_mm_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 3c75f022..1b5a8762 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.utils import run_concurrent -async def build_text_kg( +def build_text_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) - results = await run_concurrent( + results = run_concurrent( kg_builder.extract, chunks, desc="[2/4]Extracting entities and relationships from chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_text_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py deleted file mode 100644 index ec604441..00000000 --- a/graphgen/operators/init/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .init_llm import init_llm diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index ac63f87b..7c85a6b6 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -1,55 +1,43 @@ import asyncio -from typing import Awaitable, Callable, List, Optional, TypeVar +from typing import Awaitable, Callable, List, TypeVar -import gradio as gr from tqdm.asyncio import tqdm as tqdm_async from graphgen.utils.log import logger +from .loop import create_event_loop T = TypeVar("T") R = TypeVar("R") -async def run_concurrent( +def run_concurrent( coro_fn: Callable[[T], Awaitable[R]], items: List[T], *, desc: str = "processing", unit: str = "item", - progress_bar: Optional[gr.Progress] = None, ) -> List[R]: - tasks = [asyncio.create_task(coro_fn(it)) for it in items] - - completed_count = 0 - results = [] - - pbar = tqdm_async(total=len(items), desc=desc, unit=unit) - - if progress_bar is not None: - progress_bar(0.0, desc=f"{desc} (0/{len(items)})") - - for future in asyncio.as_completed(tasks): - try: - result = await future - results.append(result) - except Exception as e: # pylint: disable=broad-except - logger.exception("Task failed: %s", e) - # even if failed, record it to keep results consistent with tasks - results.append(e) - - completed_count += 1 - pbar.update(1) - - if progress_bar is not None: - progress = completed_count / len(items) - progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})") - - pbar.close() - - if progress_bar is not None: - progress_bar(1.0, desc=f"{desc} (completed)") - - # filter out exceptions - results = [res for res in results if not isinstance(res, Exception)] - - return results + async def _run_all(): + tasks = [asyncio.create_task(coro_fn(item)) for item in items] + + results = [] + pbar = tqdm_async(total=len(items), desc=desc, unit=unit) + + for future in asyncio.as_completed(tasks): + try: + result = await future + results.append(result) + except Exception as e: # pylint: disable=broad-except + logger.exception("Task failed: %s", e) + results.append(e) + + pbar.update(1) + + pbar.close() + return [res for res in results if not isinstance(res, Exception)] + + loop = create_event_loop() + try: + return loop.run_until_complete(_run_all()) + finally: + loop.close() From 3edbb817d7d90867e0d855c52d956a97df7fc7e4 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 4 Dec 2025 19:46:55 +0800 Subject: [PATCH 10/28] feat: add StorageFactory & global params --- graphgen/bases/base_storage.py | 17 ------- graphgen/bases/datatypes.py | 7 +++ graphgen/common/__init__.py | 2 + graphgen/common/init_llm.py | 2 + graphgen/common/init_storage.py | 28 +++++++++++ graphgen/engine.py | 50 +++++++++++++++---- .../operators/build_kg/build_kg_service.py | 18 ++++--- graphgen/operators/{ => evaluate}/evaluate.py | 0 .../{quiz_and_judge => judge}/judge.py | 0 .../{quiz_and_judge => quiz}/quiz.py | 0 graphgen/operators/quiz_and_judge/__init__.py | 2 - graphgen/operators/registry.py | 22 -------- 12 files changed, 91 insertions(+), 57 deletions(-) create mode 100644 graphgen/common/__init__.py create mode 100644 graphgen/common/init_storage.py rename graphgen/operators/{ => evaluate}/evaluate.py (100%) rename graphgen/operators/{quiz_and_judge => judge}/judge.py (100%) rename graphgen/operators/{quiz_and_judge => quiz}/quiz.py (100%) delete mode 100644 graphgen/operators/quiz_and_judge/__init__.py delete mode 100644 graphgen/operators/registry.py diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index bfcd658c..53610a5d 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -16,23 +16,6 @@ def query_done_callback(self): """commit the storage operations after querying""" -class BaseListStorage(Generic[T], StorageNameSpace): - def all_items(self) -> list[T]: - raise NotImplementedError - - def get_by_index(self, index: int) -> Union[T, None]: - raise NotImplementedError - - def append(self, data: T): - raise NotImplementedError - - def upsert(self, data: list[T]): - raise NotImplementedError - - def drop(self): - raise NotImplementedError - - class BaseKVStorage(Generic[T], StorageNameSpace): def all_keys(self) -> list[str]: raise NotImplementedError diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 199ba80b..df719fdf 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -62,6 +62,9 @@ class Node(BaseModel): dependencies: List[str] = Field( default_factory=list, description="list of dependent node ids" ) + execution_params: dict = Field( + default_factory=dict, description="execution parameters like replicas, batch_size" + ) @classmethod @field_validator("type") @@ -73,6 +76,10 @@ def validate_type(cls, v: str) -> str: class Config(BaseModel): + global_params: dict = Field( + default_factory=dict, description="global context for the computation graph" + ) + nodes: List[Node] = Field( ..., min_length=1, description="list of nodes in the computation graph" ) diff --git a/graphgen/common/__init__.py b/graphgen/common/__init__.py new file mode 100644 index 00000000..deb99459 --- /dev/null +++ b/graphgen/common/__init__.py @@ -0,0 +1,2 @@ +from .init_llm import init_llm +from .init_storage import init_storage diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py index 9d3fe12a..79a8677b 100644 --- a/graphgen/common/init_llm.py +++ b/graphgen/common/init_llm.py @@ -29,6 +29,7 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: return HTTPClient(**config) if backend in ("openai_api", "azure_openai_api"): from graphgen.models.llm.api.openai_client import OpenAIClient + # pass in concrete backend to the OpenAIClient so that internally we can distinguish # between OpenAI and Azure OpenAI return OpenAIClient(**config, backend=backend) @@ -80,4 +81,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) return llm_wrapper + # TODO: use ray serve when loading large models to avoid re-loading in each actor diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py new file mode 100644 index 00000000..f9c4de57 --- /dev/null +++ b/graphgen/common/init_storage.py @@ -0,0 +1,28 @@ +from graphgen.models import JsonKVStorage, NetworkXStorage + + +class StorageFactory: + """ + Factory class to create storage instances based on backend. + Supported backends: + kv_storage(key-value storage): + - json_kv: JsonKVStorage + graph_storage: + - networkx: NetworkXStorage (graph storage) + """ + + @staticmethod + def create_storage(backend: str, working_dir: str, namespace: str): + if backend == "json_kv": + return JsonKVStorage(working_dir, namespace=namespace) + + if backend == "networkx": + return NetworkXStorage(working_dir, namespace=namespace) + + raise NotImplementedError( + f"Storage backend '{backend}' is not implemented yet." + ) + + +def init_storage(backend: str, working_dir: str, namespace: str): + return StorageFactory.create_storage(backend, working_dir, namespace) diff --git a/graphgen/engine.py b/graphgen/engine.py index bc529f27..1bdce370 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -15,6 +15,7 @@ def __init__( self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs ): self.config = Config(**config) + self.global_params = self.config.global_params self.functions = functions self.datasets: Dict[str, ray.data.Dataset] = {} @@ -90,28 +91,59 @@ def _get_input_dataset( return main_ds.union(*other_dss) def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): + def _filter_kwargs( + func_or_class: Callable, + global_params: Dict[str, Any], + func_params: Dict[str, Any], + ) -> Dict[str, Any]: + """ + 1. global_params: only when specified in function signature, will be passed + 2. func_params: pass specified params first, then **kwargs if exists + """ + try: + sig = inspect.signature(func_or_class) + except ValueError: + return {} + + params = sig.parameters + final_kwargs = {} + + has_var_keywords = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + valid_keys = set(params.keys()) + for k, v in global_params.items(): + if k in valid_keys: + final_kwargs[k] = v + + for k, v in func_params.items(): + if k in valid_keys or has_var_keywords: + final_kwargs[k] = v + elif has_var_keywords: + final_kwargs[k] = v + return final_kwargs + if node.op_name not in self.functions: raise ValueError(f"Operator {node.op_name} not found for node {node.id}") + op_handler = self.functions[node.op_name] + node_params = _filter_kwargs(op_handler, self.global_params, node.params or {}) + if node.type == "source": - op_handler = self.functions[node.op_name] - node_params = node.params self.datasets[node.id] = op_handler(**node_params) return input_ds = self._get_input_dataset(node, initial_ds) - op_handler = self.functions[node.op_name] - node_params = node.params - if inspect.isclass(op_handler): - replicas = node_params.pop("replicas", 1) + execution_params = node.execution_params or {} + replicas = execution_params.get("replicas", 1) batch_size = ( - int(node_params.pop("batch_size")) - if "batch_size" in node_params + int(execution_params.get("batch_size")) + if "batch_size" in execution_params else "default" ) - compute_resources = node_params.pop("compute_resources", {}) + compute_resources = execution_params.get("compute_resources", {}) if node.type == "aggregate": self.datasets[node.id] = input_ds.repartition(1).map_batches( diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py index 7520dd7e..00b423cc 100644 --- a/graphgen/operators/build_kg/build_kg_service.py +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -1,27 +1,32 @@ from typing import List + import pandas as pd -from graphgen.bases import BaseLLMWrapper, BaseGraphStorage +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper from graphgen.bases.datatypes import Chunk from graphgen.common import init_llm, init_storage from graphgen.utils import logger -from .build_text_kg import build_text_kg + from .build_mm_kg import build_mm_kg +from .build_text_kg import build_text_kg class BuildKGService: - def __init__(self): + def __init__(self, working_dir: str = "cache"): self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( - backend="networkx", working_dir="cache",namespace="graph") + backend="networkx", working_dir=working_dir, namespace="graph" + ) def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: docs = batch.to_dict(orient="records") docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] - return pd.DataFrame(self.build_kg(docs)) + # consume the chunks and build kg + self.build_kg(docs) + return pd.DataFrame() - def build_kg(self, chunks: List[Chunk]) -> List: + def build_kg(self, chunks: List[Chunk]) -> None: """ Build knowledge graph (KG) and merge into kg_instance """ @@ -52,4 +57,3 @@ def build_kg(self, chunks: List[Chunk]) -> List: ) self.graph_storage.index_done_callback() - return [{"_chunk_id": chunk.id} for chunk in chunks] diff --git a/graphgen/operators/evaluate.py b/graphgen/operators/evaluate/evaluate.py similarity index 100% rename from graphgen/operators/evaluate.py rename to graphgen/operators/evaluate/evaluate.py diff --git a/graphgen/operators/quiz_and_judge/judge.py b/graphgen/operators/judge/judge.py similarity index 100% rename from graphgen/operators/quiz_and_judge/judge.py rename to graphgen/operators/judge/judge.py diff --git a/graphgen/operators/quiz_and_judge/quiz.py b/graphgen/operators/quiz/quiz.py similarity index 100% rename from graphgen/operators/quiz_and_judge/quiz.py rename to graphgen/operators/quiz/quiz.py diff --git a/graphgen/operators/quiz_and_judge/__init__.py b/graphgen/operators/quiz_and_judge/__init__.py deleted file mode 100644 index cb73251a..00000000 --- a/graphgen/operators/quiz_and_judge/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .judge import judge_statement -from .quiz import quiz diff --git a/graphgen/operators/registry.py b/graphgen/operators/registry.py deleted file mode 100644 index e0c5826d..00000000 --- a/graphgen/operators/registry.py +++ /dev/null @@ -1,22 +0,0 @@ -from .build_kg import build_kg -from .extract import extract_info -from .generate import generate_qas -from .init import init_llm -from .partition import partition_kg -from .quiz_and_judge import judge_statement, quiz -from .read import read -from .search import search_all -from .split import chunk_documents - -operators = { - "read": read, - "init_llm": init_llm, - "chunk_documents": chunk_documents, - "extract_info": extract_info, - "search_all": search_all, - "build_kg": build_kg, - "partition_kg": partition_kg, - "generate_qas": generate_qas, - "quiz": quiz, - "judge_statement": judge_statement, -} From ee0639dbc069e04921c10e69ac9392562bfe3b37 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 5 Dec 2025 12:29:30 +0800 Subject: [PATCH 11/28] refactor: refactor quiz to accomodata ray data engine --- graphgen/operators/__init__.py | 18 +- .../operators/build_kg/build_kg_service.py | 2 +- graphgen/operators/evaluate/__init__.py | 0 graphgen/operators/judge/__init__.py | 0 graphgen/operators/quiz/__init__.py | 1 + graphgen/operators/quiz/quiz.py | 190 ++++++++++-------- graphgen/utils/run_concurrent.py | 3 +- 7 files changed, 120 insertions(+), 94 deletions(-) create mode 100644 graphgen/operators/evaluate/__init__.py create mode 100644 graphgen/operators/judge/__init__.py create mode 100644 graphgen/operators/quiz/__init__.py diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 38ced41e..83a843d5 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,9 +1,19 @@ -from .build_kg import build_kg +from .build_kg import BuildKGService +from .chunk import ChunkService from .extract import extract_info from .generate import generate_qas -from .init import init_llm from .partition import partition_kg -from .quiz_and_judge import judge_statement, quiz +from .quiz import QuizService from .read import read from .search import search_all -from .split import chunk_documents + +operators = { + "read": read, + "chunk": ChunkService, + "build_kg": BuildKGService, + "quiz": QuizService, + "extract_info": extract_info, + "search_all": search_all, + "partition_kg": partition_kg, + "generate_qas": generate_qas, +} diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py index 00b423cc..c6842089 100644 --- a/graphgen/operators/build_kg/build_kg_service.py +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -24,7 +24,7 @@ def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: # consume the chunks and build kg self.build_kg(docs) - return pd.DataFrame() + return pd.DataFrame([{"status": "kg_building_completed"}]) def build_kg(self, chunks: List[Chunk]) -> None: """ diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/judge/__init__.py b/graphgen/operators/judge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/quiz/__init__.py b/graphgen/operators/quiz/__init__.py new file mode 100644 index 00000000..318ae515 --- /dev/null +++ b/graphgen/operators/quiz/__init__.py @@ -0,0 +1 @@ +from .quiz import QuizService diff --git a/graphgen/operators/quiz/quiz.py b/graphgen/operators/quiz/quiz.py index 9aadb34b..61dfde49 100644 --- a/graphgen/operators/quiz/quiz.py +++ b/graphgen/operators/quiz/quiz.py @@ -1,93 +1,107 @@ -from collections import defaultdict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator -from graphgen.utils import logger, run_concurrent - - -async def quiz( - synth_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - max_samples: int = 1, - progress_bar: gr.Progress = None, -) -> JsonKVStorage: - """ - Get all edges and quiz them using QuizGenerator. - - :param synth_llm_client: generate statements - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param max_samples: max samples for each edge - :param progress_bar - :return: - """ - - generator = QuizGenerator(synth_llm_client) - - async def _process_single_quiz(item: tuple[str, str, str]): - description, template_type, gt = item - try: - # if rephrase_storage exists already, directly get it - descriptions = rephrase_storage.get_by_id(description) - if descriptions: - return None - - prompt = generator.build_prompt_for_description(description, template_type) - new_description = await synth_llm_client.generate_answer( - prompt, temperature=1 - ) - rephrased_text = generator.parse_rephrased_text(new_description) - return {description: [(rephrased_text, gt)]} - - except Exception as e: # pylint: disable=broad-except - logger.error("Error when quizzing description %s: %s", description, e) +from collections.abc import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.common import init_llm, init_storage +from graphgen.models import QuizGenerator +from graphgen.utils import compute_content_hash, logger, run_concurrent + + +class QuizService: + def __init__(self, working_dir: str = "cache", quiz_samples: int = 1): + self.quiz_samples = quiz_samples + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", working_dir=working_dir, namespace="graph" + ) + # { _description_id: { "description": str, "quizzes": List[Tuple[str, str]] } } + self.quiz_storage: BaseKVStorage = init_storage( + backend="json_kv", working_dir=working_dir, namespace="quiz" + ) + self.generator = QuizGenerator(self.llm_client) + + self.concurrency_limit = 20 + + def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + + yield from self.quiz() + + async def _process_single_quiz(self, item: str) -> dict | None: + # if quiz in quiz_storage exists already, directly get it + _description_id = compute_content_hash(item) + if self.quiz_storage.get_by_id(_description_id): return None - edges = graph_storage.get_all_edges() - nodes = graph_storage.get_all_nodes() - - results = defaultdict(list) - items = [] - for edge in edges: - edge_data = edge[2] - description = edge_data["description"] - - results[description] = [(description, "yes")] - - for i in range(max_samples): + tasks = [] + for i in range(self.quiz_samples): if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - for node in nodes: - node_data = node[1] - description = node_data["description"] + tasks.append((item, "TEMPLATE", "yes")) + tasks.append((item, "ANTI_TEMPLATE", "no")) + try: + quizzes = [] + for description, template_type, gt in tasks: + prompt = self.generator.build_prompt_for_description( + description, template_type + ) + new_description = await self.llm_client.generate_answer( + prompt, temperature=1 + ) + rephrased_text = self.generator.parse_rephrased_text(new_description) + quizzes.append((rephrased_text, gt)) + return { + "_description_id": _description_id, + "description": item, + "quizzes": quizzes, + } + except Exception as e: + logger.error("Error when quizzing description %s: %s", item, e) + return None - results[description] = [(description, "yes")] + def quiz(self) -> Iterable[pd.DataFrame]: + """ + Get all nodes and edges and quiz their descriptions using QuizGenerator. + """ + edges = self.graph_storage.get_all_edges() + nodes = self.graph_storage.get_all_nodes() + + items = [] + + for edge in edges: + edge_data = edge[2] + description = edge_data["description"] + items.append(description) + + for node in nodes: + node_data = node[1] + description = node_data["description"] + items.append(description) + + logger.info("Total descriptions to quiz: %d", len(items)) + + for i in range(0, len(items), self.concurrency_limit): + batch_items = items[i : i + self.concurrency_limit] + batch_results = run_concurrent( + self._process_single_quiz, + batch_items, + desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})", + unit="description", + ) - for i in range(max_samples): - if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - quiz_results = await run_concurrent( - _process_single_quiz, - items, - desc="Quizzing descriptions", - unit="description", - progress_bar=progress_bar, - ) - - for new_result in quiz_results: - if new_result: - for key, value in new_result.items(): - results[key].extend(value) - - for key, value in results.items(): - results[key] = list(set(value)) - rephrase_storage.upsert({key: results[key]}) - - return rephrase_storage + final_results = [] + for new_result in batch_results: + if new_result: + self.quiz_storage.upsert( + { + new_result["_description_id"]: { + "description": new_result["description"], + "quizzes": new_result["quizzes"], + } + } + ) + final_results.append(new_result) + self.quiz_storage.index_done_callback() + yield pd.DataFrame(final_results) diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index 7c85a6b6..d1a9b0e2 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -4,6 +4,7 @@ from tqdm.asyncio import tqdm as tqdm_async from graphgen.utils.log import logger + from .loop import create_event_loop T = TypeVar("T") @@ -27,7 +28,7 @@ async def _run_all(): try: result = await future results.append(result) - except Exception as e: # pylint: disable=broad-except + except Exception as e: logger.exception("Task failed: %s", e) results.append(e) From 157f0b0b35c98d30466806cc14b1be10a849616f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 5 Dec 2025 14:13:16 +0800 Subject: [PATCH 12/28] fix: reload graph before quizzing --- graphgen/models/storage/__init__.py | 5 +- graphgen/models/storage/graph/__init__.py | 0 .../models/storage/graph/networkx_storage.py | 6 ++ graphgen/models/storage/kv/__init__.py | 0 graphgen/models/storage/kv/json_storage.py | 4 + graphgen/models/storage/kv/rocksdb_storage.py | 79 +++++++++++++++++++ graphgen/operators/quiz/quiz.py | 2 +- 7 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 graphgen/models/storage/graph/__init__.py create mode 100644 graphgen/models/storage/kv/__init__.py create mode 100644 graphgen/models/storage/kv/rocksdb_storage.py diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index 1e8f8341..0f8d9eeb 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -1,3 +1,4 @@ -from .json_storage import JsonKVStorage, JsonListStorage -from .networkx_storage import NetworkXStorage +from graphgen.models.storage.graph.networkx_storage import NetworkXStorage +from graphgen.models.storage.kv.json_storage import JsonKVStorage + from .rocksdb_cache import RocksDBCache diff --git a/graphgen/models/storage/graph/__init__.py b/graphgen/models/storage/graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py index 36bf1b5e..28024535 100644 --- a/graphgen/models/storage/graph/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -170,3 +170,9 @@ def clear(self): """ self._graph.clear() logger.info("Graph %s cleared.", self.namespace) + + def reload(self): + """ + Reload the graph from the GraphML file. + """ + self.__post_init__() diff --git a/graphgen/models/storage/kv/__init__.py b/graphgen/models/storage/kv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/storage/kv/json_storage.py b/graphgen/models/storage/kv/json_storage.py index 3fcae5a3..f0b6c995 100644 --- a/graphgen/models/storage/kv/json_storage.py +++ b/graphgen/models/storage/kv/json_storage.py @@ -54,3 +54,7 @@ def upsert(self, data: dict): def drop(self): if self._data: self._data.clear() + + def reload(self): + self._data = load_json(self._file_name) or {} + logger.info("Reload KV %s with %d data", self.namespace, len(self._data)) diff --git a/graphgen/models/storage/kv/rocksdb_storage.py b/graphgen/models/storage/kv/rocksdb_storage.py new file mode 100644 index 00000000..0cbe1145 --- /dev/null +++ b/graphgen/models/storage/kv/rocksdb_storage.py @@ -0,0 +1,79 @@ +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Set + +# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it +# pylint: disable=no-name-in-module +from rocksdict import Rdict + +from graphgen.bases.base_storage import BaseKVStorage +from graphgen.utils import logger + + +@dataclass +class RocksDBKVStorage(BaseKVStorage): + _db: Rdict = None + _db_path: str = None + + def __post_init__(self): + self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db") + self._db = Rdict(self._db_path) + logger.info("Load KV (RocksDB) %s at %s", self.namespace, self._db_path) + + @property + def data(self): + return self._db + + def all_keys(self) -> List[str]: + return list(self._db.keys()) + + def index_done_callback(self): + self._db.flush() + logger.info("RocksDB flushed for %s", self.namespace) + + def get_by_id(self, id: str) -> Any: + return self._db.get(id, None) + + def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]: + result = [] + for index in ids: + item = self._db.get(index, None) + if item is None: + result.append(None) + continue + + if fields is None: + result.append(item) + else: + result.append({k: v for k, v in item.items() if k in fields}) + return result + + def get_all(self) -> Dict[str, Dict]: + return dict(self._db) + + def filter_keys(self, data: List[str]) -> Set[str]: + return {s for s in data if s not in self._db} + + def upsert(self, data: Dict[str, Any]): + left_data = {} + for k, v in data.items(): + if k not in self._db: + left_data[k] = v + + if left_data: + for k, v in left_data.items(): + self._db[k] = v + + # if left_data is very large, it is recommended to use self._db.write_batch() for optimization + + return left_data + + def drop(self): + self._db.close() + Rdict.destroy(self._db_path) + self._db = Rdict(self._db_path) + logger.info("Dropped RocksDB %s", self.namespace) + + def close(self): + if self._db: + self._db.close() diff --git a/graphgen/operators/quiz/quiz.py b/graphgen/operators/quiz/quiz.py index 61dfde49..4fc4cb00 100644 --- a/graphgen/operators/quiz/quiz.py +++ b/graphgen/operators/quiz/quiz.py @@ -27,7 +27,7 @@ def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: # this operator does not consume any batch data # but for compatibility we keep the interface _ = batch.to_dict(orient="records") - + self.graph_storage.reload() yield from self.quiz() async def _process_single_quiz(self, item: str) -> dict | None: From ec2033b0d01ca6849c5cb04cad7c0df848187c4b Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:19:32 +0800 Subject: [PATCH 13/28] Potential fix for pull request finding 'Unreachable code' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- graphgen/engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index 1bdce370..5a0e67d3 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -119,8 +119,6 @@ def _filter_kwargs( for k, v in func_params.items(): if k in valid_keys or has_var_keywords: final_kwargs[k] = v - elif has_var_keywords: - final_kwargs[k] = v return final_kwargs if node.op_name not in self.functions: From bc07222ca79d94517e3f1fb8b63252db933e6f20 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 5 Dec 2025 16:01:20 +0800 Subject: [PATCH 14/28] fix: fix quiz params --- graphgen/operators/__init__.py | 2 + graphgen/operators/chunk/chunk_service.py | 12 +- graphgen/operators/extract/__init__.py | 2 +- .../extract/{extract_info.py => extract.py} | 0 graphgen/operators/judge/__init__.py | 1 + graphgen/operators/judge/judge.py | 139 -------------- graphgen/operators/judge/judge_service.py | 170 ++++++++++++++++++ graphgen/operators/quiz/__init__.py | 2 +- .../quiz/{quiz.py => quiz_service.py} | 7 +- 9 files changed, 190 insertions(+), 145 deletions(-) rename graphgen/operators/extract/{extract_info.py => extract.py} (100%) delete mode 100644 graphgen/operators/judge/judge.py create mode 100644 graphgen/operators/judge/judge_service.py rename graphgen/operators/quiz/{quiz.py => quiz_service.py} (94%) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 83a843d5..1b56429c 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -6,12 +6,14 @@ from .quiz import QuizService from .read import read from .search import search_all +from .judge import JudgeService operators = { "read": read, "chunk": ChunkService, "build_kg": BuildKGService, "quiz": QuizService, + "judge": JudgeService, "extract_info": extract_info, "search_all": search_all, "partition_kg": partition_kg, diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index a0abbe57..0fcb20a3 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -4,6 +4,7 @@ import pandas as pd +from graphgen.common import init_storage from graphgen.models import ( ChineseRecursiveTextSplitter, RecursiveCharacterSplitter, @@ -40,9 +41,14 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list: class ChunkService: - def __init__(self, **chunk_kwargs): + def __init__(self, working_dir: str = "cache", **chunk_kwargs): tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) + self.chunk_storage = init_storage( + backend="json_kv", + working_dir=working_dir, + namespace="chunk", + ) self.chunk_kwargs = chunk_kwargs def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: @@ -88,4 +94,8 @@ def chunk_documents(self, new_docs: list) -> list: **doc, } ) + self.chunk_storage.upsert( + {chunk["_chunk_id"]: chunk for chunk in chunks} + ) + self.chunk_storage.index_done_callback() return chunks diff --git a/graphgen/operators/extract/__init__.py b/graphgen/operators/extract/__init__.py index ec576cb6..a5d29d7c 100644 --- a/graphgen/operators/extract/__init__.py +++ b/graphgen/operators/extract/__init__.py @@ -1 +1 @@ -from .extract_info import extract_info +from .extract import extract_info diff --git a/graphgen/operators/extract/extract_info.py b/graphgen/operators/extract/extract.py similarity index 100% rename from graphgen/operators/extract/extract_info.py rename to graphgen/operators/extract/extract.py diff --git a/graphgen/operators/judge/__init__.py b/graphgen/operators/judge/__init__.py index e69de29b..32ccf5c2 100644 --- a/graphgen/operators/judge/__init__.py +++ b/graphgen/operators/judge/__init__.py @@ -0,0 +1 @@ +from .judge_service import JudgeService diff --git a/graphgen/operators/judge/judge.py b/graphgen/operators/judge/judge.py deleted file mode 100644 index b5e35eb9..00000000 --- a/graphgen/operators/judge/judge.py +++ /dev/null @@ -1,139 +0,0 @@ -import math - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT -from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy - - -async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - re_judge: bool = False, - progress_bar: gr.Progress = None, -) -> NetworkXStorage: - """ - Get all edges and nodes and judge them - - :param trainee_llm_client: judge the statements to get comprehension loss - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param re_judge: re-judge the relations - :param progress_bar - :return: - """ - - async def _judge_single_relation( - edge: tuple, - ): - source_id = edge[0] - target_id = edge[1] - edge_data = edge[2] - - if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: - logger.debug( - "Edge %s -> %s already judged, loss: %s, skip", - source_id, - target_id, - edge_data["loss"], - ) - return source_id, target_id, edge_data - - description = edge_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug( - "Edge %s -> %s description: %s loss: %s", - source_id, - target_id, - description, - loss, - ) - - edge_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error in judging relation %s -> %s: %s", source_id, target_id, e - ) - logger.info("Use default loss 0.1") - edge_data["loss"] = -math.log(0.1) - - graph_storage.update_edge(source_id, target_id, edge_data) - return source_id, target_id, edge_data - - edges = graph_storage.get_all_edges() - - await run_concurrent( - _judge_single_relation, - edges, - desc="Judging relations", - unit="relation", - progress_bar=progress_bar, - ) - - async def _judge_single_entity( - node: tuple, - ): - node_id = node[0] - node_data = node[1] - - if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: - logger.debug( - "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] - ) - return node_id, node_data - - description = node_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug("Node %s description: %s loss: %s", node_id, description, loss) - - node_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error("Error in judging entity %s: %s", node_id, e) - logger.error("Use default loss 0.1") - node_data["loss"] = -math.log(0.1) - - graph_storage.update_node(node_id, node_data) - return node_id, node_data - - nodes = graph_storage.get_all_nodes() - - await run_concurrent( - _judge_single_entity, - nodes, - desc="Judging entities", - unit="entity", - progress_bar=progress_bar, - ) - - return graph_storage diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py new file mode 100644 index 00000000..20588e43 --- /dev/null +++ b/graphgen/operators/judge/judge_service.py @@ -0,0 +1,170 @@ +import math + +import gradio as gr + +from graphgen.bases import BaseLLMWrapper +from graphgen.models import JsonKVStorage, NetworkXStorage +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy + + +import math +from collections.abc import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.common import init_llm, init_storage +from graphgen.models import NetworkXStorage, JsonKVStorage +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy + + +class JudgeService: + """Service for judging graph edges and nodes using a trainee LLM.""" + def __init__(self, working_dir: str = "cache"): + self.llm_client: BaseLLMWrapper = init_llm("trainee") + + def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame([{"status": "judging_completed"}]) + + def judge(self) -> Iterable[pd.DataFrame]: + """ + Judge the statements in the graph storage + + :param re_judge: re-judge the relations + :return: + """ + return + + + +# async def judge_statement( # pylint: disable=too-many-statements +# trainee_llm_client: BaseLLMWrapper, +# graph_storage: NetworkXStorage, +# rephrase_storage: JsonKVStorage, +# re_judge: bool = False, +# progress_bar: gr.Progress = None, +# ) -> NetworkXStorage: +# """ +# Get all edges and nodes and judge them +# +# :param trainee_llm_client: judge the statements to get comprehension loss +# :param graph_storage: graph storage instance +# :param rephrase_storage: rephrase storage instance +# :param re_judge: re-judge the relations +# :param progress_bar +# :return: +# """ +# +# async def _judge_single_relation( +# edge: tuple, +# ): +# source_id = edge[0] +# target_id = edge[1] +# edge_data = edge[2] +# +# if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: +# logger.debug( +# "Edge %s -> %s already judged, loss: %s, skip", +# source_id, +# target_id, +# edge_data["loss"], +# ) +# return source_id, target_id, edge_data +# +# description = edge_data["description"] +# +# try: +# descriptions = rephrase_storage.get_by_id(description) +# assert descriptions is not None +# +# judgements = [] +# gts = [gt for _, gt in descriptions] +# for description, gt in descriptions: +# judgement = await trainee_llm_client.generate_topk_per_token( +# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) +# ) +# judgements.append(judgement[0].top_candidates) +# +# loss = yes_no_loss_entropy(judgements, gts) +# +# logger.debug( +# "Edge %s -> %s description: %s loss: %s", +# source_id, +# target_id, +# description, +# loss, +# ) +# +# edge_data["loss"] = loss +# except Exception as e: # pylint: disable=broad-except +# logger.error( +# "Error in judging relation %s -> %s: %s", source_id, target_id, e +# ) +# logger.info("Use default loss 0.1") +# edge_data["loss"] = -math.log(0.1) +# +# graph_storage.update_edge(source_id, target_id, edge_data) +# return source_id, target_id, edge_data +# +# edges = graph_storage.get_all_edges() +# +# await run_concurrent( +# _judge_single_relation, +# edges, +# desc="Judging relations", +# unit="relation", +# progress_bar=progress_bar, +# ) +# +# async def _judge_single_entity( +# node: tuple, +# ): +# node_id = node[0] +# node_data = node[1] +# +# if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: +# logger.debug( +# "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] +# ) +# return node_id, node_data +# +# description = node_data["description"] +# +# try: +# descriptions = rephrase_storage.get_by_id(description) +# assert descriptions is not None +# +# judgements = [] +# gts = [gt for _, gt in descriptions] +# for description, gt in descriptions: +# judgement = await trainee_llm_client.generate_topk_per_token( +# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) +# ) +# judgements.append(judgement[0].top_candidates) +# +# loss = yes_no_loss_entropy(judgements, gts) +# +# logger.debug("Node %s description: %s loss: %s", node_id, description, loss) +# +# node_data["loss"] = loss +# except Exception as e: # pylint: disable=broad-except +# logger.error("Error in judging entity %s: %s", node_id, e) +# logger.error("Use default loss 0.1") +# node_data["loss"] = -math.log(0.1) +# +# graph_storage.update_node(node_id, node_data) +# return node_id, node_data +# +# nodes = graph_storage.get_all_nodes() +# +# await run_concurrent( +# _judge_single_entity, +# nodes, +# desc="Judging entities", +# unit="entity", +# progress_bar=progress_bar, +# ) +# +# return graph_storage diff --git a/graphgen/operators/quiz/__init__.py b/graphgen/operators/quiz/__init__.py index 318ae515..2a931f4b 100644 --- a/graphgen/operators/quiz/__init__.py +++ b/graphgen/operators/quiz/__init__.py @@ -1 +1 @@ -from .quiz import QuizService +from .quiz_service import QuizService diff --git a/graphgen/operators/quiz/quiz.py b/graphgen/operators/quiz/quiz_service.py similarity index 94% rename from graphgen/operators/quiz/quiz.py rename to graphgen/operators/quiz/quiz_service.py index 4fc4cb00..f811157b 100644 --- a/graphgen/operators/quiz/quiz.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -5,11 +5,11 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.common import init_llm, init_storage from graphgen.models import QuizGenerator -from graphgen.utils import compute_content_hash, logger, run_concurrent +from graphgen.utils import compute_content_hash, run_concurrent, logger class QuizService: - def __init__(self, working_dir: str = "cache", quiz_samples: int = 1): + def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrency_limit: int = 200): self.quiz_samples = quiz_samples self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( @@ -21,7 +21,7 @@ def __init__(self, working_dir: str = "cache", quiz_samples: int = 1): ) self.generator = QuizGenerator(self.llm_client) - self.concurrency_limit = 20 + self.concurrency_limit = concurrency_limit def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: # this operator does not consume any batch data @@ -80,6 +80,7 @@ def quiz(self) -> Iterable[pd.DataFrame]: description = node_data["description"] items.append(description) + print("Total descriptions to quiz: %d", len(items)) logger.info("Total descriptions to quiz: %d", len(items)) for i in range(0, len(items), self.concurrency_limit): From c9435d795cf96734c94206419748ea687193c335 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Dec 2025 15:18:04 +0800 Subject: [PATCH 15/28] refactor: refactor quiz&judge to ray actors --- .../generate/{generate_qas.py => generate.py} | 0 graphgen/operators/judge/judge_service.py | 72 ++++++++++++++----- .../{partition_kg.py => partition_service.py} | 64 +++++++++++++++-- graphgen/operators/partition/pre_tokenize.py | 55 -------------- graphgen/operators/quiz/quiz_service.py | 11 +-- requirements.txt | 2 + 6 files changed, 122 insertions(+), 82 deletions(-) rename graphgen/operators/generate/{generate_qas.py => generate.py} (100%) rename graphgen/operators/partition/{partition_kg.py => partition_service.py} (68%) delete mode 100644 graphgen/operators/partition/pre_tokenize.py diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate.py similarity index 100% rename from graphgen/operators/generate/generate_qas.py rename to graphgen/operators/generate/generate.py diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py index 20588e43..f48b2d62 100644 --- a/graphgen/operators/judge/judge_service.py +++ b/graphgen/operators/judge/judge_service.py @@ -1,42 +1,76 @@ import math -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT -from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy - - -import math -from collections.abc import Iterable - import pandas as pd -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper from graphgen.common import init_llm, init_storage -from graphgen.models import NetworkXStorage, JsonKVStorage from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy class JudgeService: """Service for judging graph edges and nodes using a trainee LLM.""" + def __init__(self, working_dir: str = "cache"): self.llm_client: BaseLLMWrapper = init_llm("trainee") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", + working_dir=working_dir, + namespace="graph", + ) def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + self.graph_storage.reload() + self.judge(items) return pd.DataFrame([{"status": "judging_completed"}]) - def judge(self) -> Iterable[pd.DataFrame]: - """ - Judge the statements in the graph storage + async def _process_single_judge(self, item: dict) -> dict: + description = item["description"] + try: + judgement = await self.llm_client.generate_topk_per_token( + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) + ) + top_candidates = judgement[0].top_candidates + gt = item.get("ground_truth", "yes") + loss = yes_no_loss_entropy([top_candidates], [gt]) + logger.debug("Description: %s Loss: %s", description, loss) + item["loss"] = loss + except Exception as e: # pylint: disable=broad-except + logger.error("Error in judging description: %s", e) + logger.info("Use default loss 0.1") + item["loss"] = -math.log(0.1) + return item - :param re_judge: re-judge the relations - :return: + def judge(self, items: list[dict]) -> None: + """ + Judge the description in the item and compute the loss. """ - return + results = run_concurrent( + self._process_single_judge, + items, + desc="Judging descriptions", + unit="description", + ) + # Update the graph storage with the computed losses + for item in results: + print(item) + node_id = item.get("node_id") + edge_source = item.get("edge_source") + edge_target = item.get("edge_target") + loss = item["loss"] + if node_id is not None: + node_data = self.graph_storage.get_node(node_id) + if node_data is not None: + node_data["loss"] = loss + self.graph_storage.update_node(node_id, node_data) + elif edge_source is not None and edge_target is not None: + edge_data = self.graph_storage.get_edge(edge_source, edge_target) + if edge_data is not None: + edge_data["loss"] = loss + self.graph_storage.update_edge(edge_source, edge_target, edge_data) + self.graph_storage.index_done_callback() # async def judge_statement( # pylint: disable=too-many-statements diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_service.py similarity index 68% rename from graphgen/operators/partition/partition_kg.py rename to graphgen/operators/partition/partition_service.py index 4c4fdaa1..f7510abc 100644 --- a/graphgen/operators/partition/partition_kg.py +++ b/graphgen/operators/partition/partition_service.py @@ -10,10 +10,8 @@ ) from graphgen.utils import logger -from .pre_tokenize import pre_tokenize - -async def partition_kg( +def partition_kg( kg_instance: BaseGraphStorage, chunk_storage: BaseKVStorage, tokenizer: Any = BaseTokenizer, @@ -60,7 +58,7 @@ async def partition_kg( return batches -async def attach_additional_data_to_node( +def attach_additional_data_to_node( batches: list[ tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] @@ -112,3 +110,61 @@ async def _attach_by_type( # We'll use the first image chunk found for this node. node_data["images"] = image_chunks[0] logger.debug("Attached image data to node %s", node_id) + + +import asyncio +from typing import List, Tuple + +import gradio as gr + +from graphgen.bases import BaseGraphStorage, BaseTokenizer +from graphgen.utils import run_concurrent + + +async def pre_tokenize( + graph_storage: BaseGraphStorage, + tokenizer: BaseTokenizer, + edges: List[Tuple], + nodes: List[Tuple], + progress_bar: gr.Progress = None, + max_concurrent: int = 1000, +) -> Tuple[List, List]: + """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" + sem = asyncio.Semaphore(max_concurrent) + + async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: + async with sem: + data = obj[1] if is_node else obj[2] + if "length" not in data: + loop = asyncio.get_event_loop() + data["length"] = len( + await loop.run_in_executor( + None, tokenizer.encode, data["description"] + ) + ) + if is_node: + graph_storage.update_node(obj[0], obj[1]) + else: + graph_storage.update_edge(obj[0], obj[1], obj[2]) + return obj + + new_edges, new_nodes = await asyncio.gather( + run_concurrent( + lambda e: _patch_and_write(e, is_node=False), + edges, + desc="Pre-tokenizing edges", + unit="edge", + progress_bar=progress_bar, + ), + run_concurrent( + lambda n: _patch_and_write(n, is_node=True), + nodes, + desc="Pre-tokenizing nodes", + unit="node", + progress_bar=progress_bar, + ), + ) + + graph_storage.index_done_callback() + return new_edges, new_nodes + diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py deleted file mode 100644 index 83e99060..00000000 --- a/graphgen/operators/partition/pre_tokenize.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import List, Tuple - -import gradio as gr - -from graphgen.bases import BaseGraphStorage, BaseTokenizer -from graphgen.utils import run_concurrent - - -async def pre_tokenize( - graph_storage: BaseGraphStorage, - tokenizer: BaseTokenizer, - edges: List[Tuple], - nodes: List[Tuple], - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> Tuple[List, List]: - """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" - sem = asyncio.Semaphore(max_concurrent) - - async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: - async with sem: - data = obj[1] if is_node else obj[2] - if "length" not in data: - loop = asyncio.get_event_loop() - data["length"] = len( - await loop.run_in_executor( - None, tokenizer.encode, data["description"] - ) - ) - if is_node: - graph_storage.update_node(obj[0], obj[1]) - else: - graph_storage.update_edge(obj[0], obj[1], obj[2]) - return obj - - new_edges, new_nodes = await asyncio.gather( - run_concurrent( - lambda e: _patch_and_write(e, is_node=False), - edges, - desc="Pre-tokenizing edges", - unit="edge", - progress_bar=progress_bar, - ), - run_concurrent( - lambda n: _patch_and_write(n, is_node=True), - nodes, - desc="Pre-tokenizing nodes", - unit="node", - progress_bar=progress_bar, - ), - ) - - graph_storage.index_done_callback() - return new_edges, new_nodes diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py index f811157b..415d4771 100644 --- a/graphgen/operators/quiz/quiz_service.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -5,11 +5,16 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.common import init_llm, init_storage from graphgen.models import QuizGenerator -from graphgen.utils import compute_content_hash, run_concurrent, logger +from graphgen.utils import compute_content_hash, logger, run_concurrent class QuizService: - def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrency_limit: int = 200): + def __init__( + self, + working_dir: str = "cache", + quiz_samples: int = 1, + concurrency_limit: int = 200, + ): self.quiz_samples = quiz_samples self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( @@ -20,7 +25,6 @@ def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrenc backend="json_kv", working_dir=working_dir, namespace="quiz" ) self.generator = QuizGenerator(self.llm_client) - self.concurrency_limit = concurrency_limit def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: @@ -80,7 +84,6 @@ def quiz(self) -> Iterable[pd.DataFrame]: description = node_data["description"] items.append(description) - print("Total descriptions to quiz: %d", len(items)) logger.info("Total descriptions to quiz: %d", len(items)) for i in range(0, len(items), self.concurrency_limit): diff --git a/requirements.txt b/requirements.txt index 85fc43e3..44079ab5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,8 @@ fastapi trafilatura aiohttp socksio +pydantic +ray==2.52.1 leidenalg igraph From d7d6c2abd644b6edcfd1ecd8f61b3898d4c13b9d Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Dec 2025 19:51:48 +0800 Subject: [PATCH 16/28] fix: fix transferring quizzed data to JudgeService --- graphgen/operators/judge/judge_service.py | 147 ++-------------------- graphgen/operators/quiz/quiz_service.py | 26 ++-- 2 files changed, 20 insertions(+), 153 deletions(-) diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py index f48b2d62..2801bb4e 100644 --- a/graphgen/operators/judge/judge_service.py +++ b/graphgen/operators/judge/judge_service.py @@ -52,153 +52,20 @@ def judge(self, items: list[dict]) -> None: desc="Judging descriptions", unit="description", ) - # Update the graph storage with the computed losses for item in results: - print(item) - node_id = item.get("node_id") - edge_source = item.get("edge_source") - edge_target = item.get("edge_target") + index = item["index"] loss = item["loss"] - if node_id is not None: + if isinstance(index, str): + node_id = index node_data = self.graph_storage.get_node(node_id) - if node_data is not None: + if node_data: node_data["loss"] = loss self.graph_storage.update_node(node_id, node_data) - elif edge_source is not None and edge_target is not None: + elif isinstance(index, tuple): + edge_source, edge_target = index edge_data = self.graph_storage.get_edge(edge_source, edge_target) - if edge_data is not None: + if edge_data: edge_data["loss"] = loss self.graph_storage.update_edge(edge_source, edge_target, edge_data) self.graph_storage.index_done_callback() - - -# async def judge_statement( # pylint: disable=too-many-statements -# trainee_llm_client: BaseLLMWrapper, -# graph_storage: NetworkXStorage, -# rephrase_storage: JsonKVStorage, -# re_judge: bool = False, -# progress_bar: gr.Progress = None, -# ) -> NetworkXStorage: -# """ -# Get all edges and nodes and judge them -# -# :param trainee_llm_client: judge the statements to get comprehension loss -# :param graph_storage: graph storage instance -# :param rephrase_storage: rephrase storage instance -# :param re_judge: re-judge the relations -# :param progress_bar -# :return: -# """ -# -# async def _judge_single_relation( -# edge: tuple, -# ): -# source_id = edge[0] -# target_id = edge[1] -# edge_data = edge[2] -# -# if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: -# logger.debug( -# "Edge %s -> %s already judged, loss: %s, skip", -# source_id, -# target_id, -# edge_data["loss"], -# ) -# return source_id, target_id, edge_data -# -# description = edge_data["description"] -# -# try: -# descriptions = rephrase_storage.get_by_id(description) -# assert descriptions is not None -# -# judgements = [] -# gts = [gt for _, gt in descriptions] -# for description, gt in descriptions: -# judgement = await trainee_llm_client.generate_topk_per_token( -# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) -# ) -# judgements.append(judgement[0].top_candidates) -# -# loss = yes_no_loss_entropy(judgements, gts) -# -# logger.debug( -# "Edge %s -> %s description: %s loss: %s", -# source_id, -# target_id, -# description, -# loss, -# ) -# -# edge_data["loss"] = loss -# except Exception as e: # pylint: disable=broad-except -# logger.error( -# "Error in judging relation %s -> %s: %s", source_id, target_id, e -# ) -# logger.info("Use default loss 0.1") -# edge_data["loss"] = -math.log(0.1) -# -# graph_storage.update_edge(source_id, target_id, edge_data) -# return source_id, target_id, edge_data -# -# edges = graph_storage.get_all_edges() -# -# await run_concurrent( -# _judge_single_relation, -# edges, -# desc="Judging relations", -# unit="relation", -# progress_bar=progress_bar, -# ) -# -# async def _judge_single_entity( -# node: tuple, -# ): -# node_id = node[0] -# node_data = node[1] -# -# if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: -# logger.debug( -# "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] -# ) -# return node_id, node_data -# -# description = node_data["description"] -# -# try: -# descriptions = rephrase_storage.get_by_id(description) -# assert descriptions is not None -# -# judgements = [] -# gts = [gt for _, gt in descriptions] -# for description, gt in descriptions: -# judgement = await trainee_llm_client.generate_topk_per_token( -# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) -# ) -# judgements.append(judgement[0].top_candidates) -# -# loss = yes_no_loss_entropy(judgements, gts) -# -# logger.debug("Node %s description: %s loss: %s", node_id, description, loss) -# -# node_data["loss"] = loss -# except Exception as e: # pylint: disable=broad-except -# logger.error("Error in judging entity %s: %s", node_id, e) -# logger.error("Use default loss 0.1") -# node_data["loss"] = -math.log(0.1) -# -# graph_storage.update_node(node_id, node_data) -# return node_id, node_data -# -# nodes = graph_storage.get_all_nodes() -# -# await run_concurrent( -# _judge_single_entity, -# nodes, -# desc="Judging entities", -# unit="entity", -# progress_bar=progress_bar, -# ) -# -# return graph_storage diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py index 415d4771..811c24aa 100644 --- a/graphgen/operators/quiz/quiz_service.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -34,23 +34,22 @@ def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: self.graph_storage.reload() yield from self.quiz() - async def _process_single_quiz(self, item: str) -> dict | None: + async def _process_single_quiz(self, item: tuple) -> dict | None: # if quiz in quiz_storage exists already, directly get it - _description_id = compute_content_hash(item) + index, desc = item + _description_id = compute_content_hash(desc, prefix="quiz-") if self.quiz_storage.get_by_id(_description_id): return None tasks = [] for i in range(self.quiz_samples): if i > 0: - tasks.append((item, "TEMPLATE", "yes")) - tasks.append((item, "ANTI_TEMPLATE", "no")) + tasks.append((desc, "TEMPLATE", "yes")) + tasks.append((desc, "ANTI_TEMPLATE", "no")) try: quizzes = [] - for description, template_type, gt in tasks: - prompt = self.generator.build_prompt_for_description( - description, template_type - ) + for d, template_type, gt in tasks: + prompt = self.generator.build_prompt_for_description(d, template_type) new_description = await self.llm_client.generate_answer( prompt, temperature=1 ) @@ -58,7 +57,8 @@ async def _process_single_quiz(self, item: str) -> dict | None: quizzes.append((rephrased_text, gt)) return { "_description_id": _description_id, - "description": item, + "description": desc, + "index": index, "quizzes": quizzes, } except Exception as e: @@ -76,13 +76,13 @@ def quiz(self) -> Iterable[pd.DataFrame]: for edge in edges: edge_data = edge[2] - description = edge_data["description"] - items.append(description) + desc = edge_data["description"] + items.append(((edge[0], edge[1]), desc)) for node in nodes: node_data = node[1] - description = node_data["description"] - items.append(description) + desc = node_data["description"] + items.append((node[0], desc)) logger.info("Total descriptions to quiz: %d", len(items)) From a6aedafba0be11f68b5825de94c187cd4a3361f9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Dec 2025 21:14:56 +0800 Subject: [PATCH 17/28] refactor: refactor partition to accomodate ray data --- graphgen/bases/base_partitioner.py | 49 ++- .../partitioner/anchor_bfs_partitioner.py | 23 +- .../models/partitioner/bfs_partitioner.py | 13 +- .../models/partitioner/dfs_partitioner.py | 12 +- .../models/partitioner/ece_partitioner.py | 32 +- .../models/partitioner/leiden_partitioner.py | 14 +- graphgen/operators/partition/__init__.py | 2 +- .../operators/partition/partition_service.py | 300 +++++++++--------- 8 files changed, 203 insertions(+), 242 deletions(-) diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index d74ff563..d948e3a7 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -7,7 +7,7 @@ class BasePartitioner(ABC): @abstractmethod - async def partition( + def partition( self, g: BaseGraphStorage, **kwargs: Any, @@ -20,39 +20,34 @@ async def partition( """ @staticmethod - async def community2batch( - communities: List[Community], g: BaseGraphStorage - ) -> list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] + def community2batch( + comm: Community, g: BaseGraphStorage + ) -> tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ]: """ Convert communities to batches of nodes and edges. - :param communities + :param comm: Community :param g: Graph storage instance :return: List of batches, each batch is a tuple of (nodes, edges) """ - batches = [] - for comm in communities: - nodes = comm.nodes - edges = comm.edges - nodes_data = [] - for node in nodes: - node_data = g.get_node(node) - if node_data: - nodes_data.append((node, node_data)) - edges_data = [] - for u, v in edges: - edge_data = g.get_edge(u, v) + nodes = comm.nodes + edges = comm.edges + nodes_data = [] + for node in nodes: + node_data = g.get_node(node) + if node_data: + nodes_data.append((node, node_data)) + edges_data = [] + for u, v in edges: + edge_data = g.get_edge(u, v) + if edge_data: + edges_data.append((u, v, edge_data)) + else: + edge_data = g.get_edge(v, u) if edge_data: - edges_data.append((u, v, edge_data)) - else: - edge_data = g.get_edge(v, u) - if edge_data: - edges_data.append((v, u, edge_data)) - batches.append((nodes_data, edges_data)) - return batches + edges_data.append((v, u, edge_data)) + return nodes_data, edges_data @staticmethod def _build_adjacency_list( diff --git a/graphgen/models/partitioner/anchor_bfs_partitioner.py b/graphgen/models/partitioner/anchor_bfs_partitioner.py index 6cc1400c..09133af7 100644 --- a/graphgen/models/partitioner/anchor_bfs_partitioner.py +++ b/graphgen/models/partitioner/anchor_bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List, Literal, Set, Tuple +from typing import Any, Iterable, List, Literal, Set, Tuple from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -30,24 +30,23 @@ def __init__( self.anchor_type = anchor_type self.anchor_ids = anchor_ids - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() # List[tuple[id, meta]] edges = g.get_all_edges() # List[tuple[u, v, meta]] adj, _ = self._build_adjacency_list(nodes, edges) - anchors: Set[str] = await self._pick_anchor_ids(nodes) + anchors: Set[str] = self._pick_anchor_ids(nodes) if not anchors: - return [] # if no anchors, return empty list + return # if no anchors, return nothing used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] seeds = list(anchors) random.shuffle(seeds) @@ -55,17 +54,13 @@ async def partition( for seed_node in seeds: if seed_node in used_n: continue - comm_n, comm_e = await self._grow_community( + comm_n, comm_e = self._grow_community( seed_node, adj, max_units_per_community, used_n, used_e ) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) + yield Community(id=seed_node, nodes=comm_n, edges=comm_e) - return communities - - async def _pick_anchor_ids( + def _pick_anchor_ids( self, nodes: List[tuple[str, dict]], ) -> Set[str]: @@ -80,7 +75,7 @@ async def _pick_anchor_ids( return anchor_ids @staticmethod - async def _grow_community( + def _grow_community( seed: str, adj: dict[str, List[str]], max_units: int, diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 00895712..994e08e8 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List +from typing import Any, Iterable, List from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community @@ -17,12 +17,12 @@ class BFSPartitioner(BasePartitioner): (A unit is a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -30,7 +30,6 @@ async def partition( used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -74,8 +73,4 @@ async def partition( queue.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index 6c394b10..36305842 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,4 +1,5 @@ import random +from collections.abc import Iterable from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -16,12 +17,12 @@ class DFSPartitioner(BasePartitioner): (In GraphGen, a unit is defined as a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -29,7 +30,6 @@ async def partition( used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -71,8 +71,4 @@ async def partition( stack.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 7de73181..cb0ce861 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,8 +1,8 @@ -import asyncio import random +from collections import deque from typing import Any, Dict, List, Optional, Set, Tuple -from tqdm.asyncio import tqdm as tqdm_async +from tqdm import tqdm from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -51,7 +51,7 @@ def _sort_units(units: list, edge_sampling: str) -> list: raise ValueError(f"Invalid edge sampling: {edge_sampling}") return units - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 10, @@ -73,21 +73,19 @@ async def partition( used_n: Set[str] = set() used_e: Set[frozenset[str]] = set() - communities: List = [] + communities: List[Community] = [] all_units = self._sort_units(all_units, unit_sampling) - async def _grow_community( - seed_unit: Tuple[str, Any, dict] - ) -> Optional[Community]: + def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]: nonlocal used_n, used_e community_nodes: Dict[str, dict] = {} community_edges: Dict[frozenset[str], dict] = {} - queue: asyncio.Queue = asyncio.Queue() + queue = deque() token_sum = 0 - async def _add_unit(u): + def _add_unit(u): nonlocal token_sum t, i, d = u if t == NODE_UNIT: # node @@ -103,11 +101,11 @@ async def _add_unit(u): token_sum += d.get("length", 0) return True - await _add_unit(seed_unit) - await queue.put(seed_unit) + _add_unit(seed_unit) + queue.append(seed_unit) # BFS - while not queue.empty(): + while queue: if ( len(community_nodes) + len(community_edges) >= max_units_per_community @@ -115,7 +113,7 @@ async def _add_unit(u): ): break - cur_type, cur_id, _ = await queue.get() + cur_type, cur_id, _ = queue.popleft() neighbors: List[Tuple[str, Any, dict]] = [] if cur_type == NODE_UNIT: @@ -136,8 +134,8 @@ async def _add_unit(u): or token_sum >= max_tokens_per_community ): break - if await _add_unit(nb): - await queue.put(nb) + if _add_unit(nb): + queue.append(nb) if len(community_nodes) + len(community_edges) < min_units_per_community: return None @@ -148,13 +146,13 @@ async def _add_unit(u): edges=[(u, v) for (u, v), _ in community_edges.items()], ) - async for unit in tqdm_async(all_units, desc="ECE partition"): + for unit in tqdm(all_units, desc="ECE partition"): utype, uid, _ = unit if (utype == NODE_UNIT and uid in used_n) or ( utype == EDGE_UNIT and uid in used_e ): continue - comm = await _grow_community(unit) + comm = _grow_community(unit) if comm is not None: communities.append(comm) diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index 1f85789b..b62b8544 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner): Leiden partitioner that partitions the graph into communities using the Leiden algorithm. """ - async def partition( + def partition( self, g: BaseGraphStorage, max_size: int = 20, @@ -37,12 +37,10 @@ async def partition( nodes = g.get_all_nodes() # List[Tuple[str, dict]] edges = g.get_all_edges() # List[Tuple[str, str, dict]] - node2cid: Dict[str, int] = await self._run_leiden( - nodes, edges, use_lcc, random_seed - ) + node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed) if max_size is not None and max_size > 0: - node2cid = await self._split_communities(node2cid, max_size) + node2cid = self._split_communities(node2cid, max_size) cid2nodes: Dict[int, List[str]] = defaultdict(list) for n, cid in node2cid.items(): @@ -58,7 +56,7 @@ async def partition( return communities @staticmethod - async def _run_leiden( + def _run_leiden( nodes: List[Tuple[str, dict]], edges: List[Tuple[str, str, dict]], use_lcc: bool = False, @@ -92,9 +90,7 @@ async def _run_leiden( return node2cid @staticmethod - async def _split_communities( - node2cid: Dict[str, int], max_size: int - ) -> Dict[str, int]: + def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]: """ Split communities larger than max_size into smaller sub-communities. """ diff --git a/graphgen/operators/partition/__init__.py b/graphgen/operators/partition/__init__.py index 21f934b3..8d586b95 100644 --- a/graphgen/operators/partition/__init__.py +++ b/graphgen/operators/partition/__init__.py @@ -1 +1 @@ -from .partition_kg import partition_kg +from .partition_service import PartitionService diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index f7510abc..e8b08628 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -1,170 +1,156 @@ -from typing import Any +import os +from typing import Any, Iterable + +import pandas as pd from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer +from graphgen.common import init_storage from graphgen.models import ( AnchorBFSPartitioner, BFSPartitioner, DFSPartitioner, ECEPartitioner, LeidenPartitioner, + Tokenizer, ) from graphgen.utils import logger -def partition_kg( - kg_instance: BaseGraphStorage, - chunk_storage: BaseKVStorage, - tokenizer: Any = BaseTokenizer, - partition_config: dict = None, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - method = partition_config["method"] - method_params = partition_config["method_params"] - if method == "bfs": - logger.info("Partitioning knowledge graph using BFS method.") - partitioner = BFSPartitioner() - elif method == "dfs": - logger.info("Partitioning knowledge graph using DFS method.") - partitioner = DFSPartitioner() - elif method == "ece": - logger.info("Partitioning knowledge graph using ECE method.") - # TODO: before ECE partitioning, we need to: - # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random - # 2. pre-tokenize nodes and edges to get the token length - edges = kg_instance.get_all_edges() - nodes = kg_instance.get_all_nodes() - await pre_tokenize(kg_instance, tokenizer, edges, nodes) - partitioner = ECEPartitioner() - elif method == "leiden": - logger.info("Partitioning knowledge graph using Leiden method.") - partitioner = LeidenPartitioner() - elif method == "anchor_bfs": - logger.info("Partitioning knowledge graph using Anchor BFS method.") - partitioner = AnchorBFSPartitioner( - anchor_type=method_params.get("anchor_type"), - anchor_ids=set(method_params.get("anchor_ids", [])) - if method_params.get("anchor_ids") - else None, +class PartitionService: + def __init__(self, working_dir: str = "cache", **partition_kwargs): + self.kg_instance: BaseGraphStorage = init_storage( + backend="networkx", + working_dir=working_dir, + namespace="graph", ) - else: - raise ValueError(f"Unsupported partition method: {method}") - - communities = await partitioner.partition(g=kg_instance, **method_params) - logger.info("Partitioned the graph into %d communities.", len(communities)) - batches = await partitioner.community2batch(communities, g=kg_instance) - - batches = await attach_additional_data_to_node(batches, chunk_storage) - return batches - - -def attach_additional_data_to_node( - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - chunk_storage: BaseKVStorage, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - """ - Attach additional data from chunk_storage to nodes in the batches. - :param batches: - :param chunk_storage: - :return: - """ - for batch in batches: - for node_id, node_data in batch[0]: - await _attach_by_type(node_id, node_data, chunk_storage) - return batches - - -async def _attach_by_type( - node_id: str, - node_data: dict, - chunk_storage: BaseKVStorage, -) -> None: - """ - Attach additional data to the node based on its entity type. - """ - entity_type = (node_data.get("entity_type") or "").lower() - if not entity_type: - return - - source_ids = [ - sid.strip() - for sid in node_data.get("source_id", "").split("") - if sid.strip() - ] - - # Handle images - if "image" in entity_type: - image_chunks = [ - data - for sid in source_ids - if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid)) - ] - if image_chunks: - # The generator expects a dictionary with an 'img_path' key, not a list of captions. - # We'll use the first image chunk found for this node. - node_data["images"] = image_chunks[0] - logger.debug("Attached image data to node %s", node_id) - - -import asyncio -from typing import List, Tuple - -import gradio as gr - -from graphgen.bases import BaseGraphStorage, BaseTokenizer -from graphgen.utils import run_concurrent - - -async def pre_tokenize( - graph_storage: BaseGraphStorage, - tokenizer: BaseTokenizer, - edges: List[Tuple], - nodes: List[Tuple], - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> Tuple[List, List]: - """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" - sem = asyncio.Semaphore(max_concurrent) - - async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: - async with sem: - data = obj[1] if is_node else obj[2] - if "length" not in data: - loop = asyncio.get_event_loop() - data["length"] = len( - await loop.run_in_executor( - None, tokenizer.encode, data["description"] - ) - ) - if is_node: - graph_storage.update_node(obj[0], obj[1]) - else: - graph_storage.update_edge(obj[0], obj[1], obj[2]) - return obj - - new_edges, new_nodes = await asyncio.gather( - run_concurrent( - lambda e: _patch_and_write(e, is_node=False), - edges, - desc="Pre-tokenizing edges", - unit="edge", - progress_bar=progress_bar, - ), - run_concurrent( - lambda n: _patch_and_write(n, is_node=True), - nodes, - desc="Pre-tokenizing nodes", - unit="node", - progress_bar=progress_bar, - ), - ) - - graph_storage.index_done_callback() - return new_edges, new_nodes - + self.chunk_storage: BaseKVStorage = init_storage( + backend="json_kv", + working_dir=working_dir, + namespace="chunk", + ) + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) + self.partition_kwargs = partition_kwargs + + def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + self.kg_instance.reload() + self.chunk_storage.reload() + + yield from self.partition() + + def partition(self) -> Iterable[pd.DataFrame]: + method = self.partition_kwargs["method"] + method_params = self.partition_kwargs["method_params"] + if method == "bfs": + logger.info("Partitioning knowledge graph using BFS method.") + partitioner = BFSPartitioner() + elif method == "dfs": + logger.info("Partitioning knowledge graph using DFS method.") + partitioner = DFSPartitioner() + elif method == "ece": + logger.info("Partitioning knowledge graph using ECE method.") + # TODO: before ECE partitioning, we need to: + # 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random + # 2. pre-tokenize nodes and edges to get the token length + self._pre_tokenize() + partitioner = ECEPartitioner() + elif method == "leiden": + logger.info("Partitioning knowledge graph using Leiden method.") + partitioner = LeidenPartitioner() + elif method == "anchor_bfs": + logger.info("Partitioning knowledge graph using Anchor BFS method.") + partitioner = AnchorBFSPartitioner( + anchor_type=method_params.get("anchor_type"), + anchor_ids=set(method_params.get("anchor_ids", [])) + if method_params.get("anchor_ids") + else None, + ) + else: + raise ValueError(f"Unsupported partition method: {method}") + + communities = partitioner.partition(g=self.kg_instance, **method_params) + logger.info("Partitioned the graph into %d communities.", len(communities)) + + for community in communities: + batch = partitioner.community2batch(community, g=self.kg_instance) + batch = self._attach_additional_data_to_node(batch) + + yield pd.DataFrame( + { + "nodes": [batch[0]], + "edges": [batch[1]], + } + ) + + def _pre_tokenize(self) -> None: + """Pre-tokenize all nodes and edges to add token length information.""" + logger.info("Starting pre-tokenization of nodes and edges...") + + nodes = self.kg_instance.get_all_nodes() + edges = self.kg_instance.get_all_edges() + + # Process nodes + for node_id, node_data in nodes: + if "length" not in node_data: + try: + description = node_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + node_data["length"] = len(tokens) + self.kg_instance.update_node(node_id, node_data) + except Exception as e: + logger.warning(f"Failed to tokenize node {node_id}: {e}") + node_data["length"] = 0 + + # Process edges + for u, v, edge_data in edges: + if "length" not in edge_data: + try: + description = edge_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + edge_data["length"] = len(tokens) + self.kg_instance.update_edge(u, v, edge_data) + except Exception as e: + logger.warning(f"Failed to tokenize edge {u}-{v}: {e}") + edge_data["length"] = 0 + + # Persist changes + self.kg_instance.index_done_callback() + logger.info("Pre-tokenization completed.") + + def _attach_additional_data_to_node(self, batch: tuple) -> tuple: + """ + Attach additional data from chunk_storage to nodes in the batch. + :param batch: tuple of (nodes_data, edges_data) + :return: updated batch with additional data attached to nodes + """ + nodes_data, edges_data = batch + + for node_id, node_data in nodes_data: + entity_type = (node_data.get("entity_type") or "").lower() + if not entity_type: + continue + + source_ids = [ + sid.strip() + for sid in node_data.get("source_id", "").split("") + if sid.strip() + ] + + # Handle images + if "image" in entity_type: + image_chunks = [ + data + for sid in source_ids + if "image" in sid.lower() + and (data := self.chunk_storage.get_by_id(sid)) + ] + if image_chunks: + # The generator expects a dictionary with an 'img_path' key, not a list of captions. + # We'll use the first image chunk found for this node. + node_data["images"] = image_chunks[0] + logger.debug("Attached image data to node %s", node_id) + + return nodes_data, edges_data From ea1603bedc2c18bd2b169ee4d3d357cd5f6c3c3f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Dec 2025 22:36:08 +0800 Subject: [PATCH 18/28] fix: fix lint problem --- graphgen/operators/__init__.py | 10 +++++----- graphgen/operators/chunk/chunk_service.py | 4 +--- graphgen/operators/extract/__init__.py | 2 +- graphgen/operators/extract/extract.py | 2 +- graphgen/operators/generate/__init__.py | 2 +- graphgen/operators/generate/generate.py | 1 - graphgen/operators/judge/judge_service.py | 10 ++++------ graphgen/operators/partition/partition_service.py | 6 +++--- graphgen/operators/quiz/quiz_service.py | 12 ++++++------ 9 files changed, 22 insertions(+), 27 deletions(-) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 1b56429c..35b004fe 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,12 +1,12 @@ from .build_kg import BuildKGService from .chunk import ChunkService -from .extract import extract_info +from .extract import extract from .generate import generate_qas -from .partition import partition_kg +from .judge import JudgeService +from .partition import PartitionService from .quiz import QuizService from .read import read from .search import search_all -from .judge import JudgeService operators = { "read": read, @@ -14,8 +14,8 @@ "build_kg": BuildKGService, "quiz": QuizService, "judge": JudgeService, - "extract_info": extract_info, + "extract_info": extract, "search_all": search_all, - "partition_kg": partition_kg, + "partition": PartitionService, "generate_qas": generate_qas, } diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index 0fcb20a3..307833ba 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -94,8 +94,6 @@ def chunk_documents(self, new_docs: list) -> list: **doc, } ) - self.chunk_storage.upsert( - {chunk["_chunk_id"]: chunk for chunk in chunks} - ) + self.chunk_storage.upsert({chunk["_chunk_id"]: chunk for chunk in chunks}) self.chunk_storage.index_done_callback() return chunks diff --git a/graphgen/operators/extract/__init__.py b/graphgen/operators/extract/__init__.py index a5d29d7c..d46dcdf1 100644 --- a/graphgen/operators/extract/__init__.py +++ b/graphgen/operators/extract/__init__.py @@ -1 +1 @@ -from .extract import extract_info +from .extract import extract diff --git a/graphgen/operators/extract/extract.py b/graphgen/operators/extract/extract.py index 8e65f1b2..ab69af40 100644 --- a/graphgen/operators/extract/extract.py +++ b/graphgen/operators/extract/extract.py @@ -7,7 +7,7 @@ from graphgen.utils import logger, run_concurrent -async def extract_info( +async def extract( llm_client: BaseLLMWrapper, chunk_storage: BaseKVStorage, extract_config: dict, diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index 035eca36..44c2111c 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -1 +1 @@ -from .generate_qas import generate_qas +from .generate import generate_qas diff --git a/graphgen/operators/generate/generate.py b/graphgen/operators/generate/generate.py index 86dbb9c9..62f92e6b 100644 --- a/graphgen/operators/generate/generate.py +++ b/graphgen/operators/generate/generate.py @@ -52,7 +52,6 @@ async def generate_qas( batches, desc="[4/4]Generating QAs", unit="batch", - progress_bar=progress_bar, ) # format diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py index 2801bb4e..16e8af4c 100644 --- a/graphgen/operators/judge/judge_service.py +++ b/graphgen/operators/judge/judge_service.py @@ -59,13 +59,11 @@ def judge(self, items: list[dict]) -> None: if isinstance(index, str): node_id = index node_data = self.graph_storage.get_node(node_id) - if node_data: - node_data["loss"] = loss - self.graph_storage.update_node(node_id, node_data) + node_data["loss"] = loss + self.graph_storage.update_node(node_id, node_data) elif isinstance(index, tuple): edge_source, edge_target = index edge_data = self.graph_storage.get_edge(edge_source, edge_target) - if edge_data: - edge_data["loss"] = loss - self.graph_storage.update_edge(edge_source, edge_target, edge_data) + edge_data["loss"] = loss + self.graph_storage.update_edge(edge_source, edge_target, edge_data) self.graph_storage.index_done_callback() diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index e8b08628..db1f6e87 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -1,5 +1,5 @@ import os -from typing import Any, Iterable +from typing import Iterable import pandas as pd @@ -101,7 +101,7 @@ def _pre_tokenize(self) -> None: node_data["length"] = len(tokens) self.kg_instance.update_node(node_id, node_data) except Exception as e: - logger.warning(f"Failed to tokenize node {node_id}: {e}") + logger.warning("Failed to tokenize node %s: %s", node_id, e) node_data["length"] = 0 # Process edges @@ -113,7 +113,7 @@ def _pre_tokenize(self) -> None: edge_data["length"] = len(tokens) self.kg_instance.update_edge(u, v, edge_data) except Exception as e: - logger.warning(f"Failed to tokenize edge {u}-{v}: {e}") + logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e) edge_data["length"] = 0 # Persist changes diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py index 811c24aa..9bbe99a3 100644 --- a/graphgen/operators/quiz/quiz_service.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -5,7 +5,7 @@ from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper from graphgen.common import init_llm, init_storage from graphgen.models import QuizGenerator -from graphgen.utils import compute_content_hash, logger, run_concurrent +from graphgen.utils import compute_dict_hash, logger, run_concurrent class QuizService: @@ -20,7 +20,7 @@ def __init__( self.graph_storage: BaseGraphStorage = init_storage( backend="networkx", working_dir=working_dir, namespace="graph" ) - # { _description_id: { "description": str, "quizzes": List[Tuple[str, str]] } } + # { _quiz_id: { "description": str, "quizzes": List[Tuple[str, str]] } } self.quiz_storage: BaseKVStorage = init_storage( backend="json_kv", working_dir=working_dir, namespace="quiz" ) @@ -37,8 +37,8 @@ def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: async def _process_single_quiz(self, item: tuple) -> dict | None: # if quiz in quiz_storage exists already, directly get it index, desc = item - _description_id = compute_content_hash(desc, prefix="quiz-") - if self.quiz_storage.get_by_id(_description_id): + _quiz_id = compute_dict_hash({"index": index, "description": desc}) + if self.quiz_storage.get_by_id(_quiz_id): return None tasks = [] @@ -56,7 +56,7 @@ async def _process_single_quiz(self, item: tuple) -> dict | None: rephrased_text = self.generator.parse_rephrased_text(new_description) quizzes.append((rephrased_text, gt)) return { - "_description_id": _description_id, + "_quiz_id": _quiz_id, "description": desc, "index": index, "quizzes": quizzes, @@ -100,7 +100,7 @@ def quiz(self) -> Iterable[pd.DataFrame]: if new_result: self.quiz_storage.upsert( { - new_result["_description_id"]: { + new_result["_quiz_id"]: { "description": new_result["description"], "quizzes": new_result["quizzes"], } From 244deb46b65720702fa07e9a7039e7cb2ed6d066 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 11:48:50 +0800 Subject: [PATCH 19/28] refactor: refactor op generate --- .../aggregated_config.yaml | 119 ++++++++++++++++++ .../generate_aggregated.sh | 0 .../input_examples/csv_demo.csv | 0 .../input_examples/extract_demo.txt | 0 .../input_examples/graphml_demo.graphml | 0 ...b3064cf17c5435814edfbee42ae6b19aac37d2.jpg | Bin ...3ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg | Bin ...fe2ae309fee014082db00bc2d87187a6bb5dca.jpg | Bin ...37df02964c9c3da8d8e9567ea19240b14cc742.jpg | Bin ...7ffe56f793f287b3399345aea31cd20eed2824.jpg | Bin ...ea0129a0475b2ab5b920a4cff20a4fb623517d.jpg | Bin .../input_examples/json_demo.json | 0 .../input_examples/jsonl_demo.jsonl | 0 .../input_examples/pdf_demo.pdf | Bin .../input_examples/search_dna_demo.jsonl | 0 .../input_examples/search_protein_demo.jsonl | 0 .../input_examples/search_rna_demo.jsonl | 0 .../input_examples/txt_demo.txt | 0 .../input_examples/vqa_demo.json | 0 .../output_examples/aggregated_chatml.json | 0 .../output_examples/atomic_alpaca.json | 0 .../output_examples/cot_sharegpt.json | 0 .../output_examples/multi-hop_chatml.json | 0 graphgen/configs/aggregated_config.yaml | 41 ------ graphgen/operators/generate/__init__.py | 2 +- graphgen/operators/generate/generate.py | 65 ---------- .../operators/generate/generate_service.py | 62 +++++++++ 27 files changed, 182 insertions(+), 107 deletions(-) create mode 100644 examples/generate/generate_aggregated_qa/aggregated_config.yaml rename {scripts/generate => examples/generate/generate_aggregated_qa}/generate_aggregated.sh (100%) rename {resources => examples}/input_examples/csv_demo.csv (100%) rename {resources => examples}/input_examples/extract_demo.txt (100%) rename {resources => examples}/input_examples/graphml_demo.graphml (100%) rename {resources => examples}/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg (100%) rename {resources => examples}/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg (100%) rename {resources => examples}/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg (100%) rename {resources => examples}/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg (100%) rename {resources => examples}/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg (100%) rename {resources => examples}/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg (100%) rename {resources => examples}/input_examples/json_demo.json (100%) rename {resources => examples}/input_examples/jsonl_demo.jsonl (100%) rename {resources => examples}/input_examples/pdf_demo.pdf (100%) rename {resources => examples}/input_examples/search_dna_demo.jsonl (100%) rename {resources => examples}/input_examples/search_protein_demo.jsonl (100%) rename {resources => examples}/input_examples/search_rna_demo.jsonl (100%) rename {resources => examples}/input_examples/txt_demo.txt (100%) rename {resources => examples}/input_examples/vqa_demo.json (100%) rename {resources => examples}/output_examples/aggregated_chatml.json (100%) rename {resources => examples}/output_examples/atomic_alpaca.json (100%) rename {resources => examples}/output_examples/cot_sharegpt.json (100%) rename {resources => examples}/output_examples/multi-hop_chatml.json (100%) delete mode 100644 graphgen/configs/aggregated_config.yaml delete mode 100644 graphgen/operators/generate/generate.py create mode 100644 graphgen/operators/generate/generate_service.py diff --git a/examples/generate/generate_aggregated_qa/aggregated_config.yaml b/examples/generate/generate_aggregated_qa/aggregated_config.yaml new file mode 100644 index 00000000..ed4c3487 --- /dev/null +++ b/examples/generate/generate_aggregated_qa/aggregated_config.yaml @@ -0,0 +1,119 @@ +global_params: + working_dir: cache + +nodes: + - id: read_files + op_name: read + type: source + dependencies: [] + params: + input_path: + - resources/input_examples/jsonl_demo.jsonl + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: aggregate + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 + concurrency_limit: 200 + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 16 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece + method_params: + max_units_per_community: 20 + min_units_per_community: 5 + max_tokens_per_community: 10240 + unit_sampling: max_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 16 + params: + method: aggregated + data_format: ChatML + +#pipeline: +# - name: read_step # step name is unique in the pipeline, and can be referenced by other steps +# op_key: read +# params: +# input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples +# +# - name: chunk_step +# op_key: chunk +# deps: [read_step] # chunk_step depends on read_step +# params: +# chunk_size: 1024 # chunk size for text splitting +# chunk_overlap: 100 # chunk overlap for text splitting +# +# - name: build_kg_step +# op_key: build_kg +# deps: [chunk_step] # build_kg_step depends on chunk_step +# +# - name: quiz_and_judge_step +# op_key: quiz_and_judge +# deps: [build_kg_step] # quiz_and_judge depends on build_kg_step +# params: +# quiz_samples: 2 # number of quiz samples to generate +# re_judge: false # whether to re-judge the existing quiz samples +# +# - name: partition_step +# op_key: partition +# deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step +# params: +# method: ece # ece is a custom partition method based on comprehension loss +# method_params: +# max_units_per_community: 20 # max nodes and edges per community +# min_units_per_community: 5 # min nodes and edges per community +# max_tokens_per_community: 10240 # max tokens per community +# unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss +# +# - name: generate_step +# op_key: generate +# deps: [partition_step] # generate_step depends on partition_step +# params: +# method: aggregated # atomic, aggregated, multi_hop, cot, vqa +# data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/scripts/generate/generate_aggregated.sh b/examples/generate/generate_aggregated_qa/generate_aggregated.sh similarity index 100% rename from scripts/generate/generate_aggregated.sh rename to examples/generate/generate_aggregated_qa/generate_aggregated.sh diff --git a/resources/input_examples/csv_demo.csv b/examples/input_examples/csv_demo.csv similarity index 100% rename from resources/input_examples/csv_demo.csv rename to examples/input_examples/csv_demo.csv diff --git a/resources/input_examples/extract_demo.txt b/examples/input_examples/extract_demo.txt similarity index 100% rename from resources/input_examples/extract_demo.txt rename to examples/input_examples/extract_demo.txt diff --git a/resources/input_examples/graphml_demo.graphml b/examples/input_examples/graphml_demo.graphml similarity index 100% rename from resources/input_examples/graphml_demo.graphml rename to examples/input_examples/graphml_demo.graphml diff --git a/resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg b/examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg similarity index 100% rename from resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg rename to examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg diff --git a/resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg b/examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg similarity index 100% rename from resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg rename to examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg diff --git a/resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg b/examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg similarity index 100% rename from resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg rename to examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg diff --git a/resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg b/examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg similarity index 100% rename from resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg rename to examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg diff --git a/resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg b/examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg similarity index 100% rename from resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg rename to examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg diff --git a/resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg b/examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg similarity index 100% rename from resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg rename to examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg diff --git a/resources/input_examples/json_demo.json b/examples/input_examples/json_demo.json similarity index 100% rename from resources/input_examples/json_demo.json rename to examples/input_examples/json_demo.json diff --git a/resources/input_examples/jsonl_demo.jsonl b/examples/input_examples/jsonl_demo.jsonl similarity index 100% rename from resources/input_examples/jsonl_demo.jsonl rename to examples/input_examples/jsonl_demo.jsonl diff --git a/resources/input_examples/pdf_demo.pdf b/examples/input_examples/pdf_demo.pdf similarity index 100% rename from resources/input_examples/pdf_demo.pdf rename to examples/input_examples/pdf_demo.pdf diff --git a/resources/input_examples/search_dna_demo.jsonl b/examples/input_examples/search_dna_demo.jsonl similarity index 100% rename from resources/input_examples/search_dna_demo.jsonl rename to examples/input_examples/search_dna_demo.jsonl diff --git a/resources/input_examples/search_protein_demo.jsonl b/examples/input_examples/search_protein_demo.jsonl similarity index 100% rename from resources/input_examples/search_protein_demo.jsonl rename to examples/input_examples/search_protein_demo.jsonl diff --git a/resources/input_examples/search_rna_demo.jsonl b/examples/input_examples/search_rna_demo.jsonl similarity index 100% rename from resources/input_examples/search_rna_demo.jsonl rename to examples/input_examples/search_rna_demo.jsonl diff --git a/resources/input_examples/txt_demo.txt b/examples/input_examples/txt_demo.txt similarity index 100% rename from resources/input_examples/txt_demo.txt rename to examples/input_examples/txt_demo.txt diff --git a/resources/input_examples/vqa_demo.json b/examples/input_examples/vqa_demo.json similarity index 100% rename from resources/input_examples/vqa_demo.json rename to examples/input_examples/vqa_demo.json diff --git a/resources/output_examples/aggregated_chatml.json b/examples/output_examples/aggregated_chatml.json similarity index 100% rename from resources/output_examples/aggregated_chatml.json rename to examples/output_examples/aggregated_chatml.json diff --git a/resources/output_examples/atomic_alpaca.json b/examples/output_examples/atomic_alpaca.json similarity index 100% rename from resources/output_examples/atomic_alpaca.json rename to examples/output_examples/atomic_alpaca.json diff --git a/resources/output_examples/cot_sharegpt.json b/examples/output_examples/cot_sharegpt.json similarity index 100% rename from resources/output_examples/cot_sharegpt.json rename to examples/output_examples/cot_sharegpt.json diff --git a/resources/output_examples/multi-hop_chatml.json b/examples/output_examples/multi-hop_chatml.json similarity index 100% rename from resources/output_examples/multi-hop_chatml.json rename to examples/output_examples/multi-hop_chatml.json diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml deleted file mode 100644 index 9c53ec9c..00000000 --- a/graphgen/configs/aggregated_config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -pipeline: - - name: read_step # step name is unique in the pipeline, and can be referenced by other steps - op_key: read - params: - input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: quiz_and_judge_step - op_key: quiz_and_judge - deps: [build_kg_step] # quiz_and_judge depends on build_kg_step - params: - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples - - - name: partition_step - op_key: partition - deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 20 # max nodes and edges per community - min_units_per_community: 5 # min nodes and edges per community - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: aggregated # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index 44c2111c..04057ce6 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -1 +1 @@ -from .generate import generate_qas +from .generate_service import GenerateService diff --git a/graphgen/operators/generate/generate.py b/graphgen/operators/generate/generate.py deleted file mode 100644 index 62f92e6b..00000000 --- a/graphgen/operators/generate/generate.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import ( - AggregatedGenerator, - AtomicGenerator, - CoTGenerator, - MultiHopGenerator, - VQAGenerator, -) -from graphgen.utils import logger, run_concurrent - - -async def generate_qas( - llm_client: BaseLLMWrapper, - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - generation_config: dict, - progress_bar: gr.Progress = None, -) -> list[dict[str, Any]]: - """ - Generate question-answer pairs based on nodes and edges. - :param llm_client: LLM client - :param batches - :param generation_config - :param progress_bar - :return: QA pairs - """ - method = generation_config["method"] - logger.info("[Generation] mode: %s, batches: %d", method, len(batches)) - - if method == "atomic": - generator = AtomicGenerator(llm_client) - elif method == "aggregated": - generator = AggregatedGenerator(llm_client) - elif method == "multi_hop": - generator = MultiHopGenerator(llm_client) - elif method == "cot": - generator = CoTGenerator(llm_client) - elif method in ["vqa"]: - generator = VQAGenerator(llm_client) - else: - raise ValueError(f"Unsupported generation mode: {method}") - - results = await run_concurrent( - generator.generate, - batches, - desc="[4/4]Generating QAs", - unit="batch", - ) - - # format - data_format = generation_config["data_format"] - logger.info("Output data format: %s", data_format) - - results = generator.format_generation_results( - results, output_data_format=data_format - ) - - return results diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py new file mode 100644 index 00000000..8b0f78e9 --- /dev/null +++ b/graphgen/operators/generate/generate_service.py @@ -0,0 +1,62 @@ +import pandas as pd + +from graphgen.bases import BaseLLMWrapper +from graphgen.common import init_llm +from graphgen.models import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, + VQAGenerator, +) +from graphgen.utils import logger, run_concurrent + + +class GenerateService: + """ + Generate question-answer pairs based on nodes and edges. + """ + + def __init__(self, method: str = "aggregated", data_format: str = "ChatML"): + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + self.method = method + self.data_format = data_format + + if self.method == "atomic": + self.generator = AtomicGenerator(self.llm_client) + elif self.method == "aggregated": + self.generator = AggregatedGenerator(self.llm_client) + elif self.method == "multi_hop": + self.generator = MultiHopGenerator(self.llm_client) + elif self.method == "cot": + self.generator = CoTGenerator(self.llm_client) + elif self.method in ["vqa"]: + self.generator = VQAGenerator(self.llm_client) + else: + raise ValueError(f"Unsupported generation mode: {method}") + + def __call__(self, batches: pd.DataFrame) -> pd.DataFrame: + items = batches.to_dict(orient="records") + return pd.DataFrame(self.generate(items)) + + def generate(self, items: list[dict]) -> list[dict]: + """ + Generate question-answer pairs based on nodes and edges. + :param items + :return: QA pairs + """ + logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) + items = [(item["nodes"], item["edges"]) for item in items] + results = run_concurrent( + self.generator.generate, + items, + desc="[4/4]Generating QAs", + unit="batch", + ) + + results = self.generator.format_generation_results( + results, output_data_format=self.data_format + ) + + return results From d460a2af9b42ed4e4f124ea3dd1efc68ac350083 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 11:50:06 +0800 Subject: [PATCH 20/28] feat: write results in output folder --- graphgen/engine.py | 7 ++--- graphgen/run.py | 65 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 18 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index 5a0e67d3..7e7243c0 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -198,7 +198,7 @@ def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: deps_set.update(n.dependencies) return all_ids - deps_set - def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]: + def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: sorted_nodes = self._topo_sort(self.config.nodes) for node in sorted_nodes: @@ -210,7 +210,4 @@ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, List[Any]]: def _fetch_result(ds: ray.data.Dataset) -> List[Any]: return ds.take_all() - results = ray.get( - [_fetch_result.remote(self.datasets[node_id]) for node_id in leaf_nodes] - ) - return dict(zip(leaf_nodes, results)) + return {node_id: self.datasets[node_id] for node_id in leaf_nodes} diff --git a/graphgen/run.py b/graphgen/run.py index c300a6aa..7a8bb654 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -1,13 +1,17 @@ import argparse import os import time -from importlib.resources import files +from importlib import resources +from typing import Any, Dict +import ray import yaml from dotenv import load_dotenv +from ray.data.block import Block +from ray.data.datasource.filename_provider import FilenameProvider -from graphgen.engine import Context, Engine, collect_ops -from graphgen.graphgen import GraphGen +from graphgen.engine import Engine +from graphgen.operators import operators from graphgen.utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -28,12 +32,38 @@ def save_config(config_path, global_config): ) +class NodeFilenameProvider(FilenameProvider): + def __init__(self, node_id: str): + self.node_id = node_id + + def get_filename_for_block( + self, block: Block, write_uuid: str, task_index: int, block_index: int + ) -> str: + # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json + return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" + + def get_filename_for_row( + self, + row: Dict[str, Any], + write_uuid: str, + task_index: int, + block_index: int, + row_index: int, + ) -> str: + raise NotImplementedError( + f"Row-based filenames are not supported by write_json. " + f"Node: {self.node_id}, write_uuid: {write_uuid}" + ) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config_file", help="Config parameters for GraphGen.", - default=files("graphgen").joinpath("configs", "aggregated_config.yaml"), + default=resources.files("graphgen") + .joinpath("configs") + .joinpath("aggregated_config.yaml"), type=str, ) parser.add_argument( @@ -51,6 +81,8 @@ def main(): with open(args.config_file, "r", encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.FullLoader) + engine = Engine(config, operators) + unique_id = int(time.time()) output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") @@ -65,15 +97,22 @@ def main(): unique_id, os.path.join(working_dir, f"{unique_id}.log"), ) - - graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) - - # share context between different steps - ctx = Context(config=config, graph_gen=graph_gen) - ops = collect_ops(config, graph_gen) - - # run operations - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) + ds = ray.data.from_items([]) + results = engine.execute(ds) + + for node_id, dataset in results.items(): + output_path = os.path.join(output_path, f"{node_id}") + os.makedirs(output_path, exist_ok=True) + dataset.write_json( + output_path, + filename_provider=NodeFilenameProvider(node_id), + pandas_json_args_fn=lambda: { + "force_ascii": False, + "orient": "records", + "lines": True, + }, + ) + logger.info("Node %s results saved to %s", node_id, output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) From cd011adf71e62635a84c9806096e8379fc86c40b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 12:23:43 +0800 Subject: [PATCH 21/28] fix: raise error when no dataset is created --- .../baselines/generate_all_baselines.sh | 0 .../baselines/generate_bds.sh | 0 .../baselines/generate_entigraph.sh | 0 .../baselines/generate_genie.sh | 0 .../baselines/generate_longform.sh | 0 .../baselines/generate_selfqa.sh | 0 .../baselines/generate_wrap.sh | 0 {graphgen => examples}/configs/README.md | 0 .../configs/cot_config.yaml | 0 .../configs/multi_hop_config.yaml | 0 .../schema_guided_extraction_config.yaml | 0 .../configs/search_dna_config.yaml | 0 .../configs/search_protein_config.yaml | 0 .../configs/search_rna_config.yaml | 0 .../configs/vqa_config.yaml | 0 {scripts => examples}/evaluate/evaluate.sh | 0 .../extract/extract_schema_guided.sh | 0 .../generate/generate_aggregated_qa/README.md | 3 + .../aggregated_config.yaml | 66 ++++--------------- .../generate_aggregated.sh | 2 +- .../generate/generate_atomic_qa/README.md | 3 + .../generate_atomic_qa/atomic_config.yaml | 45 +++++++++++++ .../generate_atomic_qa/generate_atomic.sh | 3 + .../generate/generate_cot.sh | 0 .../generate/generate_multi_hop.sh | 0 .../generate/generate_vqa.sh | 0 .../search/build_db/build_dna_blast_db.sh | 0 .../search/build_db/build_protein_blast_db.sh | 0 .../search/build_db/build_rna_blast_db.sh | 0 {scripts => examples}/search/search_dna.sh | 0 {scripts => examples}/search/search_rna.sh | 0 .../search/search_uniprot.sh | 0 graphgen/configs/__init__.py | 0 graphgen/configs/atomic_config.yaml | 31 --------- graphgen/operators/read/read.py | 3 +- scripts/generate/generate_atomic.sh | 3 - 36 files changed, 68 insertions(+), 91 deletions(-) rename {scripts => examples}/baselines/generate_all_baselines.sh (100%) rename {scripts => examples}/baselines/generate_bds.sh (100%) rename {scripts => examples}/baselines/generate_entigraph.sh (100%) rename {scripts => examples}/baselines/generate_genie.sh (100%) rename {scripts => examples}/baselines/generate_longform.sh (100%) rename {scripts => examples}/baselines/generate_selfqa.sh (100%) rename {scripts => examples}/baselines/generate_wrap.sh (100%) rename {graphgen => examples}/configs/README.md (100%) rename {graphgen => examples}/configs/cot_config.yaml (100%) rename {graphgen => examples}/configs/multi_hop_config.yaml (100%) rename {graphgen => examples}/configs/schema_guided_extraction_config.yaml (100%) rename {graphgen => examples}/configs/search_dna_config.yaml (100%) rename {graphgen => examples}/configs/search_protein_config.yaml (100%) rename {graphgen => examples}/configs/search_rna_config.yaml (100%) rename {graphgen => examples}/configs/vqa_config.yaml (100%) rename {scripts => examples}/evaluate/evaluate.sh (100%) rename {scripts => examples}/extract/extract_schema_guided.sh (100%) create mode 100644 examples/generate/generate_aggregated_qa/README.md create mode 100644 examples/generate/generate_atomic_qa/README.md create mode 100644 examples/generate/generate_atomic_qa/atomic_config.yaml create mode 100644 examples/generate/generate_atomic_qa/generate_atomic.sh rename {scripts => examples}/generate/generate_cot.sh (100%) rename {scripts => examples}/generate/generate_multi_hop.sh (100%) rename {scripts => examples}/generate/generate_vqa.sh (100%) rename {scripts => examples}/search/build_db/build_dna_blast_db.sh (100%) rename {scripts => examples}/search/build_db/build_protein_blast_db.sh (100%) rename {scripts => examples}/search/build_db/build_rna_blast_db.sh (100%) rename {scripts => examples}/search/search_dna.sh (100%) rename {scripts => examples}/search/search_rna.sh (100%) rename {scripts => examples}/search/search_uniprot.sh (100%) delete mode 100644 graphgen/configs/__init__.py delete mode 100644 graphgen/configs/atomic_config.yaml delete mode 100644 scripts/generate/generate_atomic.sh diff --git a/scripts/baselines/generate_all_baselines.sh b/examples/baselines/generate_all_baselines.sh similarity index 100% rename from scripts/baselines/generate_all_baselines.sh rename to examples/baselines/generate_all_baselines.sh diff --git a/scripts/baselines/generate_bds.sh b/examples/baselines/generate_bds.sh similarity index 100% rename from scripts/baselines/generate_bds.sh rename to examples/baselines/generate_bds.sh diff --git a/scripts/baselines/generate_entigraph.sh b/examples/baselines/generate_entigraph.sh similarity index 100% rename from scripts/baselines/generate_entigraph.sh rename to examples/baselines/generate_entigraph.sh diff --git a/scripts/baselines/generate_genie.sh b/examples/baselines/generate_genie.sh similarity index 100% rename from scripts/baselines/generate_genie.sh rename to examples/baselines/generate_genie.sh diff --git a/scripts/baselines/generate_longform.sh b/examples/baselines/generate_longform.sh similarity index 100% rename from scripts/baselines/generate_longform.sh rename to examples/baselines/generate_longform.sh diff --git a/scripts/baselines/generate_selfqa.sh b/examples/baselines/generate_selfqa.sh similarity index 100% rename from scripts/baselines/generate_selfqa.sh rename to examples/baselines/generate_selfqa.sh diff --git a/scripts/baselines/generate_wrap.sh b/examples/baselines/generate_wrap.sh similarity index 100% rename from scripts/baselines/generate_wrap.sh rename to examples/baselines/generate_wrap.sh diff --git a/graphgen/configs/README.md b/examples/configs/README.md similarity index 100% rename from graphgen/configs/README.md rename to examples/configs/README.md diff --git a/graphgen/configs/cot_config.yaml b/examples/configs/cot_config.yaml similarity index 100% rename from graphgen/configs/cot_config.yaml rename to examples/configs/cot_config.yaml diff --git a/graphgen/configs/multi_hop_config.yaml b/examples/configs/multi_hop_config.yaml similarity index 100% rename from graphgen/configs/multi_hop_config.yaml rename to examples/configs/multi_hop_config.yaml diff --git a/graphgen/configs/schema_guided_extraction_config.yaml b/examples/configs/schema_guided_extraction_config.yaml similarity index 100% rename from graphgen/configs/schema_guided_extraction_config.yaml rename to examples/configs/schema_guided_extraction_config.yaml diff --git a/graphgen/configs/search_dna_config.yaml b/examples/configs/search_dna_config.yaml similarity index 100% rename from graphgen/configs/search_dna_config.yaml rename to examples/configs/search_dna_config.yaml diff --git a/graphgen/configs/search_protein_config.yaml b/examples/configs/search_protein_config.yaml similarity index 100% rename from graphgen/configs/search_protein_config.yaml rename to examples/configs/search_protein_config.yaml diff --git a/graphgen/configs/search_rna_config.yaml b/examples/configs/search_rna_config.yaml similarity index 100% rename from graphgen/configs/search_rna_config.yaml rename to examples/configs/search_rna_config.yaml diff --git a/graphgen/configs/vqa_config.yaml b/examples/configs/vqa_config.yaml similarity index 100% rename from graphgen/configs/vqa_config.yaml rename to examples/configs/vqa_config.yaml diff --git a/scripts/evaluate/evaluate.sh b/examples/evaluate/evaluate.sh similarity index 100% rename from scripts/evaluate/evaluate.sh rename to examples/evaluate/evaluate.sh diff --git a/scripts/extract/extract_schema_guided.sh b/examples/extract/extract_schema_guided.sh similarity index 100% rename from scripts/extract/extract_schema_guided.sh rename to examples/extract/extract_schema_guided.sh diff --git a/examples/generate/generate_aggregated_qa/README.md b/examples/generate/generate_aggregated_qa/README.md new file mode 100644 index 00000000..ab08693b --- /dev/null +++ b/examples/generate/generate_aggregated_qa/README.md @@ -0,0 +1,3 @@ +# Generate Aggregated QAs + +Aggregated mode is one of three question-answering scenarios in GraphGen (alongside atomic and multi-hop) designed to generate synthetic training data that incorporates complex, integrated knowledge from multiple sources. \ No newline at end of file diff --git a/examples/generate/generate_aggregated_qa/aggregated_config.yaml b/examples/generate/generate_aggregated_qa/aggregated_config.yaml index ed4c3487..8453b281 100644 --- a/examples/generate/generate_aggregated_qa/aggregated_config.yaml +++ b/examples/generate/generate_aggregated_qa/aggregated_config.yaml @@ -2,13 +2,13 @@ global_params: working_dir: cache nodes: - - id: read_files + - id: read_files # id is unique in the pipeline, and can be referenced by other steps op_name: read type: source dependencies: [] params: input_path: - - resources/input_examples/jsonl_demo.jsonl + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples - id: chunk_documents op_name: chunk @@ -18,8 +18,8 @@ nodes: execution_params: replicas: 4 params: - chunk_size: 1024 - chunk_overlap: 100 + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting - id: build_kg op_name: build_kg @@ -39,7 +39,7 @@ nodes: replicas: 1 batch_size: 128 params: - quiz_samples: 2 + quiz_samples: 2 # number of quiz samples to generate concurrency_limit: 200 - id: judge @@ -57,12 +57,12 @@ nodes: dependencies: - judge params: - method: ece + method: ece # ece is a custom partition method based on comprehension loss method_params: - max_units_per_community: 20 - min_units_per_community: 5 - max_tokens_per_community: 10240 - unit_sampling: max_loss + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss - id: generate op_name: generate @@ -73,47 +73,5 @@ nodes: replicas: 1 batch_size: 16 params: - method: aggregated - data_format: ChatML - -#pipeline: -# - name: read_step # step name is unique in the pipeline, and can be referenced by other steps -# op_key: read -# params: -# input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples -# -# - name: chunk_step -# op_key: chunk -# deps: [read_step] # chunk_step depends on read_step -# params: -# chunk_size: 1024 # chunk size for text splitting -# chunk_overlap: 100 # chunk overlap for text splitting -# -# - name: build_kg_step -# op_key: build_kg -# deps: [chunk_step] # build_kg_step depends on chunk_step -# -# - name: quiz_and_judge_step -# op_key: quiz_and_judge -# deps: [build_kg_step] # quiz_and_judge depends on build_kg_step -# params: -# quiz_samples: 2 # number of quiz samples to generate -# re_judge: false # whether to re-judge the existing quiz samples -# -# - name: partition_step -# op_key: partition -# deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step -# params: -# method: ece # ece is a custom partition method based on comprehension loss -# method_params: -# max_units_per_community: 20 # max nodes and edges per community -# min_units_per_community: 5 # min nodes and edges per community -# max_tokens_per_community: 10240 # max tokens per community -# unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss -# -# - name: generate_step -# op_key: generate -# deps: [partition_step] # generate_step depends on partition_step -# params: -# method: aggregated # atomic, aggregated, multi_hop, cot, vqa -# data_format: ChatML # Alpaca, Sharegpt, ChatML + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_aggregated_qa/generate_aggregated.sh b/examples/generate/generate_aggregated_qa/generate_aggregated.sh index 7117eff1..cae544ff 100644 --- a/examples/generate/generate_aggregated_qa/generate_aggregated.sh +++ b/examples/generate/generate_aggregated_qa/generate_aggregated.sh @@ -1,3 +1,3 @@ python3 -m graphgen.run \ ---config_file graphgen/configs/aggregated_config.yaml \ +--config_file examples/generate/generate_aggregated_qa/aggregated_config.yaml \ --output_dir cache/ diff --git a/examples/generate/generate_atomic_qa/README.md b/examples/generate/generate_atomic_qa/README.md new file mode 100644 index 00000000..e979b182 --- /dev/null +++ b/examples/generate/generate_atomic_qa/README.md @@ -0,0 +1,3 @@ +# Generate Atomic QAs + +Atomic mode generates question-answer pairs that test basic, isolated knowledge from individual facts or relationships in the knowledge graph. \ No newline at end of file diff --git a/examples/generate/generate_atomic_qa/atomic_config.yaml b/examples/generate/generate_atomic_qa/atomic_config.yaml new file mode 100644 index 00000000..35ebb086 --- /dev/null +++ b/examples/generate/generate_atomic_qa/atomic_config.yaml @@ -0,0 +1,45 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - resources/input_examples/json_demo.json + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: dfs + method_params: + max_units_per_community: 1 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + params: + method: atomic + data_format: Alpaca diff --git a/examples/generate/generate_atomic_qa/generate_atomic.sh b/examples/generate/generate_atomic_qa/generate_atomic.sh new file mode 100644 index 00000000..c9fdb977 --- /dev/null +++ b/examples/generate/generate_atomic_qa/generate_atomic.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_atomic_qa/atomic_config.yaml \ +--output_dir cache/ diff --git a/scripts/generate/generate_cot.sh b/examples/generate/generate_cot.sh similarity index 100% rename from scripts/generate/generate_cot.sh rename to examples/generate/generate_cot.sh diff --git a/scripts/generate/generate_multi_hop.sh b/examples/generate/generate_multi_hop.sh similarity index 100% rename from scripts/generate/generate_multi_hop.sh rename to examples/generate/generate_multi_hop.sh diff --git a/scripts/generate/generate_vqa.sh b/examples/generate/generate_vqa.sh similarity index 100% rename from scripts/generate/generate_vqa.sh rename to examples/generate/generate_vqa.sh diff --git a/scripts/search/build_db/build_dna_blast_db.sh b/examples/search/build_db/build_dna_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_dna_blast_db.sh rename to examples/search/build_db/build_dna_blast_db.sh diff --git a/scripts/search/build_db/build_protein_blast_db.sh b/examples/search/build_db/build_protein_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_protein_blast_db.sh rename to examples/search/build_db/build_protein_blast_db.sh diff --git a/scripts/search/build_db/build_rna_blast_db.sh b/examples/search/build_db/build_rna_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_rna_blast_db.sh rename to examples/search/build_db/build_rna_blast_db.sh diff --git a/scripts/search/search_dna.sh b/examples/search/search_dna.sh similarity index 100% rename from scripts/search/search_dna.sh rename to examples/search/search_dna.sh diff --git a/scripts/search/search_rna.sh b/examples/search/search_rna.sh similarity index 100% rename from scripts/search/search_rna.sh rename to examples/search/search_rna.sh diff --git a/scripts/search/search_uniprot.sh b/examples/search/search_uniprot.sh similarity index 100% rename from scripts/search/search_uniprot.sh rename to examples/search/search_uniprot.sh diff --git a/graphgen/configs/__init__.py b/graphgen/configs/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml deleted file mode 100644 index f8ae2218..00000000 --- a/graphgen/configs/atomic_config.yaml +++ /dev/null @@ -1,31 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg] # partition_step depends on build_kg - params: - method: dfs # partition method, support: dfs, bfs, ece, leiden - method_params: - max_units_per_community: 1 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: atomic # atomic, aggregated, multi_hop, cot, vqa - data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index f98a97ca..378316f8 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -106,8 +106,7 @@ def read( # 4. Combine all datasets if not read_tasks: - logger.warning("[READ] No datasets created") - return ray.data.from_items([]) + raise ValueError("No datasets created from the provided files.") if len(read_tasks) == 1: combined_ds = read_tasks[0] diff --git a/scripts/generate/generate_atomic.sh b/scripts/generate/generate_atomic.sh deleted file mode 100644 index 822d6c48..00000000 --- a/scripts/generate/generate_atomic.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/atomic_config.yaml \ ---output_dir cache/ From aab743819717511c6c2316a13e8572377a2ced69 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 12:30:00 +0800 Subject: [PATCH 22/28] fix: return generator in ece_partitioner --- graphgen/operators/partition/partition_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index db1f6e87..a914a7ac 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -72,7 +72,6 @@ def partition(self) -> Iterable[pd.DataFrame]: raise ValueError(f"Unsupported partition method: {method}") communities = partitioner.partition(g=self.kg_instance, **method_params) - logger.info("Partitioned the graph into %d communities.", len(communities)) for community in communities: batch = partitioner.community2batch(community, g=self.kg_instance) From 7643b9fa5a7012c195ee82e0ddcd3ddd8b7bb18f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 12:30:49 +0800 Subject: [PATCH 23/28] fix: return generator in ece_partitioner --- graphgen/models/partitioner/ece_partitioner.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index cb0ce861..fcf776c7 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from tqdm import tqdm @@ -59,7 +59,7 @@ def partition( max_tokens_per_community: int = 10240, unit_sampling: str = "random", **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes: List[Tuple[str, dict]] = g.get_all_nodes() edges: List[Tuple[str, str, dict]] = g.get_all_edges() @@ -73,7 +73,6 @@ def partition( used_n: Set[str] = set() used_e: Set[frozenset[str]] = set() - communities: List[Community] = [] all_units = self._sort_units(all_units, unit_sampling) @@ -141,7 +140,7 @@ def _add_unit(u): return None return Community( - id=len(communities), + id=seed_unit[1], nodes=list(community_nodes.keys()), edges=[(u, v) for (u, v), _ in community_edges.items()], ) @@ -153,7 +152,5 @@ def _add_unit(u): ): continue comm = _grow_community(unit) - if comm is not None: - communities.append(comm) - - return communities + if comm: + yield comm From c42b60406de89ede6a46b6bdaba07c1a344510d2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 17:11:58 +0800 Subject: [PATCH 24/28] refactor: refactor data format to support multi-modal input --- examples/configs/cot_config.yaml | 33 ---------- examples/configs/multi_hop_config.yaml | 34 ---------- examples/configs/vqa_config.yaml | 32 ---------- .../extract_schema_guided.sh | 0 .../schema_guided_extraction_config.yaml | 0 .../aggregated_config.yaml | 4 +- .../generate_atomic_qa/atomic_config.yaml | 10 ++- examples/generate/generate_cot.sh | 3 - examples/generate/generate_cot_qa/README.md | 1 + .../generate/generate_cot_qa/cot_config.yaml | 55 ++++++++++++++++ .../generate/generate_cot_qa/generate_cot.sh | 3 + examples/generate/generate_multi_hop.sh | 3 - .../generate/generate_multi_hop_qa/README.md | 1 + .../generate_multi_hop.sh | 3 + .../multi_hop_config.yaml | 56 +++++++++++++++++ examples/generate/generate_vqa.sh | 3 - examples/generate/generate_vqa/README.md | 1 + .../generate/generate_vqa/generate_vqa.sh | 3 + .../generate/generate_vqa/vqa_config.yaml | 57 +++++++++++++++++ examples/input_examples/vqa_demo.json | 63 ++++++++++--------- graphgen/bases/base_reader.py | 3 +- graphgen/models/__init__.py | 2 +- graphgen/models/generator/vqa_generator.py | 4 +- graphgen/models/reader/json_reader.py | 34 ++++++++-- 24 files changed, 261 insertions(+), 147 deletions(-) delete mode 100644 examples/configs/cot_config.yaml delete mode 100644 examples/configs/multi_hop_config.yaml delete mode 100644 examples/configs/vqa_config.yaml rename examples/extract/{ => extract_schema_guided}/extract_schema_guided.sh (100%) rename examples/{configs => extract/extract_schema_guided}/schema_guided_extraction_config.yaml (100%) delete mode 100644 examples/generate/generate_cot.sh create mode 100644 examples/generate/generate_cot_qa/README.md create mode 100644 examples/generate/generate_cot_qa/cot_config.yaml create mode 100644 examples/generate/generate_cot_qa/generate_cot.sh delete mode 100644 examples/generate/generate_multi_hop.sh create mode 100644 examples/generate/generate_multi_hop_qa/README.md create mode 100644 examples/generate/generate_multi_hop_qa/generate_multi_hop.sh create mode 100644 examples/generate/generate_multi_hop_qa/multi_hop_config.yaml delete mode 100644 examples/generate/generate_vqa.sh create mode 100644 examples/generate/generate_vqa/README.md create mode 100644 examples/generate/generate_vqa/generate_vqa.sh create mode 100644 examples/generate/generate_vqa/vqa_config.yaml diff --git a/examples/configs/cot_config.yaml b/examples/configs/cot_config.yaml deleted file mode 100644 index b09e341d..00000000 --- a/examples/configs/cot_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg - params: - method: leiden # leiden is a partitioner detection algorithm - method_params: - max_size: 20 # Maximum size of communities - use_lcc: false # whether to use the largest connected component - random_seed: 42 # random seed for partitioning - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: cot # atomic, aggregated, multi_hop, cot, vqa - data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/examples/configs/multi_hop_config.yaml b/examples/configs/multi_hop_config.yaml deleted file mode 100644 index 4b8051b4..00000000 --- a/examples/configs/multi_hop_config.yaml +++ /dev/null @@ -1,34 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 - min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: multi_hop # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/configs/vqa_config.yaml b/examples/configs/vqa_config.yaml deleted file mode 100644 index 06eba5c4..00000000 --- a/examples/configs/vqa_config.yaml +++ /dev/null @@ -1,32 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: anchor_bfs # partition method - method_params: - anchor_type: image # node type to select anchor nodes - max_units_per_community: 10 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: vqa # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/extract/extract_schema_guided.sh b/examples/extract/extract_schema_guided/extract_schema_guided.sh similarity index 100% rename from examples/extract/extract_schema_guided.sh rename to examples/extract/extract_schema_guided/extract_schema_guided.sh diff --git a/examples/configs/schema_guided_extraction_config.yaml b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml similarity index 100% rename from examples/configs/schema_guided_extraction_config.yaml rename to examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml diff --git a/examples/generate/generate_aggregated_qa/aggregated_config.yaml b/examples/generate/generate_aggregated_qa/aggregated_config.yaml index 8453b281..09f95653 100644 --- a/examples/generate/generate_aggregated_qa/aggregated_config.yaml +++ b/examples/generate/generate_aggregated_qa/aggregated_config.yaml @@ -49,7 +49,7 @@ nodes: - quiz execution_params: replicas: 1 - batch_size: 16 + batch_size: 128 - id: partition op_name: partition @@ -71,7 +71,7 @@ nodes: - partition execution_params: replicas: 1 - batch_size: 16 + batch_size: 128 params: method: aggregated # atomic, aggregated, multi_hop, cot, vqa data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_atomic_qa/atomic_config.yaml b/examples/generate/generate_atomic_qa/atomic_config.yaml index 35ebb086..a76272b9 100644 --- a/examples/generate/generate_atomic_qa/atomic_config.yaml +++ b/examples/generate/generate_atomic_qa/atomic_config.yaml @@ -8,13 +8,15 @@ nodes: dependencies: [] params: input_path: - - resources/input_examples/json_demo.json + - examples/input_examples/json_demo.json - id: chunk op_name: chunk type: map_batch dependencies: - read + execution_params: + replicas: 4 params: chunk_size: 1024 chunk_overlap: 100 @@ -22,6 +24,9 @@ nodes: - id: build_kg op_name: build_kg type: map_batch + execution_params: + replicas: 1 + batch_size: 128 dependencies: - chunk @@ -40,6 +45,9 @@ nodes: type: map_batch dependencies: - partition + execution_params: + replicas: 1 + batch_size: 128 params: method: atomic data_format: Alpaca diff --git a/examples/generate/generate_cot.sh b/examples/generate/generate_cot.sh deleted file mode 100644 index 9c2ee151..00000000 --- a/examples/generate/generate_cot.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/cot_config.yaml \ ---output_dir cache/ diff --git a/examples/generate/generate_cot_qa/README.md b/examples/generate/generate_cot_qa/README.md new file mode 100644 index 00000000..37afe9c7 --- /dev/null +++ b/examples/generate/generate_cot_qa/README.md @@ -0,0 +1 @@ +# Generate CoT QAs diff --git a/examples/generate/generate_cot_qa/cot_config.yaml b/examples/generate/generate_cot_qa/cot_config.yaml new file mode 100644 index 00000000..1daf7fa1 --- /dev/null +++ b/examples/generate/generate_cot_qa/cot_config.yaml @@ -0,0 +1,55 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/txt_demo.txt + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + execution_params: + replicas: 1 + batch_size: 128 + dependencies: + - chunk + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: leiden + method_params: + max_size: 20 + use_lcc: false + random_seed: 42 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: cot + data_format: Sharegpt diff --git a/examples/generate/generate_cot_qa/generate_cot.sh b/examples/generate/generate_cot_qa/generate_cot.sh new file mode 100644 index 00000000..d34d503f --- /dev/null +++ b/examples/generate/generate_cot_qa/generate_cot.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_cot_qa/cot_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_multi_hop.sh b/examples/generate/generate_multi_hop.sh deleted file mode 100644 index 6480e080..00000000 --- a/examples/generate/generate_multi_hop.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/multi_hop_config.yaml \ ---output_dir cache/ diff --git a/examples/generate/generate_multi_hop_qa/README.md b/examples/generate/generate_multi_hop_qa/README.md new file mode 100644 index 00000000..dcee73be --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/README.md @@ -0,0 +1 @@ +# Generate Multi-hop QAs diff --git a/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh b/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh new file mode 100644 index 00000000..2bfbc91c --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_multi_hop_qa/multi_hop_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml new file mode 100644 index 00000000..1ef2f13f --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml @@ -0,0 +1,56 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/csv_demo.csv + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: ece + method_params: + max_units_per_community: 3 + min_units_per_community: 3 + max_tokens_per_community: 10240 + unit_sampling: random + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: multi_hop + data_format: ChatML diff --git a/examples/generate/generate_vqa.sh b/examples/generate/generate_vqa.sh deleted file mode 100644 index f7fd2726..00000000 --- a/examples/generate/generate_vqa.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/vqa_config.yaml \ ---output_dir cache/ diff --git a/examples/generate/generate_vqa/README.md b/examples/generate/generate_vqa/README.md new file mode 100644 index 00000000..42b13865 --- /dev/null +++ b/examples/generate/generate_vqa/README.md @@ -0,0 +1 @@ +# Generate VQAs \ No newline at end of file diff --git a/examples/generate/generate_vqa/generate_vqa.sh b/examples/generate/generate_vqa/generate_vqa.sh new file mode 100644 index 00000000..7c7313fa --- /dev/null +++ b/examples/generate/generate_vqa/generate_vqa.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_vqa/vqa_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_vqa/vqa_config.yaml b/examples/generate/generate_vqa/vqa_config.yaml new file mode 100644 index 00000000..335c5e5f --- /dev/null +++ b/examples/generate/generate_vqa/vqa_config.yaml @@ -0,0 +1,57 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/vqa_demo.json + modalities: + - text + - image + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: anchor_bfs + method_params: + anchor_type: image + max_units_per_community: 10 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: vqa + data_format: ChatML \ No newline at end of file diff --git a/examples/input_examples/vqa_demo.json b/examples/input_examples/vqa_demo.json index 9d9661ec..d3aed723 100644 --- a/examples/input_examples/vqa_demo.json +++ b/examples/input_examples/vqa_demo.json @@ -9,11 +9,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg", - "image_caption": [ + "content":{ + "img_path": "examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg", + "image_caption": [ "Fig. 1. (A) Physical map of the hrp gene cluster of E. amylovora (4, 18, 29), showing restriction sites: B, Bam HI; E, Eco RI; H, Hind II. Gene hrpN, encoding harpin, is contained in the 1.3 kb Hind II fragment indicated by the solid bar. The shaded region (including hrpN) contains that part of the hrp gene cluster in which most transposon insertions, exemplified by K49, a Tn10 mini-kan (30) insertion, abolish the HR and pathogenicity phenotypes. Most " - ], - "image_footnote": [] + ] + } }, { "type": "text", @@ -25,11 +26,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg", - "image_caption": [ + "content": { + "img_path": "examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg", + "image_caption": [ "Fig. 2. Tobacco leaf showing responses 24 hours after infitration of sectors (7) with the following preparations: 1,, living E. coli DH5α (pCPP9) $( 1 \\times 1 0 ^ { 8 } / \\mathrm { m l } )$ ; 2, E. coli DH5α (pCPP430) $( 1 \\ \\times \\ 1 0 ^ { 8 } / \\mathrm { m l } )$ ; 3, E. coli DH5α (pCPP430K49) $( 1 \\times 1 0 ^ { 8 } / \\mathrm { m } )$ ; 4, E. amylovora Ea321 $( 1 \\times 1 0 ^ { 8 } / \\mathsf { m l } )$ ; 5, Ea321K49, an hrp mutant $( 1 \\times 1 0 ^ { 8 } / \\mathsf { m } )$ , 8, heat-treated CFEP from $\\pmb { \\varepsilon }$ coli ${ \\mathsf { D } } { \\mathsf { H } } { \\mathsf { S } } { \\mathsf { { \\alpha } } } ( { \\mathsf { P } } { \\mathsf { C } } { \\mathsf { P } } { \\mathsf { P } } { \\mathsf { 9 } } )$ ; 9,heat-treated CFEP from E. coli DH5α(pCPP430); 10, heat-treated CFEP from E. coli DH5α(pCPP430K49); 11, heattreated CFEP from $\\boldsymbol { \\varepsilon }$ amylovora Ea321; 12, heat-treated CFEP from Ea321K49; 6, harpin $( 1 . 1 \\mu M )$ from E. coli DH5α(pCPP430) eluted from SDS-polyacrylamide gel; 7, same preparation as 6, but protease treated for 2 hours then heated for io min to inactivate protease; 13, harpin $( 1 \\pmb { \\mu } \\pmb { M } )$ from E. amylovora Ea321 eluted from SDS-polyacrylamide gel; 14, same preparation as 13 but with protease treatment as sample 7. Harpin solutions $< - 0 . 3 \\mu \\mathsf { m }$ do not cause collapse of infitrated tissue; spotty and incomplete collapse is caused by harpin between 0.3 and $0 . 5 ~ { \\mu } \\mathsf { m }$ . " - ], - "image_footnote": [] + ] + } }, { "type": "text", @@ -41,10 +43,12 @@ }, { "type": "table", - "img_path": "resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg", - "table_caption": [], - "table_footnote": [], - "table_body": "
Protease per milliterTissue collapseHarpin detected
0++
5μg++
10μg++
20 μgWeak+
40 μg-
80μg
80μg + 0.5 mM PMSF++
Cell-free supernatant
" + "content": { + "img_path": "examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg", + "table_caption": [], + "table_footnote": [], + "table_body": "
Protease per milliterTissue collapseHarpin detected
0++
5μg++
10μg++
20 μgWeak+
40 μg-
80μg
80μg + 0.5 mM PMSF++
Cell-free supernatant
" + } }, { "type": "text", @@ -52,11 +56,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg", - "image_caption": [ - "Fig. 3. SDS-polyacrylamide gel electrophoresis of CFEPs and purified harpin. Lanes: 1, purified harpin $( 1 . 5 \\ \\mathsf { \\pmb { \\mu } } \\mathsf { \\pmb { \\mathsf { g } } } )$ from E. coli $\\mathsf { D M } 5 \\alpha ( \\mathsf { p C P } 4 3 0 )$ incubated with protease (9) for 1 hour; 2, purified harpin $( 1 . 5 \\mu \\mathfrak { g } )$ from E. amylovora Ea321 incubated with protease for 1 hour; 3, same as 1, but without treatment with protease; 4, same as 2, but without treatment with protease; 5, CFEP (5 ${ \\pmb { \\mu } } ( { \\pmb q } )$ from E. coli DH5α(pCPP9) treated at $1 0 0 ^ { \\circ } \\mathbb { C }$ for 10'min; 6, CFEP $( 5 \\ \\pmb { \\mu } \\pmb { \\mu } )$ from E. coli DH5a(pCPP430K49) treated at $\\pmb { 1 0 0 } \\pmb { \\circ } \\pmb { \\subset }$ for 10 min; 7, CFEP $( 5 ~ \\mu 9 )$ from E. amylovora Ea321 treated " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg", + "image_caption": [ + "Fig. 3. SDS-polyacrylamide gel electrophoresis of CFEPs and purified harpin. Lanes: 1, purified harpin $( 1 . 5 \\ \\mathsf { \\pmb { \\mu } } \\mathsf { \\pmb { \\mathsf { g } } } )$ from E. coli $\\mathsf { D M } 5 \\alpha ( \\mathsf { p C P } 4 3 0 )$ incubated with protease (9) for 1 hour; 2, purified harpin $( 1 . 5 \\mu \\mathfrak { g } )$ from E. amylovora Ea321 incubated with protease for 1 hour; 3, same as 1, but without treatment with protease; 4, same as 2, but without treatment with protease; 5, CFEP (5 ${ \\pmb { \\mu } } ( { \\pmb q } )$ from E. coli DH5α(pCPP9) treated at $1 0 0 ^ { \\circ } \\mathbb { C }$ for 10'min; 6, CFEP $( 5 \\ \\pmb { \\mu } \\pmb { \\mu } )$ from E. coli DH5a(pCPP430K49) treated at $\\pmb { 1 0 0 } \\pmb { \\circ } \\pmb { \\subset }$ for 10 min; 7, CFEP $( 5 ~ \\mu 9 )$ from E. amylovora Ea321 treated " + ] + } }, { "type": "text", @@ -64,12 +69,13 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg", - "image_caption": [ - "Fig. 4. Subcellular location of elicitor protein. Logphase cells $( 1 . 5 m )$ of strain Ea321(pCPP430) were fractionated (31). Proteins from each fraction were electrophoresed and transferred to Immobilon-P membrane (Millipore, Bedford, Massachusetts). The Amplified Alkaline Phosphatase Immuno-Blot Assay Kit (170-6412, Bio-Rad Richmond, California) was ", - "used in a Western blot to detect the elicitor protein with an antiserum raised in rabbit in response to harpin (15). (A) Fractions in lanes: 1, periplasm; 2, membrane; 3, whole cells; 4, supernatant; 5, cytoplasm. (B) Harpin purified by high-performance liquid chromatography (19) hybridized with antiserum. Arrows indicates $4 4 \\ k \\mathsf { D }$ based on the molecular weight markers used in Fig. 3. (C) Normal serum control. CFEP from E. coli DH5a(pCPP430) hybridized with pre-immune serum. " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg", + "image_caption": [ + "Fig. 4. Subcellular location of elicitor protein. Logphase cells $( 1 . 5 m )$ of strain Ea321(pCPP430) were fractionated (31). Proteins from each fraction were electrophoresed and transferred to Immobilon-P membrane (Millipore, Bedford, Massachusetts). The Amplified Alkaline Phosphatase Immuno-Blot Assay Kit (170-6412, Bio-Rad Richmond, California) was ", + "used in a Western blot to detect the elicitor protein with an antiserum raised in rabbit in response to harpin (15). (A) Fractions in lanes: 1, periplasm; 2, membrane; 3, whole cells; 4, supernatant; 5, cytoplasm. (B) Harpin purified by high-performance liquid chromatography (19) hybridized with antiserum. Arrows indicates $4 4 \\ k \\mathsf { D }$ based on the molecular weight markers used in Fig. 3. (C) Normal serum control. CFEP from E. coli DH5a(pCPP430) hybridized with pre-immune serum. " + ] + } }, { "type": "text", @@ -77,10 +83,11 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg", - "image_caption": [ - "Fig. 5. Changes in pH of bathing solution of tobacco cell-suspension cultures (TCSC). Control values (no additive) were subtracted. Open squares, harpin (60 nM); open circles, cells of E. coli $\\mathsf { D H } 5 \\alpha ( \\mathsf { p C P P } 4 3 0 )$ $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliliter); filled squares, cells of E. amylovora Ea321 $( 5 \\times 1 0 ^ { 7 }$ cells per milliiter); triangles, cells of E. coli DH5α(pCPP430K49) $( 5 \\times 1 0 ^ { 7 }$ cells per milliter); diamonds, cells of $\\boldsymbol { \\varepsilon }$ amylovora Ea321K49 $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliter); filled circles, cells of $\\boldsymbol { E } .$ coli DH5α(pCPP9) $( 5 \\times$ $\\pmb { 1 0 ^ { 6 } }$ cells per mililiter). TCSCs were shaken at room temperature with the indicated preparations. The pH was measured at the intervals indicated. All preparations that elicited HR in tobacco leaves (Fig. 2) also caused a pH increase in the TCSC medium. " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg", + "image_caption": [ + "Fig. 5. Changes in pH of bathing solution of tobacco cell-suspension cultures (TCSC). Control values (no additive) were subtracted. Open squares, harpin (60 nM); open circles, cells of E. coli $\\mathsf { D H } 5 \\alpha ( \\mathsf { p C P P } 4 3 0 )$ $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliliter); filled squares, cells of E. amylovora Ea321 $( 5 \\times 1 0 ^ { 7 }$ cells per milliiter); triangles, cells of E. coli DH5α(pCPP430K49) $( 5 \\times 1 0 ^ { 7 }$ cells per milliter); diamonds, cells of $\\boldsymbol { \\varepsilon }$ amylovora Ea321K49 $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliter); filled circles, cells of $\\boldsymbol { E } .$ coli DH5α(pCPP9) $( 5 \\times$ $\\pmb { 1 0 ^ { 6 } }$ cells per mililiter). TCSCs were shaken at room temperature with the indicated preparations. The pH was measured at the intervals indicated. All preparations that elicited HR in tobacco leaves (Fig. 2) also caused a pH increase in the TCSC medium. " + ] + } } ] \ No newline at end of file diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 91d55fcd..5d2af735 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -12,8 +12,9 @@ class BaseReader(ABC): Abstract base class for reading and processing data. """ - def __init__(self, text_column: str = "content"): + def __init__(self, text_column: str = "content", modalities: list = None): self.text_column = text_column + self.modalities = modalities if modalities is not None else ["text"] @abstractmethod def read(self, input_path: Union[str, List[str]]) -> Dataset: diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3716aa9a..17a7216d 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -32,5 +32,5 @@ from .searcher.web.bing_search import BingSearch from .searcher.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage, RocksDBCache +from .storage import JsonKVStorage, NetworkXStorage, RocksDBCache from .tokenizer import Tokenizer diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index eefbdd1c..91b44862 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -77,8 +77,8 @@ async def generate( nodes, _ = batch for node in nodes: node_data = node[1] - if "images" in node_data and node_data["images"]: - img_path = node_data["images"]["img_path"] + if "image_data" in node_data and node_data["image_data"]: + img_path = node_data["image_data"]["img_path"] for qa in qa_pairs.values(): qa["img_path"] = img_path result.update(qa_pairs) diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index b53c8b1d..6752e042 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,7 +1,8 @@ +import json from typing import List, Union import ray -from ray.data import Dataset +import ray.data from graphgen.bases.base_reader import BaseReader @@ -14,14 +15,39 @@ class JSONReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, input_path: Union[str, List[str]]) -> Dataset: + def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: """ Read JSON file and return Ray Dataset. :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. :return: Ray Dataset containing validated and filtered data. """ - - ds = ray.data.read_json(input_path) + if self.modalities and len(self.modalities) >= 2: + ds: ray.data.Dataset = ray.data.from_items([]) + for file in input_path if isinstance(input_path, list) else [input_path]: + data = [] + if file.endswith(".jsonl"): + with open(file, "r", encoding="utf-8") as f: + for line in f: + item = json.loads(line) + data.append(item) + else: + with open(file, "r", encoding="utf-8") as f: + data = json.load(f) + data = self._unify_schema(data) + file_ds: ray.data.Dataset = ray.data.from_items(data) + ds = ds.union(file_ds) # type: ignore + else: + ds = ray.data.read_json(input_path) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds + + @staticmethod + def _unify_schema(data): + """ + Unify schema for JSON data. + """ + for item in data: + if "content" in item and isinstance(item["content"], dict): + item["content"] = json.dumps(item["content"]) + return data From 42dc73e1753aeb9d5bf531b4256615c98f9399e5 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 17:13:27 +0800 Subject: [PATCH 25/28] fix: delete fetching schema to avoid ray's duplicate execution --- graphgen/engine.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/graphgen/engine.py b/graphgen/engine.py index 7e7243c0..6d7e1051 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -84,10 +84,6 @@ def _get_input_dataset( main_ds = self.datasets[deps[0]] other_dss = [self.datasets[d] for d in deps[1:]] - if not all(ds.schema() == main_ds.schema() for ds in other_dss): - raise ValueError( - f"Union requires all datasets to have the same schema for node {node.id}" - ) return main_ds.union(*other_dss) def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): From 73f70a5f0df7804d14819eba9df369e2035df2ef Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 17:16:01 +0800 Subject: [PATCH 26/28] fix: fix operators' registry --- graphgen/operators/__init__.py | 8 ++++---- graphgen/operators/partition/partition_service.py | 3 ++- graphgen/operators/read/read.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 35b004fe..53600c3e 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,7 +1,7 @@ from .build_kg import BuildKGService from .chunk import ChunkService from .extract import extract -from .generate import generate_qas +from .generate import GenerateService from .judge import JudgeService from .partition import PartitionService from .quiz import QuizService @@ -14,8 +14,8 @@ "build_kg": BuildKGService, "quiz": QuizService, "judge": JudgeService, - "extract_info": extract, - "search_all": search_all, + "extract": extract, + "search": search_all, "partition": PartitionService, - "generate_qas": generate_qas, + "generate": GenerateService, } diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index a914a7ac..cb5ca608 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -1,3 +1,4 @@ +import json import os from typing import Iterable @@ -149,7 +150,7 @@ def _attach_additional_data_to_node(self, batch: tuple) -> tuple: if image_chunks: # The generator expects a dictionary with an 'img_path' key, not a list of captions. # We'll use the first image chunk found for this node. - node_data["images"] = image_chunks[0] + node_data["image_data"] = json.loads(image_chunks[0]["content"]) logger.debug("Attached image data to node %s", node_id) return nodes_data, edges_data diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index 378316f8..fbed377e 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -85,7 +85,7 @@ def read( logger.info("[READ] Found %d files to process", len(all_files)) if not all_files: - return ray.data.from_items([]) + raise ValueError("No files found to read.") # 2. Group files by suffix to use appropriate reader files_by_suffix = {} @@ -116,7 +116,7 @@ def read( combined_ds = combined_ds.map( lambda record: { **record, - "_doc_id": compute_mm_hash(record), + "_doc_id": compute_mm_hash(record, prefix="doc-"), } ) From 37cbfcfdb8c727859d5eb496afcba14cce7c5b6f Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 11 Dec 2025 19:26:59 +0800 Subject: [PATCH 27/28] feat: refactor schema_guided_extraction & add examples --- examples/configs/README.md | 1 - .../extract/extract_schema_guided/README.md | 1 + .../extract_schema_guided.sh | 2 +- .../schema_guided_extraction_config.yaml | 42 +++++++++++------ .../search_dna_config.yaml | 0 .../search_protein_config.yaml | 0 .../search_rna_config.yaml | 0 .../extractor/schema_guided_extractor.py | 8 ++-- graphgen/models/reader/txt_reader.py | 7 ++- graphgen/operators/__init__.py | 4 +- graphgen/operators/extract/__init__.py | 2 +- graphgen/operators/extract/extract.py | 47 ------------------- graphgen/operators/extract/extract_service.py | 44 +++++++++++++++++ 13 files changed, 85 insertions(+), 73 deletions(-) delete mode 100644 examples/configs/README.md create mode 100644 examples/extract/extract_schema_guided/README.md rename examples/{configs => search}/search_dna_config.yaml (100%) rename examples/{configs => search}/search_protein_config.yaml (100%) rename examples/{configs => search}/search_rna_config.yaml (100%) delete mode 100644 graphgen/operators/extract/extract.py create mode 100644 graphgen/operators/extract/extract_service.py diff --git a/examples/configs/README.md b/examples/configs/README.md deleted file mode 100644 index afa815cd..00000000 --- a/examples/configs/README.md +++ /dev/null @@ -1 +0,0 @@ -# Configs for GraphGen diff --git a/examples/extract/extract_schema_guided/README.md b/examples/extract/extract_schema_guided/README.md new file mode 100644 index 00000000..ab117c0f --- /dev/null +++ b/examples/extract/extract_schema_guided/README.md @@ -0,0 +1 @@ +# Extract Schema-Guided Information from Documents diff --git a/examples/extract/extract_schema_guided/extract_schema_guided.sh b/examples/extract/extract_schema_guided/extract_schema_guided.sh index 0badc174..6ffd0fde 100644 --- a/examples/extract/extract_schema_guided/extract_schema_guided.sh +++ b/examples/extract/extract_schema_guided/extract_schema_guided.sh @@ -1,3 +1,3 @@ python3 -m graphgen.run \ ---config_file graphgen/configs/schema_guided_extraction_config.yaml \ +--config_file examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml \ --output_dir cache/ diff --git a/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml index 8d142ef6..7bd359b3 100644 --- a/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml +++ b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml @@ -1,20 +1,34 @@ -pipeline: - - name: read_step - op_key: read +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] params: - input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples + input_path: + - examples/input_examples/extract_demo.txt - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 params: - chunk_size: 20480 + chunk_size: 20480 # larger chunk size for better context chunk_overlap: 2000 - separators: [] - - name: extract_step - op_key: extract - deps: [chunk_step] # extract_step depends on chunk_step + - id: extract + op_name: extract + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 params: - method: schema_guided # extraction method, support: schema_guided - schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method + method: schema_guided + schema_path: graphgen/templates/extraction/schemas/legal_contract.json diff --git a/examples/configs/search_dna_config.yaml b/examples/search/search_dna_config.yaml similarity index 100% rename from examples/configs/search_dna_config.yaml rename to examples/search/search_dna_config.yaml diff --git a/examples/configs/search_protein_config.yaml b/examples/search/search_protein_config.yaml similarity index 100% rename from examples/configs/search_protein_config.yaml rename to examples/search/search_protein_config.yaml diff --git a/examples/configs/search_rna_config.yaml b/examples/search/search_rna_config.yaml similarity index 100% rename from examples/configs/search_rna_config.yaml rename to examples/search/search_rna_config.yaml diff --git a/graphgen/models/extractor/schema_guided_extractor.py b/graphgen/models/extractor/schema_guided_extractor.py index 70c45502..74801946 100644 --- a/graphgen/models/extractor/schema_guided_extractor.py +++ b/graphgen/models/extractor/schema_guided_extractor.py @@ -60,8 +60,8 @@ def build_prompt(self, text: str) -> str: return prompt async def extract(self, chunk: dict) -> dict: - _chunk_id = list(chunk.keys())[0] - text = chunk[_chunk_id].get("content", "") + _chunk_id = chunk.get("_chunk_id", "") + text = chunk.get("content", "") prompt = self.build_prompt(text) response = await self.llm_client.generate_answer(prompt) @@ -88,9 +88,7 @@ async def extract(self, chunk: dict) -> dict: return {} @staticmethod - async def merge_extractions( - extraction_list: List[Dict[str, dict]] - ) -> Dict[str, dict]: + def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]: """ Merge multiple extraction results based on their hashes. :param extraction_list: List of extraction results, each is a dict with hash as key and record as value. diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index 0194ca68..51a47de2 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -16,12 +16,15 @@ def read( :param input_path: Path to the input text file or list of text files. :return: Ray Dataset containing the read text data. """ - docs_ds = ray.data.read_text(input_path, encoding="utf-8") + docs_ds = ray.data.read_binary_files( + input_path, + include_paths=False, + ) docs_ds = docs_ds.map( lambda row: { "type": "text", - self.text_column: row["text"], + self.text_column: row["bytes"].decode("utf-8"), } ) diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 53600c3e..64c78af5 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,6 +1,6 @@ from .build_kg import BuildKGService from .chunk import ChunkService -from .extract import extract +from .extract import ExtractService from .generate import GenerateService from .judge import JudgeService from .partition import PartitionService @@ -14,7 +14,7 @@ "build_kg": BuildKGService, "quiz": QuizService, "judge": JudgeService, - "extract": extract, + "extract": ExtractService, "search": search_all, "partition": PartitionService, "generate": GenerateService, diff --git a/graphgen/operators/extract/__init__.py b/graphgen/operators/extract/__init__.py index d46dcdf1..6c7c2b94 100644 --- a/graphgen/operators/extract/__init__.py +++ b/graphgen/operators/extract/__init__.py @@ -1 +1 @@ -from .extract import extract +from .extract_service import ExtractService diff --git a/graphgen/operators/extract/extract.py b/graphgen/operators/extract/extract.py deleted file mode 100644 index ab69af40..00000000 --- a/graphgen/operators/extract/extract.py +++ /dev/null @@ -1,47 +0,0 @@ -import json - -import gradio as gr - -from graphgen.bases import BaseKVStorage, BaseLLMWrapper -from graphgen.models.extractor import SchemaGuidedExtractor -from graphgen.utils import logger, run_concurrent - - -async def extract( - llm_client: BaseLLMWrapper, - chunk_storage: BaseKVStorage, - extract_config: dict, - progress_bar: gr.Progress = None, -): - """ - Extract information from chunks - :param llm_client: LLM client - :param chunk_storage: storage for chunks - :param extract_config - :param progress_bar - :return: extracted information - """ - - method = extract_config.get("method") - if method == "schema_guided": - schema_file = extract_config.get("schema_file") - with open(schema_file, "r", encoding="utf-8") as f: - schema = json.load(f) - extractor = SchemaGuidedExtractor(llm_client, schema) - else: - raise ValueError(f"Unsupported extraction method: {method}") - - chunks = chunk_storage.get_all() - chunks = [{k: v} for k, v in chunks.items()] - logger.info("Start extracting information from %d chunks", len(chunks)) - - results = await run_concurrent( - extractor.extract, - chunks, - desc="Extracting information", - unit="chunk", - progress_bar=progress_bar, - ) - - results = await extractor.merge_extractions(results) - return results diff --git a/graphgen/operators/extract/extract_service.py b/graphgen/operators/extract/extract_service.py new file mode 100644 index 00000000..ee00f4d4 --- /dev/null +++ b/graphgen/operators/extract/extract_service.py @@ -0,0 +1,44 @@ +import json + +import pandas as pd + +from graphgen.bases import BaseLLMWrapper +from graphgen.common import init_llm +from graphgen.models.extractor import SchemaGuidedExtractor +from graphgen.utils import logger, run_concurrent + + +class ExtractService: + def __init__(self, working_dir: str = "cache", **extract_kwargs): + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.extract_kwargs = extract_kwargs + self.method = self.extract_kwargs.get("method") + if self.method == "schema_guided": + schema_file = self.extract_kwargs.get("schema_path") + with open(schema_file, "r", encoding="utf-8") as f: + schema = json.load(f) + self.extractor = SchemaGuidedExtractor(self.llm_client, schema) + else: + raise ValueError(f"Unsupported extraction method: {self.method}") + + def __call__(self, batches: pd.DataFrame) -> pd.DataFrame: + items = batches.to_dict(orient="records") + return pd.DataFrame(self.extract(items)) + + def extract(self, items: list[dict]) -> list[dict]: + + logger.info("Start extracting information from %d items", len(items)) + + results = run_concurrent( + self.extractor.extract, + items, + desc="Extracting information", + unit="item", + ) + results = self.extractor.merge_extractions(results) + + results = [ + {"_extract_id": key, "extracted_data": value} + for key, value in results.items() + ] + return results From b400d2ec53f51ca1da507966b8661958ac54d9cd Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Fri, 12 Dec 2025 15:23:26 +0800 Subject: [PATCH 28/28] feat: seperate ray logs and service logs --- graphgen/bases/__init__.py | 8 +- graphgen/bases/base_operator.py | 57 ++ graphgen/bases/base_storage.py | 6 + graphgen/graphgen.py | 590 +++++++++--------- .../models/storage/graph/networkx_storage.py | 31 +- graphgen/models/storage/kv/json_storage.py | 6 +- .../operators/build_kg/build_kg_service.py | 7 +- graphgen/operators/chunk/chunk_service.py | 6 +- graphgen/operators/extract/extract_service.py | 9 +- .../operators/generate/generate_service.py | 16 +- graphgen/operators/judge/judge_service.py | 7 +- .../operators/partition/partition_service.py | 7 +- graphgen/operators/quiz/quiz_service.py | 7 +- .../operators/read/parallel_file_scanner.py | 30 +- graphgen/run.py | 11 +- graphgen/utils/__init__.py | 2 +- graphgen/utils/log.py | 71 ++- 17 files changed, 482 insertions(+), 389 deletions(-) create mode 100644 graphgen/bases/base_operator.py diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index f4c3a0e8..41136974 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -2,15 +2,11 @@ from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_wrapper import BaseLLMWrapper +from .base_operator import BaseOperator from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_searcher import BaseSearcher from .base_splitter import BaseSplitter -from .base_storage import ( - BaseGraphStorage, - BaseKVStorage, - BaseListStorage, - StorageNameSpace, -) +from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py new file mode 100644 index 00000000..300d3178 --- /dev/null +++ b/graphgen/bases/base_operator.py @@ -0,0 +1,57 @@ +import inspect +import os +from abc import ABC, abstractmethod +from typing import Iterable, Union + +import pandas as pd +import ray + +from graphgen.utils import CURRENT_LOGGER_VAR, set_logger + + +class BaseOperator(ABC): + def __init__(self, working_dir: str = "cache", op_name: str = None): + log_dir = os.path.join(working_dir, "logs") + self.op_name = op_name or self.__class__.__name__ + + try: + ctx = ray.get_runtime_context() + worker_id = ctx.get_actor_id() or ctx.get_worker_id() + worker_id_short = worker_id[-6:] if worker_id else "driver" + except Exception as e: + print( + "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:", + e, + ) + worker_id_short = "local" + + # e.g. cache/logs/ChunkService_a1b2c3.log + log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log") + + self.logger = set_logger( + log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True + ) + + self.logger.info( + "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short + ) + + def __call__( + self, batch: pd.DataFrame + ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: + logger_token = CURRENT_LOGGER_VAR.set(self.logger) + try: + result = self.process(batch) + if inspect.isgenerator(result): + yield from result + else: + yield result + finally: + CURRENT_LOGGER_VAR.reset(logger_token) + + @abstractmethod + def process(self, batch): + raise NotImplementedError("Subclasses must implement the process method.") + + def get_logger(self): + return self.logger diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index 53610a5d..ff7d2d1a 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -41,6 +41,9 @@ def upsert(self, data: dict[str, T]): def drop(self): raise NotImplementedError + def reload(self): + raise NotImplementedError + class BaseGraphStorage(StorageNameSpace): def has_node(self, node_id: str) -> bool: @@ -88,3 +91,6 @@ def upsert_edge( def delete_node(self, node_id: str): raise NotImplementedError + + def reload(self): + raise NotImplementedError diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index bc7e7742..56e97469 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,295 +1,295 @@ -import os -import time -from typing import Dict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.datatypes import Chunk -from graphgen.models import ( - JsonKVStorage, - JsonListStorage, - NetworkXStorage, - OpenAIClient, - Tokenizer, -) -from graphgen.operators import ( - build_kg, - chunk_documents, - extract_info, - generate_qas, - init_llm, - judge_statement, - partition_kg, - quiz, - read_files, - search_all, -) -from graphgen.utils import async_to_sync_method, compute_mm_hash, logger - -sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - -class GraphGen: - def __init__( - self, - unique_id: int = int(time.time()), - working_dir: str = os.path.join(sys_path, "cache"), - tokenizer_instance: Tokenizer = None, - synthesizer_llm_client: OpenAIClient = None, - trainee_llm_client: OpenAIClient = None, - progress_bar: gr.Progress = None, - ): - self.unique_id: int = unique_id - self.working_dir: str = working_dir - - # llm - self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( - model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") - ) - - self.synthesizer_llm_client: BaseLLMWrapper = ( - synthesizer_llm_client or init_llm("synthesizer") - ) - self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client - - self.full_docs_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="full_docs" - ) - self.chunks_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="chunks" - ) - self.graph_storage: NetworkXStorage = NetworkXStorage( - self.working_dir, namespace="graph" - ) - self.rephrase_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="rephrase" - ) - self.partition_storage: JsonListStorage = JsonListStorage( - self.working_dir, namespace="partition" - ) - self.search_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="search", - ) - self.qa_storage: JsonListStorage = JsonListStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="qa", - ) - self.extract_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="extraction", - ) - - # webui - self.progress_bar: gr.Progress = progress_bar - - @async_to_sync_method - async def read(self, read_config: Dict): - """ - read files from input sources - """ - doc_stream = read_files(**read_config, cache_dir=self.working_dir) - - batch = {} - for doc in doc_stream: - doc_id = compute_mm_hash(doc, prefix="doc-") - batch[doc_id] = doc - - # TODO: configurable whether to use coreference resolution - - _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) - new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - self.full_docs_storage.upsert(new_docs) - self.full_docs_storage.index_done_callback() - - @async_to_sync_method - async def chunk(self, chunk_config: Dict): - """ - chunk documents into smaller pieces from full_docs_storage if not already present - """ - - new_docs = self.full_docs_storage.get_all() - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - inserting_chunks = await chunk_documents( - new_docs, - self.tokenizer_instance, - self.progress_bar, - **chunk_config, - ) - - _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - self.chunks_storage.upsert(inserting_chunks) - self.chunks_storage.index_done_callback() - - @async_to_sync_method - async def build_kg(self): - """ - build knowledge graph from text chunks - """ - # Step 1: get new chunks - inserting_chunks = self.chunks_storage.get_all() - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) - # Step 2: build knowledge graph from new chunks - _add_entities_and_relations = await build_kg( - llm_client=self.synthesizer_llm_client, - kg_instance=self.graph_storage, - chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], - progress_bar=self.progress_bar, - ) - if not _add_entities_and_relations: - logger.warning("No entities or relations extracted from text chunks") - return - - # Step 3: upsert new entities and relations to the graph storage - self.graph_storage.index_done_callback() - - return _add_entities_and_relations - - @async_to_sync_method - async def search(self, search_config: Dict): - logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - - seeds = self.full_docs_storage.get_all() - if len(seeds) == 0: - logger.warning("All documents are already been searched") - return - search_results = await search_all( - seed_data=seeds, - search_config=search_config, - ) - - _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) - search_results = { - k: v for k, v in search_results.items() if k in _add_search_keys - } - if len(search_results) == 0: - logger.warning("All search results are already in the storage") - return - self.search_storage.upsert(search_results) - self.search_storage.index_done_callback() - - @async_to_sync_method - async def quiz_and_judge(self, quiz_and_judge_config: Dict): - logger.warning( - "Quiz and Judge operation needs trainee LLM client." - " Make sure to provide one." - ) - max_samples = quiz_and_judge_config["quiz_samples"] - await quiz( - self.synthesizer_llm_client, - self.graph_storage, - self.rephrase_storage, - max_samples, - progress_bar=self.progress_bar, - ) - - # TODO: assert trainee_llm_client is valid before judge - if not self.trainee_llm_client: - # TODO: shutdown existing synthesizer_llm_client properly - logger.info("No trainee LLM client provided, initializing a new one.") - self.synthesizer_llm_client.shutdown() - self.trainee_llm_client = init_llm("trainee") - - re_judge = quiz_and_judge_config["re_judge"] - _update_relations = await judge_statement( - self.trainee_llm_client, - self.graph_storage, - self.rephrase_storage, - re_judge, - progress_bar=self.progress_bar, - ) - - self.rephrase_storage.index_done_callback() - _update_relations.index_done_callback() - - logger.info("Shutting down trainee LLM client.") - self.trainee_llm_client.shutdown() - self.trainee_llm_client = None - logger.info("Restarting synthesizer LLM client.") - self.synthesizer_llm_client.restart() - - @async_to_sync_method - async def partition(self, partition_config: Dict): - batches = await partition_kg( - self.graph_storage, - self.chunks_storage, - self.tokenizer_instance, - partition_config, - ) - self.partition_storage.upsert(batches) - return batches - - @async_to_sync_method - async def extract(self, extract_config: Dict): - logger.info("Extracting information from given chunks...") - - results = await extract_info( - self.synthesizer_llm_client, - self.chunks_storage, - extract_config, - progress_bar=self.progress_bar, - ) - if not results: - logger.warning("No information extracted") - return - - self.extract_storage.upsert(results) - self.extract_storage.index_done_callback() - - @async_to_sync_method - async def generate(self, generate_config: Dict): - - batches = self.partition_storage.data - if not batches: - logger.warning("No partitions found for QA generation") - return - - # Step 2: generate QA pairs - results = await generate_qas( - self.synthesizer_llm_client, - batches, - generate_config, - progress_bar=self.progress_bar, - ) - - if not results: - logger.warning("No QA pairs generated") - return - - # Step 3: store the generated QA pairs - self.qa_storage.upsert(results) - self.qa_storage.index_done_callback() - - @async_to_sync_method - async def clear(self): - self.full_docs_storage.drop() - self.chunks_storage.drop() - self.search_storage.drop() - self.graph_storage.clear() - self.rephrase_storage.drop() - self.qa_storage.drop() - - logger.info("All caches are cleared") - - # TODO: add data filtering step here in the future - # graph_gen.filter(filter_config=config["filter"]) +# import os +# import time +# from typing import Dict +# +# import gradio as gr +# +# from graphgen.bases import BaseLLMWrapper +# from graphgen.bases.datatypes import Chunk +# from graphgen.models import ( +# JsonKVStorage, +# JsonListStorage, +# NetworkXStorage, +# OpenAIClient, +# Tokenizer, +# ) +# from graphgen.operators import ( +# build_kg, +# chunk_documents, +# extract_info, +# generate_qas, +# init_llm, +# judge_statement, +# partition_kg, +# quiz, +# read_files, +# search_all, +# ) +# from graphgen.utils import async_to_sync_method, compute_mm_hash, logger +# +# sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +# +# +# class GraphGen: +# def __init__( +# self, +# unique_id: int = int(time.time()), +# working_dir: str = os.path.join(sys_path, "cache"), +# tokenizer_instance: Tokenizer = None, +# synthesizer_llm_client: OpenAIClient = None, +# trainee_llm_client: OpenAIClient = None, +# progress_bar: gr.Progress = None, +# ): +# self.unique_id: int = unique_id +# self.working_dir: str = working_dir +# +# # llm +# self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( +# model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") +# ) +# +# self.synthesizer_llm_client: BaseLLMWrapper = ( +# synthesizer_llm_client or init_llm("synthesizer") +# ) +# self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client +# +# self.full_docs_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="full_docs" +# ) +# self.chunks_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="chunks" +# ) +# self.graph_storage: NetworkXStorage = NetworkXStorage( +# self.working_dir, namespace="graph" +# ) +# self.rephrase_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="rephrase" +# ) +# self.partition_storage: JsonListStorage = JsonListStorage( +# self.working_dir, namespace="partition" +# ) +# self.search_storage: JsonKVStorage = JsonKVStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="search", +# ) +# self.qa_storage: JsonListStorage = JsonListStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="qa", +# ) +# self.extract_storage: JsonKVStorage = JsonKVStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="extraction", +# ) +# +# # webui +# self.progress_bar: gr.Progress = progress_bar +# +# @async_to_sync_method +# async def read(self, read_config: Dict): +# """ +# read files from input sources +# """ +# doc_stream = read_files(**read_config, cache_dir=self.working_dir) +# +# batch = {} +# for doc in doc_stream: +# doc_id = compute_mm_hash(doc, prefix="doc-") +# batch[doc_id] = doc +# +# # TODO: configurable whether to use coreference resolution +# +# _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) +# new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} +# if len(new_docs) == 0: +# logger.warning("All documents are already in the storage") +# return +# self.full_docs_storage.upsert(new_docs) +# self.full_docs_storage.index_done_callback() +# +# @async_to_sync_method +# async def chunk(self, chunk_config: Dict): +# """ +# chunk documents into smaller pieces from full_docs_storage if not already present +# """ +# +# new_docs = self.full_docs_storage.get_all() +# if len(new_docs) == 0: +# logger.warning("All documents are already in the storage") +# return +# +# inserting_chunks = await chunk_documents( +# new_docs, +# self.tokenizer_instance, +# self.progress_bar, +# **chunk_config, +# ) +# +# _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) +# inserting_chunks = { +# k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys +# } +# +# if len(inserting_chunks) == 0: +# logger.warning("All chunks are already in the storage") +# return +# +# self.chunks_storage.upsert(inserting_chunks) +# self.chunks_storage.index_done_callback() +# +# @async_to_sync_method +# async def build_kg(self): +# """ +# build knowledge graph from text chunks +# """ +# # Step 1: get new chunks +# inserting_chunks = self.chunks_storage.get_all() +# +# if len(inserting_chunks) == 0: +# logger.warning("All chunks are already in the storage") +# return +# +# logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) +# # Step 2: build knowledge graph from new chunks +# _add_entities_and_relations = await build_kg( +# llm_client=self.synthesizer_llm_client, +# kg_instance=self.graph_storage, +# chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], +# progress_bar=self.progress_bar, +# ) +# if not _add_entities_and_relations: +# logger.warning("No entities or relations extracted from text chunks") +# return +# +# # Step 3: upsert new entities and relations to the graph storage +# self.graph_storage.index_done_callback() +# +# return _add_entities_and_relations +# +# @async_to_sync_method +# async def search(self, search_config: Dict): +# logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) +# +# seeds = self.full_docs_storage.get_all() +# if len(seeds) == 0: +# logger.warning("All documents are already been searched") +# return +# search_results = await search_all( +# seed_data=seeds, +# search_config=search_config, +# ) +# +# _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) +# search_results = { +# k: v for k, v in search_results.items() if k in _add_search_keys +# } +# if len(search_results) == 0: +# logger.warning("All search results are already in the storage") +# return +# self.search_storage.upsert(search_results) +# self.search_storage.index_done_callback() +# +# @async_to_sync_method +# async def quiz_and_judge(self, quiz_and_judge_config: Dict): +# logger.warning( +# "Quiz and Judge operation needs trainee LLM client." +# " Make sure to provide one." +# ) +# max_samples = quiz_and_judge_config["quiz_samples"] +# await quiz( +# self.synthesizer_llm_client, +# self.graph_storage, +# self.rephrase_storage, +# max_samples, +# progress_bar=self.progress_bar, +# ) +# +# # TODO: assert trainee_llm_client is valid before judge +# if not self.trainee_llm_client: +# # TODO: shutdown existing synthesizer_llm_client properly +# logger.info("No trainee LLM client provided, initializing a new one.") +# self.synthesizer_llm_client.shutdown() +# self.trainee_llm_client = init_llm("trainee") +# +# re_judge = quiz_and_judge_config["re_judge"] +# _update_relations = await judge_statement( +# self.trainee_llm_client, +# self.graph_storage, +# self.rephrase_storage, +# re_judge, +# progress_bar=self.progress_bar, +# ) +# +# self.rephrase_storage.index_done_callback() +# _update_relations.index_done_callback() +# +# logger.info("Shutting down trainee LLM client.") +# self.trainee_llm_client.shutdown() +# self.trainee_llm_client = None +# logger.info("Restarting synthesizer LLM client.") +# self.synthesizer_llm_client.restart() +# +# @async_to_sync_method +# async def partition(self, partition_config: Dict): +# batches = await partition_kg( +# self.graph_storage, +# self.chunks_storage, +# self.tokenizer_instance, +# partition_config, +# ) +# self.partition_storage.upsert(batches) +# return batches +# +# @async_to_sync_method +# async def extract(self, extract_config: Dict): +# logger.info("Extracting information from given chunks...") +# +# results = await extract_info( +# self.synthesizer_llm_client, +# self.chunks_storage, +# extract_config, +# progress_bar=self.progress_bar, +# ) +# if not results: +# logger.warning("No information extracted") +# return +# +# self.extract_storage.upsert(results) +# self.extract_storage.index_done_callback() +# +# @async_to_sync_method +# async def generate(self, generate_config: Dict): +# +# batches = self.partition_storage.data +# if not batches: +# logger.warning("No partitions found for QA generation") +# return +# +# # Step 2: generate QA pairs +# results = await generate_qas( +# self.synthesizer_llm_client, +# batches, +# generate_config, +# progress_bar=self.progress_bar, +# ) +# +# if not results: +# logger.warning("No QA pairs generated") +# return +# +# # Step 3: store the generated QA pairs +# self.qa_storage.upsert(results) +# self.qa_storage.index_done_callback() +# +# @async_to_sync_method +# async def clear(self): +# self.full_docs_storage.drop() +# self.chunks_storage.drop() +# self.search_storage.drop() +# self.graph_storage.clear() +# self.rephrase_storage.drop() +# self.qa_storage.drop() +# +# logger.info("All caches are cleared") +# +# # TODO: add data filtering step here in the future +# # graph_gen.filter(filter_config=config["filter"]) diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py index 28024535..7fb73b79 100644 --- a/graphgen/models/storage/graph/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -6,7 +6,6 @@ import networkx as nx from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.utils import logger @dataclass @@ -19,11 +18,6 @@ def load_nx_graph(file_name) -> Optional[nx.Graph]: @staticmethod def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - "Writing graph with %d nodes, %d edges", - graph.number_of_nodes(), - graph.number_of_edges(), - ) nx.write_graphml(graph, file_name) @staticmethod @@ -82,12 +76,11 @@ def __post_init__(self): self.working_dir, f"{self.namespace}.graphml" ) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - "Loaded graph from %s with %d nodes, %d edges", - self._graphml_xml_file, - preloaded_graph.number_of_nodes(), - preloaded_graph.number_of_edges(), + if preloaded_graph: + print( + f"Loaded graph from {self._graphml_xml_file} with " + f"{preloaded_graph.number_of_nodes()} nodes, " + f"{preloaded_graph.number_of_edges()} edges" ) self._graph = preloaded_graph or nx.Graph() @@ -133,7 +126,7 @@ def update_node(self, node_id: str, node_data: dict[str, str]): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: - logger.warning("Node %s not found in the graph for update.", node_id) + print(f"Node {node_id} not found in the graph for update.") def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] @@ -146,10 +139,8 @@ def update_edge( if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) else: - logger.warning( - "Edge %s -> %s not found in the graph for update.", - source_node_id, - target_node_id, + print( + f"Edge {source_node_id} -> {target_node_id} not found in the graph for update." ) def delete_node(self, node_id: str): @@ -160,16 +151,16 @@ def delete_node(self, node_id: str): """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) - logger.info("Node %s deleted from the graph.", node_id) + print(f"Node {node_id} deleted from the graph.") else: - logger.warning("Node %s not found in the graph for deletion.", node_id) + print(f"Node {node_id} not found in the graph for deletion.") def clear(self): """ Clear the graph by removing all nodes and edges. """ self._graph.clear() - logger.info("Graph %s cleared.", self.namespace) + print(f"Graph {self.namespace} cleared.") def reload(self): """ diff --git a/graphgen/models/storage/kv/json_storage.py b/graphgen/models/storage/kv/json_storage.py index f0b6c995..aa7c6f42 100644 --- a/graphgen/models/storage/kv/json_storage.py +++ b/graphgen/models/storage/kv/json_storage.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from graphgen.bases.base_storage import BaseKVStorage -from graphgen.utils import load_json, logger, write_json +from graphgen.utils import load_json, write_json @dataclass @@ -12,7 +12,7 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") self._data = load_json(self._file_name) or {} - logger.info("Load KV %s with %d data", self.namespace, len(self._data)) + print(f"Load KV {self.namespace} with {len(self._data)} data") @property def data(self): @@ -57,4 +57,4 @@ def drop(self): def reload(self): self._data = load_json(self._file_name) or {} - logger.info("Reload KV %s with %d data", self.namespace, len(self._data)) + print(f"Reload KV {self.namespace} with {len(self._data)} data") diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py index c6842089..0ee54a80 100644 --- a/graphgen/operators/build_kg/build_kg_service.py +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -2,7 +2,7 @@ import pandas as pd -from graphgen.bases import BaseGraphStorage, BaseLLMWrapper +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator from graphgen.bases.datatypes import Chunk from graphgen.common import init_llm, init_storage from graphgen.utils import logger @@ -11,14 +11,15 @@ from .build_text_kg import build_text_kg -class BuildKGService: +class BuildKGService(BaseOperator): def __init__(self, working_dir: str = "cache"): + super().__init__(working_dir=working_dir, op_name="build_kg_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( backend="networkx", working_dir=working_dir, namespace="graph" ) - def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + def process(self, batch: pd.DataFrame) -> pd.DataFrame: docs = batch.to_dict(orient="records") docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index 307833ba..abd72e54 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -4,6 +4,7 @@ import pandas as pd +from graphgen.bases import BaseOperator from graphgen.common import init_storage from graphgen.models import ( ChineseRecursiveTextSplitter, @@ -40,8 +41,9 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list: return splitter.split_text(text) -class ChunkService: +class ChunkService(BaseOperator): def __init__(self, working_dir: str = "cache", **chunk_kwargs): + super().__init__(working_dir=working_dir, op_name="chunk_service") tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) self.chunk_storage = init_storage( @@ -51,7 +53,7 @@ def __init__(self, working_dir: str = "cache", **chunk_kwargs): ) self.chunk_kwargs = chunk_kwargs - def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + def process(self, batch: pd.DataFrame) -> pd.DataFrame: docs = batch.to_dict(orient="records") return pd.DataFrame(self.chunk_documents(docs)) diff --git a/graphgen/operators/extract/extract_service.py b/graphgen/operators/extract/extract_service.py index ee00f4d4..33987fcb 100644 --- a/graphgen/operators/extract/extract_service.py +++ b/graphgen/operators/extract/extract_service.py @@ -2,14 +2,15 @@ import pandas as pd -from graphgen.bases import BaseLLMWrapper +from graphgen.bases import BaseLLMWrapper, BaseOperator from graphgen.common import init_llm from graphgen.models.extractor import SchemaGuidedExtractor from graphgen.utils import logger, run_concurrent -class ExtractService: +class ExtractService(BaseOperator): def __init__(self, working_dir: str = "cache", **extract_kwargs): + super().__init__(working_dir=working_dir, op_name="extract_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.extract_kwargs = extract_kwargs self.method = self.extract_kwargs.get("method") @@ -21,8 +22,8 @@ def __init__(self, working_dir: str = "cache", **extract_kwargs): else: raise ValueError(f"Unsupported extraction method: {self.method}") - def __call__(self, batches: pd.DataFrame) -> pd.DataFrame: - items = batches.to_dict(orient="records") + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") return pd.DataFrame(self.extract(items)) def extract(self, items: list[dict]) -> list[dict]: diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 8b0f78e9..1ae2f067 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -1,6 +1,6 @@ import pandas as pd -from graphgen.bases import BaseLLMWrapper +from graphgen.bases import BaseLLMWrapper, BaseOperator from graphgen.common import init_llm from graphgen.models import ( AggregatedGenerator, @@ -12,12 +12,18 @@ from graphgen.utils import logger, run_concurrent -class GenerateService: +class GenerateService(BaseOperator): """ Generate question-answer pairs based on nodes and edges. """ - def __init__(self, method: str = "aggregated", data_format: str = "ChatML"): + def __init__( + self, + working_dir: str = "cache", + method: str = "aggregated", + data_format: str = "ChatML", + ): + super().__init__(working_dir=working_dir, op_name="generate_service") self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.method = method @@ -36,8 +42,8 @@ def __init__(self, method: str = "aggregated", data_format: str = "ChatML"): else: raise ValueError(f"Unsupported generation mode: {method}") - def __call__(self, batches: pd.DataFrame) -> pd.DataFrame: - items = batches.to_dict(orient="records") + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") return pd.DataFrame(self.generate(items)) def generate(self, items: list[dict]) -> list[dict]: diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py index 16e8af4c..4d554a0b 100644 --- a/graphgen/operators/judge/judge_service.py +++ b/graphgen/operators/judge/judge_service.py @@ -2,16 +2,17 @@ import pandas as pd -from graphgen.bases import BaseGraphStorage, BaseLLMWrapper +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator from graphgen.common import init_llm, init_storage from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy -class JudgeService: +class JudgeService(BaseOperator): """Service for judging graph edges and nodes using a trainee LLM.""" def __init__(self, working_dir: str = "cache"): + super().__init__(working_dir=working_dir, op_name="judge_service") self.llm_client: BaseLLMWrapper = init_llm("trainee") self.graph_storage: BaseGraphStorage = init_storage( backend="networkx", @@ -19,7 +20,7 @@ def __init__(self, working_dir: str = "cache"): namespace="graph", ) - def __call__(self, batch: pd.DataFrame) -> pd.DataFrame: + def process(self, batch: pd.DataFrame) -> pd.DataFrame: items = batch.to_dict(orient="records") self.graph_storage.reload() self.judge(items) diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index cb5ca608..b4c0eda0 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -4,7 +4,7 @@ import pandas as pd -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseOperator, BaseTokenizer from graphgen.common import init_storage from graphgen.models import ( AnchorBFSPartitioner, @@ -17,8 +17,9 @@ from graphgen.utils import logger -class PartitionService: +class PartitionService(BaseOperator): def __init__(self, working_dir: str = "cache", **partition_kwargs): + super().__init__(working_dir=working_dir, op_name="partition_service") self.kg_instance: BaseGraphStorage = init_storage( backend="networkx", working_dir=working_dir, @@ -33,7 +34,7 @@ def __init__(self, working_dir: str = "cache", **partition_kwargs): self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) self.partition_kwargs = partition_kwargs - def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: # this operator does not consume any batch data # but for compatibility we keep the interface _ = batch.to_dict(orient="records") diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py index 9bbe99a3..a5e1baf5 100644 --- a/graphgen/operators/quiz/quiz_service.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -2,19 +2,20 @@ import pandas as pd -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator from graphgen.common import init_llm, init_storage from graphgen.models import QuizGenerator from graphgen.utils import compute_dict_hash, logger, run_concurrent -class QuizService: +class QuizService(BaseOperator): def __init__( self, working_dir: str = "cache", quiz_samples: int = 1, concurrency_limit: int = 200, ): + super().__init__(working_dir=working_dir, op_name="quiz_service") self.quiz_samples = quiz_samples self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( @@ -27,7 +28,7 @@ def __init__( self.generator = QuizGenerator(self.llm_client) self.concurrency_limit = concurrency_limit - def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: # this operator does not consume any batch data # but for compatibility we keep the interface _ = batch.to_dict(orient="records") diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 73b477c3..db50d7af 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Set, Union from graphgen.models import RocksDBCache -from graphgen.utils import logger class ParallelFileScanner: @@ -32,15 +31,12 @@ def scan( self._scan_files, Path(p).resolve(), recursive, set() ) future_to_path[future] = p - else: - logger.warning("[READ] Path does not exist: %s", p) for future in as_completed(future_to_path): path = future_to_path[future] try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning path %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -56,17 +52,14 @@ def _scan_files( # Avoid cycles due to symlinks if path_str in visited: - logger.warning("[READ] Skipping already visited path: %s", path_str) return self._empty_result(path_str) # cache check cache_key = f"scan::{path_str}::recursive::{recursive}" cached = self.cache.get(cache_key) if cached and not self.rescan: - logger.info("[READ] Using cached scan result for path: %s", path_str) return cached["data"] - logger.info("[READ] Scanning path: %s", path_str) files, dirs = [], [] stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0} @@ -108,7 +101,6 @@ def _scan_files( stats["errors"] += 1 except (PermissionError, FileNotFoundError, OSError) as e: - logger.error("[READ] Failed to scan path %s: %s", path_str, e) return {"error": str(e), "files": [], "dirs": [], "stats": stats} if recursive: @@ -171,7 +163,6 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning subdirectory %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -183,18 +174,14 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An def _cache_result(self, key: str, result: Dict, path: Path): """Cache the scan result""" - try: - self.cache.set( - key, - { - "data": result, - "dir_mtime": path.stat().st_mtime, - "cached_at": time.time(), - }, - ) - logger.info("[READ] Cached scan result for path: %s", path) - except OSError as e: - logger.error("[READ] Failed to cache scan result for path %s: %s", path, e) + self.cache.set( + key, + { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + ) def _is_allowed_file(self, path: Path) -> bool: """Check if the file has an allowed suffix""" @@ -209,7 +196,6 @@ def invalidate(self, path: str): keys = [k for k in self.cache if k.startswith(f"scan::{path}")] for k in keys: self.cache.delete(k) - logger.info("[READ] Invalidated cache for path: %s", path) def close(self): self.cache.close() diff --git a/graphgen/run.py b/graphgen/run.py index 7a8bb654..419fd7bd 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -12,7 +12,7 @@ from graphgen.engine import Engine from graphgen.operators import operators -from graphgen.utils import logger, set_logger +from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -88,14 +88,17 @@ def main(): output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") set_working_dir(output_path) - set_logger( - os.path.join(output_path, f"{unique_id}.log"), + log_path = os.path.join(working_dir, "logs", "Driver.log") + driver_logger = set_logger( + log_path, + name="GraphGen", if_stream=True, ) + CURRENT_LOGGER_VAR.set(driver_logger) logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join(working_dir, f"{unique_id}.log"), + log_path, ) ds = ray.data.from_items([]) results = engine.execute(ds) diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index d3e6df7b..ec118816 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -16,7 +16,7 @@ compute_mm_hash, ) from .help_nltk import NLTKHelper -from .log import logger, parse_log, set_logger +from .log import CURRENT_LOGGER_VAR, logger, set_logger from .loop import create_event_loop from .run_concurrent import run_concurrent from .wrap import async_to_sync_method diff --git a/graphgen/utils/log.py b/graphgen/utils/log.py index 102b7b23..e29e994e 100644 --- a/graphgen/utils/log.py +++ b/graphgen/utils/log.py @@ -1,13 +1,15 @@ +import contextvars import logging +import os from logging.handlers import RotatingFileHandler +from typing import Any from rich.logging import RichHandler -logger = logging.getLogger("graphgen") - def set_logger( log_file: str, + name: str, file_level: int = logging.DEBUG, console_level: int = logging.INFO, *, @@ -17,26 +19,27 @@ def set_logger( force: bool = False, ): - if logger.hasHandlers() and not force: - return + current_logger = logging.getLogger(name) + if current_logger.hasHandlers() and not force: + return current_logger if force: - logger.handlers.clear() + current_logger.handlers.clear() - logger.setLevel( + current_logger.setLevel( min(file_level, console_level) ) # Set to the lowest level to capture all logs - logger.propagate = False + current_logger.propagate = False - if logger.handlers: - logger.handlers.clear() + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) if if_stream: console = RichHandler( level=console_level, show_path=False, rich_tracebacks=True ) console.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(console) + current_logger.addHandler(console) file_handler = RotatingFileHandler( log_file, @@ -51,10 +54,48 @@ def set_logger( datefmt="%y-%m-%d %H:%M:%S", ) ) - logger.addHandler(file_handler) + current_logger.addHandler(file_handler) + return current_logger + + +CURRENT_LOGGER_VAR = contextvars.ContextVar("current_logger") + + +def get_current_logger() -> logging.Logger: + current_logger = CURRENT_LOGGER_VAR.get() + if not current_logger: + raise RuntimeError("No logger is set in the current context.") + return current_logger + + +class ContextAwareLogger: + @staticmethod + def _get_logger() -> logging.Logger: + return get_current_logger() + + def debug(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().debug(msg, *args, **kwargs) + + def info(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().info(msg, *args, **kwargs) + + def warning(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().warning(msg, *args, **kwargs) + + def error(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().error(msg, *args, **kwargs) + + def exception(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().exception(msg, *args, **kwargs) + + def critical(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().critical(msg, *args, **kwargs) + + def log(self, level: int, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().log(level, msg, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_logger(), name) -def parse_log(log_file: str): - with open(log_file, "r", encoding="utf-8") as f: - lines = f.readlines() - return lines +logger = ContextAwareLogger()