diff --git a/scripts/baselines/generate_all_baselines.sh b/examples/baselines/generate_all_baselines.sh similarity index 100% rename from scripts/baselines/generate_all_baselines.sh rename to examples/baselines/generate_all_baselines.sh diff --git a/scripts/baselines/generate_bds.sh b/examples/baselines/generate_bds.sh similarity index 100% rename from scripts/baselines/generate_bds.sh rename to examples/baselines/generate_bds.sh diff --git a/scripts/baselines/generate_entigraph.sh b/examples/baselines/generate_entigraph.sh similarity index 100% rename from scripts/baselines/generate_entigraph.sh rename to examples/baselines/generate_entigraph.sh diff --git a/scripts/baselines/generate_genie.sh b/examples/baselines/generate_genie.sh similarity index 100% rename from scripts/baselines/generate_genie.sh rename to examples/baselines/generate_genie.sh diff --git a/scripts/baselines/generate_longform.sh b/examples/baselines/generate_longform.sh similarity index 100% rename from scripts/baselines/generate_longform.sh rename to examples/baselines/generate_longform.sh diff --git a/scripts/baselines/generate_selfqa.sh b/examples/baselines/generate_selfqa.sh similarity index 100% rename from scripts/baselines/generate_selfqa.sh rename to examples/baselines/generate_selfqa.sh diff --git a/scripts/baselines/generate_wrap.sh b/examples/baselines/generate_wrap.sh similarity index 100% rename from scripts/baselines/generate_wrap.sh rename to examples/baselines/generate_wrap.sh diff --git a/scripts/evaluate/evaluate.sh b/examples/evaluate/evaluate.sh similarity index 100% rename from scripts/evaluate/evaluate.sh rename to examples/evaluate/evaluate.sh diff --git a/examples/extract/extract_schema_guided/README.md b/examples/extract/extract_schema_guided/README.md new file mode 100644 index 00000000..ab117c0f --- /dev/null +++ b/examples/extract/extract_schema_guided/README.md @@ -0,0 +1 @@ +# Extract Schema-Guided Information from Documents diff --git a/examples/extract/extract_schema_guided/extract_schema_guided.sh b/examples/extract/extract_schema_guided/extract_schema_guided.sh new file mode 100644 index 00000000..6ffd0fde --- /dev/null +++ b/examples/extract/extract_schema_guided/extract_schema_guided.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml \ +--output_dir cache/ diff --git a/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml new file mode 100644 index 00000000..7bd359b3 --- /dev/null +++ b/examples/extract/extract_schema_guided/schema_guided_extraction_config.yaml @@ -0,0 +1,34 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/extract_demo.txt + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 20480 # larger chunk size for better context + chunk_overlap: 2000 + + - id: extract + op_name: extract + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + params: + method: schema_guided + schema_path: graphgen/templates/extraction/schemas/legal_contract.json diff --git a/examples/generate/generate_aggregated_qa/README.md b/examples/generate/generate_aggregated_qa/README.md new file mode 100644 index 00000000..ab08693b --- /dev/null +++ b/examples/generate/generate_aggregated_qa/README.md @@ -0,0 +1,3 @@ +# Generate Aggregated QAs + +Aggregated mode is one of three question-answering scenarios in GraphGen (alongside atomic and multi-hop) designed to generate synthetic training data that incorporates complex, integrated knowledge from multiple sources. \ No newline at end of file diff --git a/examples/generate/generate_aggregated_qa/aggregated_config.yaml b/examples/generate/generate_aggregated_qa/aggregated_config.yaml new file mode 100644 index 00000000..09f95653 --- /dev/null +++ b/examples/generate/generate_aggregated_qa/aggregated_config.yaml @@ -0,0 +1,77 @@ +global_params: + working_dir: cache + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: aggregate + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + concurrency_limit: 200 + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/examples/generate/generate_aggregated_qa/generate_aggregated.sh b/examples/generate/generate_aggregated_qa/generate_aggregated.sh new file mode 100644 index 00000000..cae544ff --- /dev/null +++ b/examples/generate/generate_aggregated_qa/generate_aggregated.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_aggregated_qa/aggregated_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_atomic_qa/README.md b/examples/generate/generate_atomic_qa/README.md new file mode 100644 index 00000000..e979b182 --- /dev/null +++ b/examples/generate/generate_atomic_qa/README.md @@ -0,0 +1,3 @@ +# Generate Atomic QAs + +Atomic mode generates question-answer pairs that test basic, isolated knowledge from individual facts or relationships in the knowledge graph. \ No newline at end of file diff --git a/examples/generate/generate_atomic_qa/atomic_config.yaml b/examples/generate/generate_atomic_qa/atomic_config.yaml new file mode 100644 index 00000000..a76272b9 --- /dev/null +++ b/examples/generate/generate_atomic_qa/atomic_config.yaml @@ -0,0 +1,53 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/json_demo.json + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + execution_params: + replicas: 1 + batch_size: 128 + dependencies: + - chunk + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: dfs + method_params: + max_units_per_community: 1 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: atomic + data_format: Alpaca diff --git a/examples/generate/generate_atomic_qa/generate_atomic.sh b/examples/generate/generate_atomic_qa/generate_atomic.sh new file mode 100644 index 00000000..c9fdb977 --- /dev/null +++ b/examples/generate/generate_atomic_qa/generate_atomic.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_atomic_qa/atomic_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_cot_qa/README.md b/examples/generate/generate_cot_qa/README.md new file mode 100644 index 00000000..37afe9c7 --- /dev/null +++ b/examples/generate/generate_cot_qa/README.md @@ -0,0 +1 @@ +# Generate CoT QAs diff --git a/examples/generate/generate_cot_qa/cot_config.yaml b/examples/generate/generate_cot_qa/cot_config.yaml new file mode 100644 index 00000000..1daf7fa1 --- /dev/null +++ b/examples/generate/generate_cot_qa/cot_config.yaml @@ -0,0 +1,55 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/txt_demo.txt + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + execution_params: + replicas: 1 + batch_size: 128 + dependencies: + - chunk + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: leiden + method_params: + max_size: 20 + use_lcc: false + random_seed: 42 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: cot + data_format: Sharegpt diff --git a/examples/generate/generate_cot_qa/generate_cot.sh b/examples/generate/generate_cot_qa/generate_cot.sh new file mode 100644 index 00000000..d34d503f --- /dev/null +++ b/examples/generate/generate_cot_qa/generate_cot.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_cot_qa/cot_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_multi_hop_qa/README.md b/examples/generate/generate_multi_hop_qa/README.md new file mode 100644 index 00000000..dcee73be --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/README.md @@ -0,0 +1 @@ +# Generate Multi-hop QAs diff --git a/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh b/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh new file mode 100644 index 00000000..2bfbc91c --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/generate_multi_hop.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_multi_hop_qa/multi_hop_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml new file mode 100644 index 00000000..1ef2f13f --- /dev/null +++ b/examples/generate/generate_multi_hop_qa/multi_hop_config.yaml @@ -0,0 +1,56 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/csv_demo.csv + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: ece + method_params: + max_units_per_community: 3 + min_units_per_community: 3 + max_tokens_per_community: 10240 + unit_sampling: random + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: multi_hop + data_format: ChatML diff --git a/examples/generate/generate_vqa/README.md b/examples/generate/generate_vqa/README.md new file mode 100644 index 00000000..42b13865 --- /dev/null +++ b/examples/generate/generate_vqa/README.md @@ -0,0 +1 @@ +# Generate VQAs \ No newline at end of file diff --git a/examples/generate/generate_vqa/generate_vqa.sh b/examples/generate/generate_vqa/generate_vqa.sh new file mode 100644 index 00000000..7c7313fa --- /dev/null +++ b/examples/generate/generate_vqa/generate_vqa.sh @@ -0,0 +1,3 @@ +python3 -m graphgen.run \ +--config_file examples/generate/generate_vqa/vqa_config.yaml \ +--output_dir cache/ diff --git a/examples/generate/generate_vqa/vqa_config.yaml b/examples/generate/generate_vqa/vqa_config.yaml new file mode 100644 index 00000000..335c5e5f --- /dev/null +++ b/examples/generate/generate_vqa/vqa_config.yaml @@ -0,0 +1,57 @@ +global_params: + working_dir: cache + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/vqa_demo.json + modalities: + - text + - image + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 1024 + chunk_overlap: 100 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - build_kg + params: + method: anchor_bfs + method_params: + anchor_type: image + max_units_per_community: 10 + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + params: + method: vqa + data_format: ChatML \ No newline at end of file diff --git a/resources/input_examples/csv_demo.csv b/examples/input_examples/csv_demo.csv similarity index 100% rename from resources/input_examples/csv_demo.csv rename to examples/input_examples/csv_demo.csv diff --git a/resources/input_examples/extract_demo.txt b/examples/input_examples/extract_demo.txt similarity index 100% rename from resources/input_examples/extract_demo.txt rename to examples/input_examples/extract_demo.txt diff --git a/resources/input_examples/graphml_demo.graphml b/examples/input_examples/graphml_demo.graphml similarity index 100% rename from resources/input_examples/graphml_demo.graphml rename to examples/input_examples/graphml_demo.graphml diff --git a/resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg b/examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg similarity index 100% rename from resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg rename to examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg diff --git a/resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg b/examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg similarity index 100% rename from resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg rename to examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg diff --git a/resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg b/examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg similarity index 100% rename from resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg rename to examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg diff --git a/resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg b/examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg similarity index 100% rename from resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg rename to examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg diff --git a/resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg b/examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg similarity index 100% rename from resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg rename to examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg diff --git a/resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg b/examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg similarity index 100% rename from resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg rename to examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg diff --git a/resources/input_examples/json_demo.json b/examples/input_examples/json_demo.json similarity index 100% rename from resources/input_examples/json_demo.json rename to examples/input_examples/json_demo.json diff --git a/resources/input_examples/jsonl_demo.jsonl b/examples/input_examples/jsonl_demo.jsonl similarity index 100% rename from resources/input_examples/jsonl_demo.jsonl rename to examples/input_examples/jsonl_demo.jsonl diff --git a/resources/input_examples/pdf_demo.pdf b/examples/input_examples/pdf_demo.pdf similarity index 100% rename from resources/input_examples/pdf_demo.pdf rename to examples/input_examples/pdf_demo.pdf diff --git a/resources/input_examples/search_dna_demo.jsonl b/examples/input_examples/search_dna_demo.jsonl similarity index 100% rename from resources/input_examples/search_dna_demo.jsonl rename to examples/input_examples/search_dna_demo.jsonl diff --git a/resources/input_examples/search_protein_demo.jsonl b/examples/input_examples/search_protein_demo.jsonl similarity index 100% rename from resources/input_examples/search_protein_demo.jsonl rename to examples/input_examples/search_protein_demo.jsonl diff --git a/resources/input_examples/search_rna_demo.jsonl b/examples/input_examples/search_rna_demo.jsonl similarity index 100% rename from resources/input_examples/search_rna_demo.jsonl rename to examples/input_examples/search_rna_demo.jsonl diff --git a/resources/input_examples/txt_demo.txt b/examples/input_examples/txt_demo.txt similarity index 100% rename from resources/input_examples/txt_demo.txt rename to examples/input_examples/txt_demo.txt diff --git a/resources/input_examples/vqa_demo.json b/examples/input_examples/vqa_demo.json similarity index 66% rename from resources/input_examples/vqa_demo.json rename to examples/input_examples/vqa_demo.json index 9d9661ec..d3aed723 100644 --- a/resources/input_examples/vqa_demo.json +++ b/examples/input_examples/vqa_demo.json @@ -9,11 +9,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg", - "image_caption": [ + "content":{ + "img_path": "examples/input_examples/images/8fb93cfc0d6b0ebb3e5d5aaae237df02964c9c3da8d8e9567ea19240b14cc742.jpg", + "image_caption": [ "Fig. 1. (A) Physical map of the hrp gene cluster of E. amylovora (4, 18, 29), showing restriction sites: B, Bam HI; E, Eco RI; H, Hind II. Gene hrpN, encoding harpin, is contained in the 1.3 kb Hind II fragment indicated by the solid bar. The shaded region (including hrpN) contains that part of the hrp gene cluster in which most transposon insertions, exemplified by K49, a Tn10 mini-kan (30) insertion, abolish the HR and pathogenicity phenotypes. Most " - ], - "image_footnote": [] + ] + } }, { "type": "text", @@ -25,11 +26,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg", - "image_caption": [ + "content": { + "img_path": "examples/input_examples/images/cc5b36e3c972b210d8b56d34fc7ffe56f793f287b3399345aea31cd20eed2824.jpg", + "image_caption": [ "Fig. 2. Tobacco leaf showing responses 24 hours after infitration of sectors (7) with the following preparations: 1,, living E. coli DH5α (pCPP9) $( 1 \\times 1 0 ^ { 8 } / \\mathrm { m l } )$ ; 2, E. coli DH5α (pCPP430) $( 1 \\ \\times \\ 1 0 ^ { 8 } / \\mathrm { m l } )$ ; 3, E. coli DH5α (pCPP430K49) $( 1 \\times 1 0 ^ { 8 } / \\mathrm { m } )$ ; 4, E. amylovora Ea321 $( 1 \\times 1 0 ^ { 8 } / \\mathsf { m l } )$ ; 5, Ea321K49, an hrp mutant $( 1 \\times 1 0 ^ { 8 } / \\mathsf { m } )$ , 8, heat-treated CFEP from $\\pmb { \\varepsilon }$ coli ${ \\mathsf { D } } { \\mathsf { H } } { \\mathsf { S } } { \\mathsf { { \\alpha } } } ( { \\mathsf { P } } { \\mathsf { C } } { \\mathsf { P } } { \\mathsf { P } } { \\mathsf { 9 } } )$ ; 9,heat-treated CFEP from E. coli DH5α(pCPP430); 10, heat-treated CFEP from E. coli DH5α(pCPP430K49); 11, heattreated CFEP from $\\boldsymbol { \\varepsilon }$ amylovora Ea321; 12, heat-treated CFEP from Ea321K49; 6, harpin $( 1 . 1 \\mu M )$ from E. coli DH5α(pCPP430) eluted from SDS-polyacrylamide gel; 7, same preparation as 6, but protease treated for 2 hours then heated for io min to inactivate protease; 13, harpin $( 1 \\pmb { \\mu } \\pmb { M } )$ from E. amylovora Ea321 eluted from SDS-polyacrylamide gel; 14, same preparation as 13 but with protease treatment as sample 7. Harpin solutions $< - 0 . 3 \\mu \\mathsf { m }$ do not cause collapse of infitrated tissue; spotty and incomplete collapse is caused by harpin between 0.3 and $0 . 5 ~ { \\mu } \\mathsf { m }$ . " - ], - "image_footnote": [] + ] + } }, { "type": "text", @@ -41,10 +43,12 @@ }, { "type": "table", - "img_path": "resources/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg", - "table_caption": [], - "table_footnote": [], - "table_body": "
Protease per milliterTissue collapseHarpin detected
0++
5μg++
10μg++
20 μgWeak+
40 μg-
80μg
80μg + 0.5 mM PMSF++
Cell-free supernatant
" + "content": { + "img_path": "examples/input_examples/images/0f25783fdfa99042db274ba9f6b3064cf17c5435814edfbee42ae6b19aac37d2.jpg", + "table_caption": [], + "table_footnote": [], + "table_body": "
Protease per milliterTissue collapseHarpin detected
0++
5μg++
10μg++
20 μgWeak+
40 μg-
80μg
80μg + 0.5 mM PMSF++
Cell-free supernatant
" + } }, { "type": "text", @@ -52,11 +56,12 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg", - "image_caption": [ - "Fig. 3. SDS-polyacrylamide gel electrophoresis of CFEPs and purified harpin. Lanes: 1, purified harpin $( 1 . 5 \\ \\mathsf { \\pmb { \\mu } } \\mathsf { \\pmb { \\mathsf { g } } } )$ from E. coli $\\mathsf { D M } 5 \\alpha ( \\mathsf { p C P } 4 3 0 )$ incubated with protease (9) for 1 hour; 2, purified harpin $( 1 . 5 \\mu \\mathfrak { g } )$ from E. amylovora Ea321 incubated with protease for 1 hour; 3, same as 1, but without treatment with protease; 4, same as 2, but without treatment with protease; 5, CFEP (5 ${ \\pmb { \\mu } } ( { \\pmb q } )$ from E. coli DH5α(pCPP9) treated at $1 0 0 ^ { \\circ } \\mathbb { C }$ for 10'min; 6, CFEP $( 5 \\ \\pmb { \\mu } \\pmb { \\mu } )$ from E. coli DH5a(pCPP430K49) treated at $\\pmb { 1 0 0 } \\pmb { \\circ } \\pmb { \\subset }$ for 10 min; 7, CFEP $( 5 ~ \\mu 9 )$ from E. amylovora Ea321 treated " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/4abc534d1dea2b706e44aaac26fe2ae309fee014082db00bc2d87187a6bb5dca.jpg", + "image_caption": [ + "Fig. 3. SDS-polyacrylamide gel electrophoresis of CFEPs and purified harpin. Lanes: 1, purified harpin $( 1 . 5 \\ \\mathsf { \\pmb { \\mu } } \\mathsf { \\pmb { \\mathsf { g } } } )$ from E. coli $\\mathsf { D M } 5 \\alpha ( \\mathsf { p C P } 4 3 0 )$ incubated with protease (9) for 1 hour; 2, purified harpin $( 1 . 5 \\mu \\mathfrak { g } )$ from E. amylovora Ea321 incubated with protease for 1 hour; 3, same as 1, but without treatment with protease; 4, same as 2, but without treatment with protease; 5, CFEP (5 ${ \\pmb { \\mu } } ( { \\pmb q } )$ from E. coli DH5α(pCPP9) treated at $1 0 0 ^ { \\circ } \\mathbb { C }$ for 10'min; 6, CFEP $( 5 \\ \\pmb { \\mu } \\pmb { \\mu } )$ from E. coli DH5a(pCPP430K49) treated at $\\pmb { 1 0 0 } \\pmb { \\circ } \\pmb { \\subset }$ for 10 min; 7, CFEP $( 5 ~ \\mu 9 )$ from E. amylovora Ea321 treated " + ] + } }, { "type": "text", @@ -64,12 +69,13 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg", - "image_caption": [ - "Fig. 4. Subcellular location of elicitor protein. Logphase cells $( 1 . 5 m )$ of strain Ea321(pCPP430) were fractionated (31). Proteins from each fraction were electrophoresed and transferred to Immobilon-P membrane (Millipore, Bedford, Massachusetts). The Amplified Alkaline Phosphatase Immuno-Blot Assay Kit (170-6412, Bio-Rad Richmond, California) was ", - "used in a Western blot to detect the elicitor protein with an antiserum raised in rabbit in response to harpin (15). (A) Fractions in lanes: 1, periplasm; 2, membrane; 3, whole cells; 4, supernatant; 5, cytoplasm. (B) Harpin purified by high-performance liquid chromatography (19) hybridized with antiserum. Arrows indicates $4 4 \\ k \\mathsf { D }$ based on the molecular weight markers used in Fig. 3. (C) Normal serum control. CFEP from E. coli DH5a(pCPP430) hybridized with pre-immune serum. " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/390516e39e77030092027ded523ee99e96ffa8b6df4476c9b12d7bb1dd20d635.jpg", + "image_caption": [ + "Fig. 4. Subcellular location of elicitor protein. Logphase cells $( 1 . 5 m )$ of strain Ea321(pCPP430) were fractionated (31). Proteins from each fraction were electrophoresed and transferred to Immobilon-P membrane (Millipore, Bedford, Massachusetts). The Amplified Alkaline Phosphatase Immuno-Blot Assay Kit (170-6412, Bio-Rad Richmond, California) was ", + "used in a Western blot to detect the elicitor protein with an antiserum raised in rabbit in response to harpin (15). (A) Fractions in lanes: 1, periplasm; 2, membrane; 3, whole cells; 4, supernatant; 5, cytoplasm. (B) Harpin purified by high-performance liquid chromatography (19) hybridized with antiserum. Arrows indicates $4 4 \\ k \\mathsf { D }$ based on the molecular weight markers used in Fig. 3. (C) Normal serum control. CFEP from E. coli DH5a(pCPP430) hybridized with pre-immune serum. " + ] + } }, { "type": "text", @@ -77,10 +83,11 @@ }, { "type": "image", - "img_path": "resources/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg", - "image_caption": [ - "Fig. 5. Changes in pH of bathing solution of tobacco cell-suspension cultures (TCSC). Control values (no additive) were subtracted. Open squares, harpin (60 nM); open circles, cells of E. coli $\\mathsf { D H } 5 \\alpha ( \\mathsf { p C P P } 4 3 0 )$ $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliliter); filled squares, cells of E. amylovora Ea321 $( 5 \\times 1 0 ^ { 7 }$ cells per milliiter); triangles, cells of E. coli DH5α(pCPP430K49) $( 5 \\times 1 0 ^ { 7 }$ cells per milliter); diamonds, cells of $\\boldsymbol { \\varepsilon }$ amylovora Ea321K49 $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliter); filled circles, cells of $\\boldsymbol { E } .$ coli DH5α(pCPP9) $( 5 \\times$ $\\pmb { 1 0 ^ { 6 } }$ cells per mililiter). TCSCs were shaken at room temperature with the indicated preparations. The pH was measured at the intervals indicated. All preparations that elicited HR in tobacco leaves (Fig. 2) also caused a pH increase in the TCSC medium. " - ], - "image_footnote": [] + "content": { + "img_path": "examples/input_examples/images/eda01885ec54011f15e7a4a56bea0129a0475b2ab5b920a4cff20a4fb623517d.jpg", + "image_caption": [ + "Fig. 5. Changes in pH of bathing solution of tobacco cell-suspension cultures (TCSC). Control values (no additive) were subtracted. Open squares, harpin (60 nM); open circles, cells of E. coli $\\mathsf { D H } 5 \\alpha ( \\mathsf { p C P P } 4 3 0 )$ $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliliter); filled squares, cells of E. amylovora Ea321 $( 5 \\times 1 0 ^ { 7 }$ cells per milliiter); triangles, cells of E. coli DH5α(pCPP430K49) $( 5 \\times 1 0 ^ { 7 }$ cells per milliter); diamonds, cells of $\\boldsymbol { \\varepsilon }$ amylovora Ea321K49 $( 5 ~ \\times ~ 1 0 ^ { 7 }$ cells per milliter); filled circles, cells of $\\boldsymbol { E } .$ coli DH5α(pCPP9) $( 5 \\times$ $\\pmb { 1 0 ^ { 6 } }$ cells per mililiter). TCSCs were shaken at room temperature with the indicated preparations. The pH was measured at the intervals indicated. All preparations that elicited HR in tobacco leaves (Fig. 2) also caused a pH increase in the TCSC medium. " + ] + } } ] \ No newline at end of file diff --git a/resources/output_examples/aggregated_chatml.json b/examples/output_examples/aggregated_chatml.json similarity index 100% rename from resources/output_examples/aggregated_chatml.json rename to examples/output_examples/aggregated_chatml.json diff --git a/resources/output_examples/atomic_alpaca.json b/examples/output_examples/atomic_alpaca.json similarity index 100% rename from resources/output_examples/atomic_alpaca.json rename to examples/output_examples/atomic_alpaca.json diff --git a/resources/output_examples/cot_sharegpt.json b/examples/output_examples/cot_sharegpt.json similarity index 100% rename from resources/output_examples/cot_sharegpt.json rename to examples/output_examples/cot_sharegpt.json diff --git a/resources/output_examples/multi-hop_chatml.json b/examples/output_examples/multi-hop_chatml.json similarity index 100% rename from resources/output_examples/multi-hop_chatml.json rename to examples/output_examples/multi-hop_chatml.json diff --git a/scripts/search/build_db/build_dna_blast_db.sh b/examples/search/build_db/build_dna_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_dna_blast_db.sh rename to examples/search/build_db/build_dna_blast_db.sh diff --git a/scripts/search/build_db/build_protein_blast_db.sh b/examples/search/build_db/build_protein_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_protein_blast_db.sh rename to examples/search/build_db/build_protein_blast_db.sh diff --git a/scripts/search/build_db/build_rna_blast_db.sh b/examples/search/build_db/build_rna_blast_db.sh similarity index 100% rename from scripts/search/build_db/build_rna_blast_db.sh rename to examples/search/build_db/build_rna_blast_db.sh diff --git a/scripts/search/search_dna.sh b/examples/search/search_dna.sh similarity index 100% rename from scripts/search/search_dna.sh rename to examples/search/search_dna.sh diff --git a/graphgen/configs/search_dna_config.yaml b/examples/search/search_dna_config.yaml similarity index 100% rename from graphgen/configs/search_dna_config.yaml rename to examples/search/search_dna_config.yaml diff --git a/graphgen/configs/search_protein_config.yaml b/examples/search/search_protein_config.yaml similarity index 100% rename from graphgen/configs/search_protein_config.yaml rename to examples/search/search_protein_config.yaml diff --git a/scripts/search/search_rna.sh b/examples/search/search_rna.sh similarity index 100% rename from scripts/search/search_rna.sh rename to examples/search/search_rna.sh diff --git a/graphgen/configs/search_rna_config.yaml b/examples/search/search_rna_config.yaml similarity index 100% rename from graphgen/configs/search_rna_config.yaml rename to examples/search/search_rna_config.yaml diff --git a/scripts/search/search_uniprot.sh b/examples/search/search_uniprot.sh similarity index 100% rename from scripts/search/search_uniprot.sh rename to examples/search/search_uniprot.sh diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 3d0bc800..41136974 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -2,15 +2,11 @@ from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_wrapper import BaseLLMWrapper +from .base_operator import BaseOperator from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_searcher import BaseSearcher from .base_splitter import BaseSplitter -from .base_storage import ( - BaseGraphStorage, - BaseKVStorage, - BaseListStorage, - StorageNameSpace, -) +from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer -from .datatypes import Chunk, QAPair, Token +from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py new file mode 100644 index 00000000..300d3178 --- /dev/null +++ b/graphgen/bases/base_operator.py @@ -0,0 +1,57 @@ +import inspect +import os +from abc import ABC, abstractmethod +from typing import Iterable, Union + +import pandas as pd +import ray + +from graphgen.utils import CURRENT_LOGGER_VAR, set_logger + + +class BaseOperator(ABC): + def __init__(self, working_dir: str = "cache", op_name: str = None): + log_dir = os.path.join(working_dir, "logs") + self.op_name = op_name or self.__class__.__name__ + + try: + ctx = ray.get_runtime_context() + worker_id = ctx.get_actor_id() or ctx.get_worker_id() + worker_id_short = worker_id[-6:] if worker_id else "driver" + except Exception as e: + print( + "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:", + e, + ) + worker_id_short = "local" + + # e.g. cache/logs/ChunkService_a1b2c3.log + log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log") + + self.logger = set_logger( + log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True + ) + + self.logger.info( + "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short + ) + + def __call__( + self, batch: pd.DataFrame + ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: + logger_token = CURRENT_LOGGER_VAR.set(self.logger) + try: + result = self.process(batch) + if inspect.isgenerator(result): + yield from result + else: + yield result + finally: + CURRENT_LOGGER_VAR.reset(logger_token) + + @abstractmethod + def process(self, batch): + raise NotImplementedError("Subclasses must implement the process method.") + + def get_logger(self): + return self.logger diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index d74ff563..d948e3a7 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -7,7 +7,7 @@ class BasePartitioner(ABC): @abstractmethod - async def partition( + def partition( self, g: BaseGraphStorage, **kwargs: Any, @@ -20,39 +20,34 @@ async def partition( """ @staticmethod - async def community2batch( - communities: List[Community], g: BaseGraphStorage - ) -> list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] + def community2batch( + comm: Community, g: BaseGraphStorage + ) -> tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ]: """ Convert communities to batches of nodes and edges. - :param communities + :param comm: Community :param g: Graph storage instance :return: List of batches, each batch is a tuple of (nodes, edges) """ - batches = [] - for comm in communities: - nodes = comm.nodes - edges = comm.edges - nodes_data = [] - for node in nodes: - node_data = g.get_node(node) - if node_data: - nodes_data.append((node, node_data)) - edges_data = [] - for u, v in edges: - edge_data = g.get_edge(u, v) + nodes = comm.nodes + edges = comm.edges + nodes_data = [] + for node in nodes: + node_data = g.get_node(node) + if node_data: + nodes_data.append((node, node_data)) + edges_data = [] + for u, v in edges: + edge_data = g.get_edge(u, v) + if edge_data: + edges_data.append((u, v, edge_data)) + else: + edge_data = g.get_edge(v, u) if edge_data: - edges_data.append((u, v, edge_data)) - else: - edge_data = g.get_edge(v, u) - if edge_data: - edges_data.append((v, u, edge_data)) - batches.append((nodes_data, edges_data)) - return batches + edges_data.append((v, u, edge_data)) + return nodes_data, edges_data @staticmethod def _build_adjacency_list( diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 89778469..5d2af735 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,8 +1,10 @@ import os from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import pandas as pd import requests +from ray.data import Dataset class BaseReader(ABC): @@ -10,56 +12,70 @@ class BaseReader(ABC): Abstract base class for reading and processing data. """ - def __init__(self, text_column: str = "content"): + def __init__(self, text_column: str = "content", modalities: list = None): self.text_column = text_column + self.modalities = modalities if modalities is not None else ["text"] @abstractmethod - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read data from the specified file path. - :param file_path: Path to the input file. - :return: List of dictionaries containing the data. + :param input_path: Path to the input file or list of file paths. + :return: Ray Dataset containing the read data. """ - @staticmethod - def filter(data: List[dict]) -> List[dict]: + def _should_keep_item(self, item: Dict[str, Any]) -> bool: + """ + Determine whether to keep the given item based on the text column. + + :param item: Dictionary representing a data entry. + :return: True if the item should be kept, False otherwise. """ - Filter out entries with empty or missing text in the specified column. + item_type = item.get("type") + assert item_type in [ + "text", + "image", + "table", + "equation", + "protein", + ], f"Unsupported item type: {item_type}" + if item_type == "text": + content = item.get(self.text_column, "").strip() + return bool(content) + return True - :param data: List of dictionaries containing the data. - :return: Filtered list of dictionaries. + def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """ + Validate data format. """ + if "type" not in batch.columns: + raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") - def _image_exists(path_or_url: str, timeout: int = 3) -> bool: - """ - Check if an image exists at the given local path or URL. - :param path_or_url: Local file path or remote URL of the image. - :param timeout: Timeout for remote URL requests in seconds. - :return: True if the image exists, False otherwise. - """ - if not path_or_url: - return False - if not path_or_url.startswith(("http://", "https://", "ftp://")): - path = path_or_url.replace("file://", "", 1) - path = os.path.abspath(path) - return os.path.isfile(path) - try: - resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) - return resp.status_code == 200 - except requests.RequestException: - return False + if "text" in batch["type"].values: + if self.text_column not in batch.columns: + raise ValueError( + f"Missing '{self.text_column}' column for text documents" + ) - filtered_data = [] - for item in data: - if item.get("type") == "text": - content = item.get("content", "").strip() - if content: - filtered_data.append(item) - elif item.get("type") in ("image", "table", "equation"): - img_path = item.get("img_path") - if _image_exists(img_path): - filtered_data.append(item) - else: - filtered_data.append(item) - return filtered_data + return batch + + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except requests.RequestException: + return False diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index b2d1ad3a..f77be6e4 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -4,7 +4,7 @@ from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger +from graphgen.utils.log import logger class BaseSplitter(ABC): @@ -33,7 +33,7 @@ def split_text(self, text: str) -> List[str]: """ Split the input text into smaller chunks. - :param text: The input text to be split. + :param text: The input text to be chunk. :return: A list of text chunks. """ @@ -111,7 +111,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: def _split_text_with_regex( text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index bfcd658c..ff7d2d1a 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -16,23 +16,6 @@ def query_done_callback(self): """commit the storage operations after querying""" -class BaseListStorage(Generic[T], StorageNameSpace): - def all_items(self) -> list[T]: - raise NotImplementedError - - def get_by_index(self, index: int) -> Union[T, None]: - raise NotImplementedError - - def append(self, data: T): - raise NotImplementedError - - def upsert(self, data: list[T]): - raise NotImplementedError - - def drop(self): - raise NotImplementedError - - class BaseKVStorage(Generic[T], StorageNameSpace): def all_keys(self) -> list[str]: raise NotImplementedError @@ -58,6 +41,9 @@ def upsert(self, data: dict[str, T]): def drop(self): raise NotImplementedError + def reload(self): + raise NotImplementedError + class BaseGraphStorage(StorageNameSpace): def has_node(self, node_id: str) -> bool: @@ -105,3 +91,6 @@ def upsert_edge( def delete_node(self, node_id: str): raise NotImplementedError + + def reload(self): + raise NotImplementedError diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index cb3be345..df719fdf 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -2,6 +2,8 @@ from dataclasses import dataclass, field from typing import List, Union +from pydantic import BaseModel, Field, field_validator + @dataclass class Chunk: @@ -48,3 +50,45 @@ class Community: nodes: List[str] = field(default_factory=list) edges: List[tuple] = field(default_factory=list) metadata: dict = field(default_factory=dict) + + +class Node(BaseModel): + id: str = Field(..., description="unique node id") + op_name: str = Field(..., description="operator name") + type: str = Field( + ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch" + ) + params: dict = Field(default_factory=dict, description="operator parameters") + dependencies: List[str] = Field( + default_factory=list, description="list of dependent node ids" + ) + execution_params: dict = Field( + default_factory=dict, description="execution parameters like replicas, batch_size" + ) + + @classmethod + @field_validator("type") + def validate_type(cls, v: str) -> str: + valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"} + if v not in valid_types: + raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.") + return v + + +class Config(BaseModel): + global_params: dict = Field( + default_factory=dict, description="global context for the computation graph" + ) + + nodes: List[Node] = Field( + ..., min_length=1, description="list of nodes in the computation graph" + ) + + @classmethod + @field_validator("nodes") + def validate_unique_ids(cls, v: List[Node]) -> List[Node]: + ids = [node.id for node in v] + if len(ids) != len(set(ids)): + duplicates = {id_ for id_ in ids if ids.count(id_) > 1} + raise ValueError(f"Duplicate node ids found: {duplicates}") + return v diff --git a/graphgen/common/__init__.py b/graphgen/common/__init__.py new file mode 100644 index 00000000..deb99459 --- /dev/null +++ b/graphgen/common/__init__.py @@ -0,0 +1,2 @@ +from .init_llm import init_llm +from .init_storage import init_storage diff --git a/graphgen/operators/init/init_llm.py b/graphgen/common/init_llm.py similarity index 97% rename from graphgen/operators/init/init_llm.py rename to graphgen/common/init_llm.py index e294d2c3..79a8677b 100644 --- a/graphgen/operators/init/init_llm.py +++ b/graphgen/common/init_llm.py @@ -29,6 +29,7 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: return HTTPClient(**config) if backend in ("openai_api", "azure_openai_api"): from graphgen.models.llm.api.openai_client import OpenAIClient + # pass in concrete backend to the OpenAIClient so that internally we can distinguish # between OpenAI and Azure OpenAI return OpenAIClient(**config, backend=backend) @@ -79,3 +80,6 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: backend = config.pop("backend") llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) return llm_wrapper + + +# TODO: use ray serve when loading large models to avoid re-loading in each actor diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py new file mode 100644 index 00000000..f9c4de57 --- /dev/null +++ b/graphgen/common/init_storage.py @@ -0,0 +1,28 @@ +from graphgen.models import JsonKVStorage, NetworkXStorage + + +class StorageFactory: + """ + Factory class to create storage instances based on backend. + Supported backends: + kv_storage(key-value storage): + - json_kv: JsonKVStorage + graph_storage: + - networkx: NetworkXStorage (graph storage) + """ + + @staticmethod + def create_storage(backend: str, working_dir: str, namespace: str): + if backend == "json_kv": + return JsonKVStorage(working_dir, namespace=namespace) + + if backend == "networkx": + return NetworkXStorage(working_dir, namespace=namespace) + + raise NotImplementedError( + f"Storage backend '{backend}' is not implemented yet." + ) + + +def init_storage(backend: str, working_dir: str, namespace: str): + return StorageFactory.create_storage(backend, working_dir, namespace) diff --git a/graphgen/configs/README.md b/graphgen/configs/README.md deleted file mode 100644 index afa815cd..00000000 --- a/graphgen/configs/README.md +++ /dev/null @@ -1 +0,0 @@ -# Configs for GraphGen diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml deleted file mode 100644 index 9c53ec9c..00000000 --- a/graphgen/configs/aggregated_config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -pipeline: - - name: read_step # step name is unique in the pipeline, and can be referenced by other steps - op_key: read - params: - input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: quiz_and_judge_step - op_key: quiz_and_judge - deps: [build_kg_step] # quiz_and_judge depends on build_kg_step - params: - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples - - - name: partition_step - op_key: partition - deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 20 # max nodes and edges per community - min_units_per_community: 5 # min nodes and edges per community - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: aggregated # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml deleted file mode 100644 index f8ae2218..00000000 --- a/graphgen/configs/atomic_config.yaml +++ /dev/null @@ -1,31 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg] # partition_step depends on build_kg - params: - method: dfs # partition method, support: dfs, bfs, ece, leiden - method_params: - max_units_per_community: 1 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: atomic # atomic, aggregated, multi_hop, cot, vqa - data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml deleted file mode 100644 index b09e341d..00000000 --- a/graphgen/configs/cot_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg - params: - method: leiden # leiden is a partitioner detection algorithm - method_params: - max_size: 20 # Maximum size of communities - use_lcc: false # whether to use the largest connected component - random_seed: 42 # random seed for partitioning - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: cot # atomic, aggregated, multi_hop, cot, vqa - data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml deleted file mode 100644 index 4b8051b4..00000000 --- a/graphgen/configs/multi_hop_config.yaml +++ /dev/null @@ -1,34 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 - min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: multi_hop # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/schema_guided_extraction_config.yaml b/graphgen/configs/schema_guided_extraction_config.yaml deleted file mode 100644 index 8d142ef6..00000000 --- a/graphgen/configs/schema_guided_extraction_config.yaml +++ /dev/null @@ -1,20 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 20480 - chunk_overlap: 2000 - separators: [] - - - name: extract_step - op_key: extract - deps: [chunk_step] # extract_step depends on chunk_step - params: - method: schema_guided # extraction method, support: schema_guided - schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method diff --git a/graphgen/configs/vqa_config.yaml b/graphgen/configs/vqa_config.yaml deleted file mode 100644 index 06eba5c4..00000000 --- a/graphgen/configs/vqa_config.yaml +++ /dev/null @@ -1,32 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: anchor_bfs # partition method - method_params: - anchor_type: image # node type to select anchor nodes - max_units_per_community: 10 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: vqa # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/engine.py b/graphgen/engine.py index 2989226c..6d7e1051 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,125 +1,209 @@ -""" -orchestration engine for GraphGen -""" +import inspect +import logging +from collections import defaultdict, deque +from functools import wraps +from typing import Any, Callable, Dict, List, Set -import threading -import traceback -from typing import Any, Callable, List +import ray +import ray.data +from graphgen.bases import Config, Node -class Context(dict): - _lock = threading.Lock() - def set(self, k, v): - with self._lock: - self[k] = v - - def get(self, k, default=None): - with self._lock: - return super().get(k, default) - - -class OpNode: +class Engine: def __init__( - self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] + self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs ): - self.name, self.deps, self.func = name, deps, func - + self.config = Config(**config) + self.global_params = self.config.global_params + self.functions = functions + self.datasets: Dict[str, ray.data.Dataset] = {} + + if not ray.is_initialized(): + context = ray.init( + ignore_reinit_error=True, + logging_level=logging.ERROR, + log_to_driver=True, + **ray_init_kwargs, + ) + print(f"Ray Dashboard URL: {context.dashboard_url}") -class Engine: - def __init__(self, max_workers: int = 4): - self.max_workers = max_workers - - def run(self, ops: List[OpNode], ctx: Context): - self._validate(ops) - name2op = {operation.name: operation for operation in ops} - - # topological sort - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(cur) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) - - if len(topo) != len(ops): + @staticmethod + def _topo_sort(nodes: List[Node]) -> List[Node]: + id_to_node: Dict[str, Node] = {} + for n in nodes: + id_to_node[n.id] = n + + indeg: Dict[str, int] = {nid: 0 for nid in id_to_node} + adj: Dict[str, List[str]] = defaultdict(list) + + for n in nodes: + nid = n.id + deps: List[str] = n.dependencies + uniq_deps: Set[str] = set(deps) + for d in uniq_deps: + if d not in id_to_node: + raise ValueError( + f"The dependency node id {d} of node {nid} is not defined in the configuration." + ) + indeg[nid] += 1 + adj[d].append(nid) + + zero_deg: deque = deque( + [id_to_node[nid] for nid, deg in indeg.items() if deg == 0] + ) + sorted_nodes: List[Node] = [] + + while zero_deg: + cur = zero_deg.popleft() + sorted_nodes.append(cur) + cur_id = cur.id + for nb_id in adj.get(cur_id, []): + indeg[nb_id] -= 1 + if indeg[nb_id] == 0: + zero_deg.append(id_to_node[nb_id]) + + if len(sorted_nodes) != len(nodes): + remaining = [nid for nid, deg in indeg.items() if deg > 0] raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." + f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}" ) - # semaphore for max_workers - sem = threading.Semaphore(self.max_workers) - done = {n: threading.Event() for n in name2op} - exc = {} - - def _exec(n: str): - with sem: - for d in name2op[n].deps: - done[d].wait() - if any(d in exc for d in name2op[n].deps): - exc[n] = Exception("Skipped due to failed dependencies") - done[n].set() - return - try: - name2op[n].func(name2op[n], ctx) - except Exception: - exc[n] = traceback.format_exc() - done[n].set() - - ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] - for t in ts: - t.start() - for t in ts: - t.join() - if exc: - raise RuntimeError( - "Some operations failed:\n" - + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) + return sorted_nodes + + def _get_input_dataset( + self, node: Node, initial_ds: ray.data.Dataset + ) -> ray.data.Dataset: + deps = node.dependencies + + if not deps: + return initial_ds + + if len(deps) == 1: + return self.datasets[deps[0]] + + main_ds = self.datasets[deps[0]] + other_dss = [self.datasets[d] for d in deps[1:]] + return main_ds.union(*other_dss) + + def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): + def _filter_kwargs( + func_or_class: Callable, + global_params: Dict[str, Any], + func_params: Dict[str, Any], + ) -> Dict[str, Any]: + """ + 1. global_params: only when specified in function signature, will be passed + 2. func_params: pass specified params first, then **kwargs if exists + """ + try: + sig = inspect.signature(func_or_class) + except ValueError: + return {} + + params = sig.parameters + final_kwargs = {} + + has_var_keywords = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + valid_keys = set(params.keys()) + for k, v in global_params.items(): + if k in valid_keys: + final_kwargs[k] = v + + for k, v in func_params.items(): + if k in valid_keys or has_var_keywords: + final_kwargs[k] = v + return final_kwargs + + if node.op_name not in self.functions: + raise ValueError(f"Operator {node.op_name} not found for node {node.id}") + + op_handler = self.functions[node.op_name] + node_params = _filter_kwargs(op_handler, self.global_params, node.params or {}) + + if node.type == "source": + self.datasets[node.id] = op_handler(**node_params) + return + + input_ds = self._get_input_dataset(node, initial_ds) + + if inspect.isclass(op_handler): + execution_params = node.execution_params or {} + replicas = execution_params.get("replicas", 1) + batch_size = ( + int(execution_params.get("batch_size")) + if "batch_size" in execution_params + else "default" ) + compute_resources = execution_params.get("compute_resources", {}) + + if node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), + batch_size=None, # aggregate processes the whole dataset at once + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) + else: + # others like map, filter, flatmap, map_batch let actors process data inside batches + self.datasets[node.id] = input_ds.map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), + batch_size=batch_size, + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) - @staticmethod - def _validate(ops: List[OpNode]): - name_set = set() - for op in ops: - if op.name in name_set: - raise ValueError(f"Duplicate operation name: {op.name}") - name_set.add(op.name) - for op in ops: - for dep in op.deps: - if dep not in name_set: - raise ValueError( - f"Operation {op.name} has unknown dependency: {dep}" - ) + else: + @wraps(op_handler) + def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: + return op_handler(row_or_batch, **node_params) + + if node.type == "map": + self.datasets[node.id] = input_ds.map(func_wrapper) + elif node.type == "filter": + self.datasets[node.id] = input_ds.filter(func_wrapper) + elif node.type == "flatmap": + self.datasets[node.id] = input_ds.flat_map(func_wrapper) + elif node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + func_wrapper, batch_format="default" + ) + elif node.type == "map_batch": + self.datasets[node.id] = input_ds.map_batches(func_wrapper) + else: + raise ValueError( + f"Unsupported node type {node.type} for node {node.id}" + ) -def collect_ops(config: dict, graph_gen) -> List[OpNode]: - """ - build operation nodes from yaml config - :param config - :param graph_gen - """ - ops: List[OpNode] = [] - for stage in config["pipeline"]: - name = stage["name"] - method_name = stage.get("op_key") - method = getattr(graph_gen, method_name) - deps = stage.get("deps", []) + @staticmethod + def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: + all_ids = {n.id for n in nodes} + deps_set = set() + for n in nodes: + deps_set.update(n.dependencies) + return all_ids - deps_set - if "params" in stage: + def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: + sorted_nodes = self._topo_sort(self.config.nodes) - def func(self, ctx, _method=method, _params=stage.get("params", {})): - return _method(_params) + for node in sorted_nodes: + self._execute_node(node, initial_ds) - else: + leaf_nodes = self._find_leaf_nodes(sorted_nodes) - def func(self, ctx, _method=method): - return _method() + @ray.remote + def _fetch_result(ds: ray.data.Dataset) -> List[Any]: + return ds.take_all() - op_node = OpNode(name=name, deps=deps, func=func) - ops.append(op_node) - return ops + return {node_id: self.datasets[node_id] for node_id in leaf_nodes} diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index bc7e7742..56e97469 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,295 +1,295 @@ -import os -import time -from typing import Dict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.datatypes import Chunk -from graphgen.models import ( - JsonKVStorage, - JsonListStorage, - NetworkXStorage, - OpenAIClient, - Tokenizer, -) -from graphgen.operators import ( - build_kg, - chunk_documents, - extract_info, - generate_qas, - init_llm, - judge_statement, - partition_kg, - quiz, - read_files, - search_all, -) -from graphgen.utils import async_to_sync_method, compute_mm_hash, logger - -sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - -class GraphGen: - def __init__( - self, - unique_id: int = int(time.time()), - working_dir: str = os.path.join(sys_path, "cache"), - tokenizer_instance: Tokenizer = None, - synthesizer_llm_client: OpenAIClient = None, - trainee_llm_client: OpenAIClient = None, - progress_bar: gr.Progress = None, - ): - self.unique_id: int = unique_id - self.working_dir: str = working_dir - - # llm - self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( - model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") - ) - - self.synthesizer_llm_client: BaseLLMWrapper = ( - synthesizer_llm_client or init_llm("synthesizer") - ) - self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client - - self.full_docs_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="full_docs" - ) - self.chunks_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="chunks" - ) - self.graph_storage: NetworkXStorage = NetworkXStorage( - self.working_dir, namespace="graph" - ) - self.rephrase_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="rephrase" - ) - self.partition_storage: JsonListStorage = JsonListStorage( - self.working_dir, namespace="partition" - ) - self.search_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="search", - ) - self.qa_storage: JsonListStorage = JsonListStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="qa", - ) - self.extract_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="extraction", - ) - - # webui - self.progress_bar: gr.Progress = progress_bar - - @async_to_sync_method - async def read(self, read_config: Dict): - """ - read files from input sources - """ - doc_stream = read_files(**read_config, cache_dir=self.working_dir) - - batch = {} - for doc in doc_stream: - doc_id = compute_mm_hash(doc, prefix="doc-") - batch[doc_id] = doc - - # TODO: configurable whether to use coreference resolution - - _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) - new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - self.full_docs_storage.upsert(new_docs) - self.full_docs_storage.index_done_callback() - - @async_to_sync_method - async def chunk(self, chunk_config: Dict): - """ - chunk documents into smaller pieces from full_docs_storage if not already present - """ - - new_docs = self.full_docs_storage.get_all() - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - inserting_chunks = await chunk_documents( - new_docs, - self.tokenizer_instance, - self.progress_bar, - **chunk_config, - ) - - _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - self.chunks_storage.upsert(inserting_chunks) - self.chunks_storage.index_done_callback() - - @async_to_sync_method - async def build_kg(self): - """ - build knowledge graph from text chunks - """ - # Step 1: get new chunks - inserting_chunks = self.chunks_storage.get_all() - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) - # Step 2: build knowledge graph from new chunks - _add_entities_and_relations = await build_kg( - llm_client=self.synthesizer_llm_client, - kg_instance=self.graph_storage, - chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], - progress_bar=self.progress_bar, - ) - if not _add_entities_and_relations: - logger.warning("No entities or relations extracted from text chunks") - return - - # Step 3: upsert new entities and relations to the graph storage - self.graph_storage.index_done_callback() - - return _add_entities_and_relations - - @async_to_sync_method - async def search(self, search_config: Dict): - logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - - seeds = self.full_docs_storage.get_all() - if len(seeds) == 0: - logger.warning("All documents are already been searched") - return - search_results = await search_all( - seed_data=seeds, - search_config=search_config, - ) - - _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) - search_results = { - k: v for k, v in search_results.items() if k in _add_search_keys - } - if len(search_results) == 0: - logger.warning("All search results are already in the storage") - return - self.search_storage.upsert(search_results) - self.search_storage.index_done_callback() - - @async_to_sync_method - async def quiz_and_judge(self, quiz_and_judge_config: Dict): - logger.warning( - "Quiz and Judge operation needs trainee LLM client." - " Make sure to provide one." - ) - max_samples = quiz_and_judge_config["quiz_samples"] - await quiz( - self.synthesizer_llm_client, - self.graph_storage, - self.rephrase_storage, - max_samples, - progress_bar=self.progress_bar, - ) - - # TODO: assert trainee_llm_client is valid before judge - if not self.trainee_llm_client: - # TODO: shutdown existing synthesizer_llm_client properly - logger.info("No trainee LLM client provided, initializing a new one.") - self.synthesizer_llm_client.shutdown() - self.trainee_llm_client = init_llm("trainee") - - re_judge = quiz_and_judge_config["re_judge"] - _update_relations = await judge_statement( - self.trainee_llm_client, - self.graph_storage, - self.rephrase_storage, - re_judge, - progress_bar=self.progress_bar, - ) - - self.rephrase_storage.index_done_callback() - _update_relations.index_done_callback() - - logger.info("Shutting down trainee LLM client.") - self.trainee_llm_client.shutdown() - self.trainee_llm_client = None - logger.info("Restarting synthesizer LLM client.") - self.synthesizer_llm_client.restart() - - @async_to_sync_method - async def partition(self, partition_config: Dict): - batches = await partition_kg( - self.graph_storage, - self.chunks_storage, - self.tokenizer_instance, - partition_config, - ) - self.partition_storage.upsert(batches) - return batches - - @async_to_sync_method - async def extract(self, extract_config: Dict): - logger.info("Extracting information from given chunks...") - - results = await extract_info( - self.synthesizer_llm_client, - self.chunks_storage, - extract_config, - progress_bar=self.progress_bar, - ) - if not results: - logger.warning("No information extracted") - return - - self.extract_storage.upsert(results) - self.extract_storage.index_done_callback() - - @async_to_sync_method - async def generate(self, generate_config: Dict): - - batches = self.partition_storage.data - if not batches: - logger.warning("No partitions found for QA generation") - return - - # Step 2: generate QA pairs - results = await generate_qas( - self.synthesizer_llm_client, - batches, - generate_config, - progress_bar=self.progress_bar, - ) - - if not results: - logger.warning("No QA pairs generated") - return - - # Step 3: store the generated QA pairs - self.qa_storage.upsert(results) - self.qa_storage.index_done_callback() - - @async_to_sync_method - async def clear(self): - self.full_docs_storage.drop() - self.chunks_storage.drop() - self.search_storage.drop() - self.graph_storage.clear() - self.rephrase_storage.drop() - self.qa_storage.drop() - - logger.info("All caches are cleared") - - # TODO: add data filtering step here in the future - # graph_gen.filter(filter_config=config["filter"]) +# import os +# import time +# from typing import Dict +# +# import gradio as gr +# +# from graphgen.bases import BaseLLMWrapper +# from graphgen.bases.datatypes import Chunk +# from graphgen.models import ( +# JsonKVStorage, +# JsonListStorage, +# NetworkXStorage, +# OpenAIClient, +# Tokenizer, +# ) +# from graphgen.operators import ( +# build_kg, +# chunk_documents, +# extract_info, +# generate_qas, +# init_llm, +# judge_statement, +# partition_kg, +# quiz, +# read_files, +# search_all, +# ) +# from graphgen.utils import async_to_sync_method, compute_mm_hash, logger +# +# sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +# +# +# class GraphGen: +# def __init__( +# self, +# unique_id: int = int(time.time()), +# working_dir: str = os.path.join(sys_path, "cache"), +# tokenizer_instance: Tokenizer = None, +# synthesizer_llm_client: OpenAIClient = None, +# trainee_llm_client: OpenAIClient = None, +# progress_bar: gr.Progress = None, +# ): +# self.unique_id: int = unique_id +# self.working_dir: str = working_dir +# +# # llm +# self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( +# model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") +# ) +# +# self.synthesizer_llm_client: BaseLLMWrapper = ( +# synthesizer_llm_client or init_llm("synthesizer") +# ) +# self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client +# +# self.full_docs_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="full_docs" +# ) +# self.chunks_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="chunks" +# ) +# self.graph_storage: NetworkXStorage = NetworkXStorage( +# self.working_dir, namespace="graph" +# ) +# self.rephrase_storage: JsonKVStorage = JsonKVStorage( +# self.working_dir, namespace="rephrase" +# ) +# self.partition_storage: JsonListStorage = JsonListStorage( +# self.working_dir, namespace="partition" +# ) +# self.search_storage: JsonKVStorage = JsonKVStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="search", +# ) +# self.qa_storage: JsonListStorage = JsonListStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="qa", +# ) +# self.extract_storage: JsonKVStorage = JsonKVStorage( +# os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), +# namespace="extraction", +# ) +# +# # webui +# self.progress_bar: gr.Progress = progress_bar +# +# @async_to_sync_method +# async def read(self, read_config: Dict): +# """ +# read files from input sources +# """ +# doc_stream = read_files(**read_config, cache_dir=self.working_dir) +# +# batch = {} +# for doc in doc_stream: +# doc_id = compute_mm_hash(doc, prefix="doc-") +# batch[doc_id] = doc +# +# # TODO: configurable whether to use coreference resolution +# +# _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) +# new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} +# if len(new_docs) == 0: +# logger.warning("All documents are already in the storage") +# return +# self.full_docs_storage.upsert(new_docs) +# self.full_docs_storage.index_done_callback() +# +# @async_to_sync_method +# async def chunk(self, chunk_config: Dict): +# """ +# chunk documents into smaller pieces from full_docs_storage if not already present +# """ +# +# new_docs = self.full_docs_storage.get_all() +# if len(new_docs) == 0: +# logger.warning("All documents are already in the storage") +# return +# +# inserting_chunks = await chunk_documents( +# new_docs, +# self.tokenizer_instance, +# self.progress_bar, +# **chunk_config, +# ) +# +# _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) +# inserting_chunks = { +# k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys +# } +# +# if len(inserting_chunks) == 0: +# logger.warning("All chunks are already in the storage") +# return +# +# self.chunks_storage.upsert(inserting_chunks) +# self.chunks_storage.index_done_callback() +# +# @async_to_sync_method +# async def build_kg(self): +# """ +# build knowledge graph from text chunks +# """ +# # Step 1: get new chunks +# inserting_chunks = self.chunks_storage.get_all() +# +# if len(inserting_chunks) == 0: +# logger.warning("All chunks are already in the storage") +# return +# +# logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) +# # Step 2: build knowledge graph from new chunks +# _add_entities_and_relations = await build_kg( +# llm_client=self.synthesizer_llm_client, +# kg_instance=self.graph_storage, +# chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], +# progress_bar=self.progress_bar, +# ) +# if not _add_entities_and_relations: +# logger.warning("No entities or relations extracted from text chunks") +# return +# +# # Step 3: upsert new entities and relations to the graph storage +# self.graph_storage.index_done_callback() +# +# return _add_entities_and_relations +# +# @async_to_sync_method +# async def search(self, search_config: Dict): +# logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) +# +# seeds = self.full_docs_storage.get_all() +# if len(seeds) == 0: +# logger.warning("All documents are already been searched") +# return +# search_results = await search_all( +# seed_data=seeds, +# search_config=search_config, +# ) +# +# _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) +# search_results = { +# k: v for k, v in search_results.items() if k in _add_search_keys +# } +# if len(search_results) == 0: +# logger.warning("All search results are already in the storage") +# return +# self.search_storage.upsert(search_results) +# self.search_storage.index_done_callback() +# +# @async_to_sync_method +# async def quiz_and_judge(self, quiz_and_judge_config: Dict): +# logger.warning( +# "Quiz and Judge operation needs trainee LLM client." +# " Make sure to provide one." +# ) +# max_samples = quiz_and_judge_config["quiz_samples"] +# await quiz( +# self.synthesizer_llm_client, +# self.graph_storage, +# self.rephrase_storage, +# max_samples, +# progress_bar=self.progress_bar, +# ) +# +# # TODO: assert trainee_llm_client is valid before judge +# if not self.trainee_llm_client: +# # TODO: shutdown existing synthesizer_llm_client properly +# logger.info("No trainee LLM client provided, initializing a new one.") +# self.synthesizer_llm_client.shutdown() +# self.trainee_llm_client = init_llm("trainee") +# +# re_judge = quiz_and_judge_config["re_judge"] +# _update_relations = await judge_statement( +# self.trainee_llm_client, +# self.graph_storage, +# self.rephrase_storage, +# re_judge, +# progress_bar=self.progress_bar, +# ) +# +# self.rephrase_storage.index_done_callback() +# _update_relations.index_done_callback() +# +# logger.info("Shutting down trainee LLM client.") +# self.trainee_llm_client.shutdown() +# self.trainee_llm_client = None +# logger.info("Restarting synthesizer LLM client.") +# self.synthesizer_llm_client.restart() +# +# @async_to_sync_method +# async def partition(self, partition_config: Dict): +# batches = await partition_kg( +# self.graph_storage, +# self.chunks_storage, +# self.tokenizer_instance, +# partition_config, +# ) +# self.partition_storage.upsert(batches) +# return batches +# +# @async_to_sync_method +# async def extract(self, extract_config: Dict): +# logger.info("Extracting information from given chunks...") +# +# results = await extract_info( +# self.synthesizer_llm_client, +# self.chunks_storage, +# extract_config, +# progress_bar=self.progress_bar, +# ) +# if not results: +# logger.warning("No information extracted") +# return +# +# self.extract_storage.upsert(results) +# self.extract_storage.index_done_callback() +# +# @async_to_sync_method +# async def generate(self, generate_config: Dict): +# +# batches = self.partition_storage.data +# if not batches: +# logger.warning("No partitions found for QA generation") +# return +# +# # Step 2: generate QA pairs +# results = await generate_qas( +# self.synthesizer_llm_client, +# batches, +# generate_config, +# progress_bar=self.progress_bar, +# ) +# +# if not results: +# logger.warning("No QA pairs generated") +# return +# +# # Step 3: store the generated QA pairs +# self.qa_storage.upsert(results) +# self.qa_storage.index_done_callback() +# +# @async_to_sync_method +# async def clear(self): +# self.full_docs_storage.drop() +# self.chunks_storage.drop() +# self.search_storage.drop() +# self.graph_storage.clear() +# self.rephrase_storage.drop() +# self.qa_storage.drop() +# +# logger.info("All caches are cleared") +# +# # TODO: add data filtering step here in the future +# # graph_gen.filter(filter_config=config["filter"]) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3ef1ff69..17a7216d 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -18,7 +18,6 @@ ) from .reader import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, @@ -33,5 +32,5 @@ from .searcher.web.bing_search import BingSearch from .searcher.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage, RocksDBCache +from .storage import JsonKVStorage, NetworkXStorage, RocksDBCache from .tokenizer import Tokenizer diff --git a/graphgen/models/extractor/schema_guided_extractor.py b/graphgen/models/extractor/schema_guided_extractor.py index 70c45502..74801946 100644 --- a/graphgen/models/extractor/schema_guided_extractor.py +++ b/graphgen/models/extractor/schema_guided_extractor.py @@ -60,8 +60,8 @@ def build_prompt(self, text: str) -> str: return prompt async def extract(self, chunk: dict) -> dict: - _chunk_id = list(chunk.keys())[0] - text = chunk[_chunk_id].get("content", "") + _chunk_id = chunk.get("_chunk_id", "") + text = chunk.get("content", "") prompt = self.build_prompt(text) response = await self.llm_client.generate_answer(prompt) @@ -88,9 +88,7 @@ async def extract(self, chunk: dict) -> dict: return {} @staticmethod - async def merge_extractions( - extraction_list: List[Dict[str, dict]] - ) -> Dict[str, dict]: + def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]: """ Merge multiple extraction results based on their hashes. :param extraction_list: List of extraction results, each is a dict with hash as key and record as value. diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index eefbdd1c..91b44862 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -77,8 +77,8 @@ async def generate( nodes, _ = batch for node in nodes: node_data = node[1] - if "images" in node_data and node_data["images"]: - img_path = node_data["images"]["img_path"] + if "image_data" in node_data and node_data["image_data"]: + img_path = node_data["image_data"]["img_path"] for qa in qa_pairs.values(): qa["img_path"] = img_path result.update(qa_pairs) diff --git a/graphgen/models/partitioner/anchor_bfs_partitioner.py b/graphgen/models/partitioner/anchor_bfs_partitioner.py index 6cc1400c..09133af7 100644 --- a/graphgen/models/partitioner/anchor_bfs_partitioner.py +++ b/graphgen/models/partitioner/anchor_bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List, Literal, Set, Tuple +from typing import Any, Iterable, List, Literal, Set, Tuple from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -30,24 +30,23 @@ def __init__( self.anchor_type = anchor_type self.anchor_ids = anchor_ids - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() # List[tuple[id, meta]] edges = g.get_all_edges() # List[tuple[u, v, meta]] adj, _ = self._build_adjacency_list(nodes, edges) - anchors: Set[str] = await self._pick_anchor_ids(nodes) + anchors: Set[str] = self._pick_anchor_ids(nodes) if not anchors: - return [] # if no anchors, return empty list + return # if no anchors, return nothing used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] seeds = list(anchors) random.shuffle(seeds) @@ -55,17 +54,13 @@ async def partition( for seed_node in seeds: if seed_node in used_n: continue - comm_n, comm_e = await self._grow_community( + comm_n, comm_e = self._grow_community( seed_node, adj, max_units_per_community, used_n, used_e ) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) + yield Community(id=seed_node, nodes=comm_n, edges=comm_e) - return communities - - async def _pick_anchor_ids( + def _pick_anchor_ids( self, nodes: List[tuple[str, dict]], ) -> Set[str]: @@ -80,7 +75,7 @@ async def _pick_anchor_ids( return anchor_ids @staticmethod - async def _grow_community( + def _grow_community( seed: str, adj: dict[str, List[str]], max_units: int, diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 00895712..994e08e8 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List +from typing import Any, Iterable, List from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community @@ -17,12 +17,12 @@ class BFSPartitioner(BasePartitioner): (A unit is a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -30,7 +30,6 @@ async def partition( used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -74,8 +73,4 @@ async def partition( queue.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index 6c394b10..36305842 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,4 +1,5 @@ import random +from collections.abc import Iterable from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -16,12 +17,12 @@ class DFSPartitioner(BasePartitioner): (In GraphGen, a unit is defined as a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -29,7 +30,6 @@ async def partition( used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -71,8 +71,4 @@ async def partition( stack.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 7de73181..fcf776c7 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,8 +1,8 @@ -import asyncio import random -from typing import Any, Dict, List, Optional, Set, Tuple +from collections import deque +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple -from tqdm.asyncio import tqdm as tqdm_async +from tqdm import tqdm from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -51,7 +51,7 @@ def _sort_units(units: list, edge_sampling: str) -> list: raise ValueError(f"Invalid edge sampling: {edge_sampling}") return units - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 10, @@ -59,7 +59,7 @@ async def partition( max_tokens_per_community: int = 10240, unit_sampling: str = "random", **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes: List[Tuple[str, dict]] = g.get_all_nodes() edges: List[Tuple[str, str, dict]] = g.get_all_edges() @@ -73,21 +73,18 @@ async def partition( used_n: Set[str] = set() used_e: Set[frozenset[str]] = set() - communities: List = [] all_units = self._sort_units(all_units, unit_sampling) - async def _grow_community( - seed_unit: Tuple[str, Any, dict] - ) -> Optional[Community]: + def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]: nonlocal used_n, used_e community_nodes: Dict[str, dict] = {} community_edges: Dict[frozenset[str], dict] = {} - queue: asyncio.Queue = asyncio.Queue() + queue = deque() token_sum = 0 - async def _add_unit(u): + def _add_unit(u): nonlocal token_sum t, i, d = u if t == NODE_UNIT: # node @@ -103,11 +100,11 @@ async def _add_unit(u): token_sum += d.get("length", 0) return True - await _add_unit(seed_unit) - await queue.put(seed_unit) + _add_unit(seed_unit) + queue.append(seed_unit) # BFS - while not queue.empty(): + while queue: if ( len(community_nodes) + len(community_edges) >= max_units_per_community @@ -115,7 +112,7 @@ async def _add_unit(u): ): break - cur_type, cur_id, _ = await queue.get() + cur_type, cur_id, _ = queue.popleft() neighbors: List[Tuple[str, Any, dict]] = [] if cur_type == NODE_UNIT: @@ -136,26 +133,24 @@ async def _add_unit(u): or token_sum >= max_tokens_per_community ): break - if await _add_unit(nb): - await queue.put(nb) + if _add_unit(nb): + queue.append(nb) if len(community_nodes) + len(community_edges) < min_units_per_community: return None return Community( - id=len(communities), + id=seed_unit[1], nodes=list(community_nodes.keys()), edges=[(u, v) for (u, v), _ in community_edges.items()], ) - async for unit in tqdm_async(all_units, desc="ECE partition"): + for unit in tqdm(all_units, desc="ECE partition"): utype, uid, _ = unit if (utype == NODE_UNIT and uid in used_n) or ( utype == EDGE_UNIT and uid in used_e ): continue - comm = await _grow_community(unit) - if comm is not None: - communities.append(comm) - - return communities + comm = _grow_community(unit) + if comm: + yield comm diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index 1f85789b..b62b8544 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner): Leiden partitioner that partitions the graph into communities using the Leiden algorithm. """ - async def partition( + def partition( self, g: BaseGraphStorage, max_size: int = 20, @@ -37,12 +37,10 @@ async def partition( nodes = g.get_all_nodes() # List[Tuple[str, dict]] edges = g.get_all_edges() # List[Tuple[str, str, dict]] - node2cid: Dict[str, int] = await self._run_leiden( - nodes, edges, use_lcc, random_seed - ) + node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed) if max_size is not None and max_size > 0: - node2cid = await self._split_communities(node2cid, max_size) + node2cid = self._split_communities(node2cid, max_size) cid2nodes: Dict[int, List[str]] = defaultdict(list) for n, cid in node2cid.items(): @@ -58,7 +56,7 @@ async def partition( return communities @staticmethod - async def _run_leiden( + def _run_leiden( nodes: List[Tuple[str, dict]], edges: List[Tuple[str, str, dict]], use_lcc: bool = False, @@ -92,9 +90,7 @@ async def _run_leiden( return node2cid @staticmethod - async def _split_communities( - node2cid: Dict[str, int], max_size: int - ) -> Dict[str, int]: + def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]: """ Split communities larger than max_size into smaller sub-communities. """ diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index 600ffb4a..220460c3 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -1,6 +1,5 @@ from .csv_reader import CSVReader from .json_reader import JSONReader -from .jsonl_reader import JSONLReader from .parquet_reader import ParquetReader from .pdf_reader import PDFReader from .pickle_reader import PickleReader diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index bc865a3b..a0343d97 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,13 +14,15 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: + """ + Read CSV files and return Ray Dataset. - df = pd.read_csv(file_path) - for _, row in df.iterrows(): - assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}" - if row["type"] == "text" and self.text_column not in row: - raise ValueError( - f"Missing '{self.text_column}' in document: {row.to_dict()}" - ) - return self.filter(df.to_dict(orient="records")) + :param input_path: Path to CSV file or list of CSV files. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_csv(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 8253041c..6752e042 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,26 +1,53 @@ import json -from typing import Any, Dict, List +from typing import List, Union + +import ray +import ray.data from graphgen.bases.base_reader import BaseReader class JSONReader(BaseReader): """ - Reader for JSON files. + Reader for JSON and JSONL files. Columns: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - return self.filter(data) - raise ValueError("JSON file must contain a list of documents.") + def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: + """ + Read JSON file and return Ray Dataset. + :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. + :return: Ray Dataset containing validated and filtered data. + """ + if self.modalities and len(self.modalities) >= 2: + ds: ray.data.Dataset = ray.data.from_items([]) + for file in input_path if isinstance(input_path, list) else [input_path]: + data = [] + if file.endswith(".jsonl"): + with open(file, "r", encoding="utf-8") as f: + for line in f: + item = json.loads(line) + data.append(item) + else: + with open(file, "r", encoding="utf-8") as f: + data = json.load(f) + data = self._unify_schema(data) + file_ds: ray.data.Dataset = ray.data.from_items(data) + ds = ds.union(file_ds) # type: ignore + else: + ds = ray.data.read_json(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds + + @staticmethod + def _unify_schema(data): + """ + Unify schema for JSON data. + """ + for item in data: + if "content" in item and isinstance(item["content"], dict): + item["content"] = json.dumps(item["content"]) + return data diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py deleted file mode 100644 index 31bc3195..00000000 --- a/graphgen/models/reader/jsonl_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -from typing import Any, Dict, List - -from graphgen.bases.base_reader import BaseReader -from graphgen.utils import logger - - -class JSONLReader(BaseReader): - """ - Reader for JSONL files. - Columns: - - type: The type of the document (e.g., "text", "image", etc.) - - if type is "text", "content" column must be present. - """ - - def read(self, file_path: str) -> List[Dict[str, Any]]: - docs = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - try: - doc = json.loads(line) - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - docs.append(doc) - except json.JSONDecodeError as e: - logger.error("Error decoding JSON line: %s. Error: %s", line, e) - return self.filter(docs) diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index a325b876..dd289e31 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,12 +14,17 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - df = pd.read_parquet(file_path) - data: List[Dict[str, Any]] = df.to_dict(orient="records") + def read(self, input_path: Union[str, List[str]]) -> Dataset: + """ + Read Parquet files using Ray Data. - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") - return self.filter(data) + :param input_path: Path to Parquet file or list of Parquet files. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + ds = ray.data.read_parquet(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 94562cb5..55dab30b 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -5,6 +5,9 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray +from ray.data import Dataset + from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device @@ -62,19 +65,31 @@ def __init__( self.parser = MinerUParser() self.txt_reader = TXTReader() - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: - """ - file_path - **override: override MinerU parameters - """ - pdf_path = Path(file_path).expanduser().resolve() - if not pdf_path.is_file(): - raise FileNotFoundError(pdf_path) + def read( + self, + input_path: Union[str, List[str]], + **override, + ) -> Dataset: + + # Ensure input_path is a list + if isinstance(input_path, str): + input_path = [input_path] + + paths_ds = ray.data.from_items(input_path) + + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + pdf_path = row["item"] + kwargs = {**self._default_kwargs, **override} + return self._call_mineru(Path(pdf_path), kwargs) + except Exception as e: + logger.error("Failed to process %s: %s", row, e) + return [] - kwargs = {**self._default_kwargs, **override} + docs_ds = paths_ds.flat_map(process_pdf) + docs_ds = docs_ds.filter(self._should_keep_item) - mineru_result = self._call_mineru(pdf_path, kwargs) - return self.filter(mineru_result) + return docs_ds def _call_mineru( self, pdf_path: Path, kwargs: Dict[str, Any] @@ -161,18 +176,18 @@ def _try_load_cached_result( base = os.path.dirname(json_file) results = [] - for item in data: + for it in data: for key in ("img_path", "table_img_path", "equation_img_path"): - rel_path = item.get(key) + rel_path = it.get(key) if rel_path: - item[key] = str(Path(base).joinpath(rel_path).resolve()) - if item["type"] == "text": - item["content"] = item["text"] - del item["text"] + it[key] = str(Path(base).joinpath(rel_path).resolve()) + if it["type"] == "text": + it["content"] = it["text"] + del it["text"] for key in ("page_idx", "bbox", "text_level"): - if item.get(key) is not None: - del item[key] - results.append(item) + if it.get(key) is not None: + del it[key] + results.append(it) return results @staticmethod diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1a11dc11..6e3d1949 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,30 +1,78 @@ import pickle -from typing import Any, Dict, List +from typing import List, Union + +import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class PickleReader(BaseReader): """ - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. - - Columns: + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. + Each pickle file should contain a list of dictionaries with at least: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. + + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. + For Ray >= 2.5, consider using read_pickle if available in your version. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "rb") as f: - data = pickle.load(f) + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read Pickle files using Ray Data. + + :param input_path: Path to pickle file or list of pickle files. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + # Use read_binary_files as a reliable alternative to read_pickle + ds = ray.data.read_binary_files(input_path, include_paths=True) + + # Deserialize pickle files and flatten into individual records + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + all_records = [] + for _, row in batch.iterrows(): + try: + # Load pickle data from bytes + data = pickle.loads(row["bytes"]) + + # Validate structure + if not isinstance(data, list): + logger.error( + "Pickle file {row['path']} must contain a list, got {type(data)}" + ) + continue + + if not all(isinstance(item, dict) for item in data): + logger.error( + "Pickle file {row['path']} must contain a list of dictionaries" + ) + continue + + # Flatten: each dict in the list becomes a separate row + all_records.extend(data) + except Exception as e: + logger.error( + "Failed to deserialize pickle file %s: %s", row["path"], str(e) + ) + continue + + return pd.DataFrame(all_records) - if not isinstance(data, list): - raise ValueError("Pickle file must contain a list of documents.") + # Apply deserialization and flattening + ds = ds.map_batches(deserialize_batch, batch_format="pandas") - for doc in data: - if not isinstance(doc, dict): - raise ValueError("Every item in the list must be a dict.") - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") + # Validate the schema + ds = ds.map_batches(self._validate_batch, batch_format="pandas") - return self.filter(data) + # Filter valid items + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index cce167c1..9670107a 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,48 +1,128 @@ -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Union +import ray import rdflib +from ray.data import Dataset from rdflib import Literal from rdflib.util import guess_format from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class RDFReader(BaseReader): """ Reader for RDF files that extracts triples and represents them as dictionaries. + + Uses Ray Data for distributed processing of multiple RDF files. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def __init__(self, *, text_column: str = "content", **kwargs): + """ + Initialize RDFReader. + + :param text_column: The column name for text content (default: "content"). + """ + super().__init__(**kwargs) + self.text_column = text_column + + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read RDF file(s) using Ray Data. + + :param input_path: Path to RDF file or list of RDF files. + :return: Ray Dataset containing extracted documents. + """ + if not ray.is_initialized(): + ray.init() + + # Ensure input_path is a list to prevent Ray from splitting string into characters + if isinstance(input_path, str): + input_path = [input_path] + + # Create dataset from file paths + paths_ds = ray.data.from_items(input_path) + + def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single RDF file and return list of documents.""" + try: + file_path = row["item"] + return self._parse_rdf_file(Path(file_path)) + except Exception as e: + logger.error( + "Failed to process RDF file %s: %s", row.get("item", "unknown"), e + ) + return [] + + # Process files in parallel and flatten results + docs_ds = paths_ds.flat_map(process_rdf) + + # Filter valid documents + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds + + def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: + """ + Parse a single RDF file and extract documents. + + :param file_path: Path to RDF file. + :return: List of document dictionaries. + """ + if not file_path.is_file(): + raise FileNotFoundError(f"RDF file not found: {file_path}") + g = rdflib.Graph() - fmt = guess_format(file_path) + fmt = guess_format(str(file_path)) + try: - g.parse(file_path, format=fmt) + g.parse(str(file_path), format=fmt) except Exception as e: raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e docs: List[Dict[str, Any]] = [] - text_col = self.text_column + # Process each unique subject in the RDF graph for subj in set(g.subjects()): literals = [] props = {} + + # Extract all triples for this subject for _, pred, obj in g.triples((subj, None, None)): pred_str = str(pred) + obj_str = str(obj) + + # Collect literal values as text content if isinstance(obj, Literal): - literals.append(str(obj)) - props.setdefault(pred_str, []).append(str(obj)) + literals.append(obj_str) + + # Store all properties (including non-literals) + props.setdefault(pred_str, []).append(obj_str) + # Join all literal values as the text content text = " ".join(literals).strip() if not text: - raise ValueError( - f"Subject {subj} has no literal values; " - f"missing '{text_col}' for text column." + logger.warning( + "Subject %s in %s has no literal values; document will have empty '%s' field.", + subj, + file_path, + self.text_column, ) - doc = {"id": str(subj), text_col: text, "properties": props} + # Create document dictionary + doc = { + "id": str(subj), + self.text_column: text, + "properties": props, + "source_file": str(file_path), + } docs.append(doc) if not docs: - raise ValueError("RDF file contains no valid documents.") + logger.warning("RDF file %s contains no valid documents.", file_path) - return self.filter(docs) + return docs diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index ec2ff747..51a47de2 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,10 +1,32 @@ -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class TXTReader(BaseReader): - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - docs = [{"type": "text", self.text_column: f.read()}] - return self.filter(docs) + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read text files from the specified input path. + :param input_path: Path to the input text file or list of text files. + :return: Ray Dataset containing the read text data. + """ + docs_ds = ray.data.read_binary_files( + input_path, + include_paths=False, + ) + + docs_ds = docs_ds.map( + lambda row: { + "type": "text", + self.text_column: row["bytes"].decode("utf-8"), + } + ) + + docs_ds = docs_ds.filter(self._should_keep_item) + return docs_ds diff --git a/graphgen/models/splitter/character_splitter.py b/graphgen/models/splitter/character_splitter.py index 1c91877e..8877c861 100644 --- a/graphgen/models/splitter/character_splitter.py +++ b/graphgen/models/splitter/character_splitter.py @@ -17,7 +17,7 @@ def __init__( def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. + # First we naively chunk the large input into a bunch of smaller ones. separator = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) diff --git a/graphgen/models/splitter/markdown_splitter.py b/graphgen/models/splitter/markdown_splitter.py index 03def6ae..40b6a44e 100644 --- a/graphgen/models/splitter/markdown_splitter.py +++ b/graphgen/models/splitter/markdown_splitter.py @@ -6,12 +6,12 @@ class MarkdownTextRefSplitter(RecursiveCharacterSplitter): - """Attempts to split the text along Markdown-formatted headings.""" + """Attempts to chunk the text along Markdown-formatted headings.""" def __init__(self, **kwargs: Any) -> None: """Initialize a MarkdownTextRefSplitter.""" separators = [ - # First, try to split along Markdown headings (starting with level 2) + # First, try to chunk along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 diff --git a/graphgen/models/splitter/recursive_character_splitter.py b/graphgen/models/splitter/recursive_character_splitter.py index c9d7c543..b1ee8e06 100644 --- a/graphgen/models/splitter/recursive_character_splitter.py +++ b/graphgen/models/splitter/recursive_character_splitter.py @@ -7,7 +7,7 @@ class RecursiveCharacterSplitter(BaseSplitter): """Splitting text by recursively look at characters. - Recursively tries to split by different characters to find one that works. + Recursively tries to chunk by different characters to find one that works. """ def __init__( @@ -88,7 +88,7 @@ def __init__( def _split_text_with_regex_from_end( self, text: str, separator: str, keep_separator: bool ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index 1e8f8341..0f8d9eeb 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -1,3 +1,4 @@ -from .json_storage import JsonKVStorage, JsonListStorage -from .networkx_storage import NetworkXStorage +from graphgen.models.storage.graph.networkx_storage import NetworkXStorage +from graphgen.models.storage.kv.json_storage import JsonKVStorage + from .rocksdb_cache import RocksDBCache diff --git a/graphgen/configs/__init__.py b/graphgen/models/storage/graph/__init__.py similarity index 100% rename from graphgen/configs/__init__.py rename to graphgen/models/storage/graph/__init__.py diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py similarity index 85% rename from graphgen/models/storage/networkx_storage.py rename to graphgen/models/storage/graph/networkx_storage.py index 36bf1b5e..7fb73b79 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -6,7 +6,6 @@ import networkx as nx from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.utils import logger @dataclass @@ -19,11 +18,6 @@ def load_nx_graph(file_name) -> Optional[nx.Graph]: @staticmethod def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - "Writing graph with %d nodes, %d edges", - graph.number_of_nodes(), - graph.number_of_edges(), - ) nx.write_graphml(graph, file_name) @staticmethod @@ -82,12 +76,11 @@ def __post_init__(self): self.working_dir, f"{self.namespace}.graphml" ) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - "Loaded graph from %s with %d nodes, %d edges", - self._graphml_xml_file, - preloaded_graph.number_of_nodes(), - preloaded_graph.number_of_edges(), + if preloaded_graph: + print( + f"Loaded graph from {self._graphml_xml_file} with " + f"{preloaded_graph.number_of_nodes()} nodes, " + f"{preloaded_graph.number_of_edges()} edges" ) self._graph = preloaded_graph or nx.Graph() @@ -133,7 +126,7 @@ def update_node(self, node_id: str, node_data: dict[str, str]): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: - logger.warning("Node %s not found in the graph for update.", node_id) + print(f"Node {node_id} not found in the graph for update.") def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] @@ -146,10 +139,8 @@ def update_edge( if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) else: - logger.warning( - "Edge %s -> %s not found in the graph for update.", - source_node_id, - target_node_id, + print( + f"Edge {source_node_id} -> {target_node_id} not found in the graph for update." ) def delete_node(self, node_id: str): @@ -160,13 +151,19 @@ def delete_node(self, node_id: str): """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) - logger.info("Node %s deleted from the graph.", node_id) + print(f"Node {node_id} deleted from the graph.") else: - logger.warning("Node %s not found in the graph for deletion.", node_id) + print(f"Node {node_id} not found in the graph for deletion.") def clear(self): """ Clear the graph by removing all nodes and edges. """ self._graph.clear() - logger.info("Graph %s cleared.", self.namespace) + print(f"Graph {self.namespace} cleared.") + + def reload(self): + """ + Reload the graph from the GraphML file. + """ + self.__post_init__() diff --git a/graphgen/models/storage/kv/__init__.py b/graphgen/models/storage/kv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/kv/json_storage.py similarity index 53% rename from graphgen/models/storage/json_storage.py rename to graphgen/models/storage/kv/json_storage.py index 53962117..aa7c6f42 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/kv/json_storage.py @@ -1,8 +1,8 @@ import os from dataclasses import dataclass -from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage -from graphgen.utils import load_json, logger, write_json +from graphgen.bases.base_storage import BaseKVStorage +from graphgen.utils import load_json, write_json @dataclass @@ -12,7 +12,7 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") self._data = load_json(self._file_name) or {} - logger.info("Load KV %s with %d data", self.namespace, len(self._data)) + print(f"Load KV {self.namespace} with {len(self._data)} data") @property def data(self): @@ -55,40 +55,6 @@ def drop(self): if self._data: self._data.clear() - -@dataclass -class JsonListStorage(BaseListStorage): - working_dir: str = None - namespace: str = None - _data: list = None - - def __post_init__(self): - self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") - self._data = load_json(self._file_name) or [] - logger.info("Load List %s with %d data", self.namespace, len(self._data)) - - @property - def data(self): - return self._data - - def all_items(self) -> list: - return self._data - - def index_done_callback(self): - write_json(self._data, self._file_name) - - def get_by_index(self, index: int): - if index < 0 or index >= len(self._data): - return None - return self._data[index] - - def append(self, data): - self._data.append(data) - - def upsert(self, data: list): - left_data = [d for d in data if d not in self._data] - self._data.extend(left_data) - return left_data - - def drop(self): - self._data = [] + def reload(self): + self._data = load_json(self._file_name) or {} + print(f"Reload KV {self.namespace} with {len(self._data)} data") diff --git a/graphgen/models/storage/kv/rocksdb_storage.py b/graphgen/models/storage/kv/rocksdb_storage.py new file mode 100644 index 00000000..0cbe1145 --- /dev/null +++ b/graphgen/models/storage/kv/rocksdb_storage.py @@ -0,0 +1,79 @@ +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Set + +# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it +# pylint: disable=no-name-in-module +from rocksdict import Rdict + +from graphgen.bases.base_storage import BaseKVStorage +from graphgen.utils import logger + + +@dataclass +class RocksDBKVStorage(BaseKVStorage): + _db: Rdict = None + _db_path: str = None + + def __post_init__(self): + self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db") + self._db = Rdict(self._db_path) + logger.info("Load KV (RocksDB) %s at %s", self.namespace, self._db_path) + + @property + def data(self): + return self._db + + def all_keys(self) -> List[str]: + return list(self._db.keys()) + + def index_done_callback(self): + self._db.flush() + logger.info("RocksDB flushed for %s", self.namespace) + + def get_by_id(self, id: str) -> Any: + return self._db.get(id, None) + + def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]: + result = [] + for index in ids: + item = self._db.get(index, None) + if item is None: + result.append(None) + continue + + if fields is None: + result.append(item) + else: + result.append({k: v for k, v in item.items() if k in fields}) + return result + + def get_all(self) -> Dict[str, Dict]: + return dict(self._db) + + def filter_keys(self, data: List[str]) -> Set[str]: + return {s for s in data if s not in self._db} + + def upsert(self, data: Dict[str, Any]): + left_data = {} + for k, v in data.items(): + if k not in self._db: + left_data[k] = v + + if left_data: + for k, v in left_data.items(): + self._db[k] = v + + # if left_data is very large, it is recommended to use self._db.write_batch() for optimization + + return left_data + + def drop(self): + self._db.close() + Rdict.destroy(self._db_path) + self._db = Rdict(self._db_path) + logger.info("Dropped RocksDB %s", self.namespace) + + def close(self): + if self._db: + self._db.close() diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 97f4b3c8..64c78af5 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,9 +1,21 @@ -from .build_kg import build_kg -from .extract import extract_info -from .generate import generate_qas -from .init import init_llm -from .partition import partition_kg -from .quiz_and_judge import judge_statement, quiz -from .read import read_files +from .build_kg import BuildKGService +from .chunk import ChunkService +from .extract import ExtractService +from .generate import GenerateService +from .judge import JudgeService +from .partition import PartitionService +from .quiz import QuizService +from .read import read from .search import search_all -from .split import chunk_documents + +operators = { + "read": read, + "chunk": ChunkService, + "build_kg": BuildKGService, + "quiz": QuizService, + "judge": JudgeService, + "extract": ExtractService, + "search": search_all, + "partition": PartitionService, + "generate": GenerateService, +} diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index 18766fe6..a8b22ce9 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -1 +1 @@ -from .build_kg import build_kg +from .build_kg_service import BuildKGService diff --git a/graphgen/operators/build_kg/build_kg.py b/graphgen/operators/build_kg/build_kg.py deleted file mode 100644 index a8a6146d..00000000 --- a/graphgen/operators/build_kg/build_kg.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import List - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger - -from .build_mm_kg import build_mm_kg -from .build_text_kg import build_text_kg - - -async def build_kg( - llm_client: BaseLLMWrapper, - kg_instance: BaseGraphStorage, - chunks: List[Chunk], - progress_bar: gr.Progress = None, -): - """ - Build knowledge graph (KG) and merge into kg_instance - :param llm_client: Synthesizer LLM model to extract entities and relationships - :param kg_instance - :param chunks - :param anchor_type: get this type of information from chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction - :return: - """ - - text_chunks = [chunk for chunk in chunks if chunk.type == "text"] - mm_chunks = [ - chunk - for chunk in chunks - if chunk.type in ("image", "video", "table", "formula") - ] - - if len(text_chunks) == 0: - logger.info("All text chunks are already in the storage") - else: - logger.info("[Text Entity and Relation Extraction] processing ...") - await build_text_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=text_chunks, - progress_bar=progress_bar, - ) - - if len(mm_chunks) == 0: - logger.info("All multi-modal chunks are already in the storage") - else: - logger.info("[Multi-modal Entity and Relation Extraction] processing ...") - await build_mm_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=mm_chunks, - progress_bar=progress_bar, - ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py new file mode 100644 index 00000000..0ee54a80 --- /dev/null +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -0,0 +1,60 @@ +from typing import List + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator +from graphgen.bases.datatypes import Chunk +from graphgen.common import init_llm, init_storage +from graphgen.utils import logger + +from .build_mm_kg import build_mm_kg +from .build_text_kg import build_text_kg + + +class BuildKGService(BaseOperator): + def __init__(self, working_dir: str = "cache"): + super().__init__(working_dir=working_dir, op_name="build_kg_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", working_dir=working_dir, namespace="graph" + ) + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] + + # consume the chunks and build kg + self.build_kg(docs) + return pd.DataFrame([{"status": "kg_building_completed"}]) + + def build_kg(self, chunks: List[Chunk]) -> None: + """ + Build knowledge graph (KG) and merge into kg_instance + """ + text_chunks = [chunk for chunk in chunks if chunk.type == "text"] + mm_chunks = [ + chunk + for chunk in chunks + if chunk.type in ("image", "video", "table", "formula") + ] + + if len(text_chunks) == 0: + logger.info("All text chunks are already in the storage") + else: + logger.info("[Text Entity and Relation Extraction] processing ...") + build_text_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=text_chunks, + ) + if len(mm_chunks) == 0: + logger.info("All multi-modal chunks are already in the storage") + else: + logger.info("[Multi-modal Entity and Relation Extraction] processing ...") + build_mm_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=mm_chunks, + ) + + self.graph_storage.index_done_callback() diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py index 624b10ad..ee0459ea 100644 --- a/graphgen/operators/build_kg/build_mm_kg.py +++ b/graphgen/operators/build_kg/build_mm_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.utils import run_concurrent -async def build_mm_kg( +def build_mm_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ Build multi-modal KG and merge into kg_instance :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ mm_builder = MMKGBuilder(llm_client=llm_client) - results = await run_concurrent( + results = run_concurrent( mm_builder.extract, chunks, desc="[2/4] Extracting entities and relationships from multi-modal chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_mm_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 3c75f022..1b5a8762 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.utils import run_concurrent -async def build_text_kg( +def build_text_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) - results = await run_concurrent( + results = run_concurrent( kg_builder.extract, chunks, desc="[2/4]Extracting entities and relationships from chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_text_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/chunk/__init__.py b/graphgen/operators/chunk/__init__.py new file mode 100644 index 00000000..f2f116f7 --- /dev/null +++ b/graphgen/operators/chunk/__init__.py @@ -0,0 +1 @@ +from .chunk_service import ChunkService diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py new file mode 100644 index 00000000..abd72e54 --- /dev/null +++ b/graphgen/operators/chunk/chunk_service.py @@ -0,0 +1,101 @@ +import os +from functools import lru_cache +from typing import Union + +import pandas as pd + +from graphgen.bases import BaseOperator +from graphgen.common import init_storage +from graphgen.models import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, + Tokenizer, +) +from graphgen.utils import compute_content_hash, detect_main_language + +_MAPPING = { + "en": RecursiveCharacterSplitter, + "zh": ChineseRecursiveTextSplitter, +} + +SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] + + +@lru_cache(maxsize=None) +def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: + cls = _MAPPING[language] + kwargs = dict(frozen_kwargs) + return cls(**kwargs) + + +def split_chunks(text: str, language: str = "en", **kwargs) -> list: + if language not in _MAPPING: + raise ValueError( + f"Unsupported language: {language}. " + f"Supported languages are: {list(_MAPPING.keys())}" + ) + frozen_kwargs = frozenset( + (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() + ) + splitter = _get_splitter(language, frozen_kwargs) + return splitter.split_text(text) + + +class ChunkService(BaseOperator): + def __init__(self, working_dir: str = "cache", **chunk_kwargs): + super().__init__(working_dir=working_dir, op_name="chunk_service") + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) + self.chunk_storage = init_storage( + backend="json_kv", + working_dir=working_dir, + namespace="chunk", + ) + self.chunk_kwargs = chunk_kwargs + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + return pd.DataFrame(self.chunk_documents(docs)) + + def chunk_documents(self, new_docs: list) -> list: + chunks = [] + for doc in new_docs: + doc_id = doc.get("_doc_id") + doc_type = doc.get("type") + + if doc_type == "text": + doc_language = detect_main_language(doc["content"]) + text_chunks = split_chunks( + doc["content"], + language=doc_language, + **self.chunk_kwargs, + ) + + chunks.extend( + [ + { + "_chunk_id": compute_content_hash( + chunk_text, prefix="chunk-" + ), + "content": chunk_text, + "type": "text", + "_doc_id": doc_id, + "length": len(self.tokenizer_instance.encode(chunk_text)) + if self.tokenizer_instance + else len(chunk_text), + "language": doc_language, + } + for chunk_text in text_chunks + ] + ) + else: + # other types of documents(images, sequences) are not chunked + chunks.append( + { + "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), + **doc, + } + ) + self.chunk_storage.upsert({chunk["_chunk_id"]: chunk for chunk in chunks}) + self.chunk_storage.index_done_callback() + return chunks diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/evaluate.py b/graphgen/operators/evaluate/evaluate.py similarity index 97% rename from graphgen/evaluate.py rename to graphgen/operators/evaluate/evaluate.py index d1e2413b..fdbfbf82 100644 --- a/graphgen/evaluate.py +++ b/graphgen/operators/evaluate/evaluate.py @@ -9,9 +9,13 @@ from dotenv import load_dotenv from graphgen.bases.datatypes import QAPair - -from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator -from .utils import logger, set_logger +from graphgen.models import ( + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + UniEvaluator, +) +from graphgen.utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) diff --git a/graphgen/operators/extract/__init__.py b/graphgen/operators/extract/__init__.py index ec576cb6..6c7c2b94 100644 --- a/graphgen/operators/extract/__init__.py +++ b/graphgen/operators/extract/__init__.py @@ -1 +1 @@ -from .extract_info import extract_info +from .extract_service import ExtractService diff --git a/graphgen/operators/extract/extract_info.py b/graphgen/operators/extract/extract_info.py deleted file mode 100644 index 8e65f1b2..00000000 --- a/graphgen/operators/extract/extract_info.py +++ /dev/null @@ -1,47 +0,0 @@ -import json - -import gradio as gr - -from graphgen.bases import BaseKVStorage, BaseLLMWrapper -from graphgen.models.extractor import SchemaGuidedExtractor -from graphgen.utils import logger, run_concurrent - - -async def extract_info( - llm_client: BaseLLMWrapper, - chunk_storage: BaseKVStorage, - extract_config: dict, - progress_bar: gr.Progress = None, -): - """ - Extract information from chunks - :param llm_client: LLM client - :param chunk_storage: storage for chunks - :param extract_config - :param progress_bar - :return: extracted information - """ - - method = extract_config.get("method") - if method == "schema_guided": - schema_file = extract_config.get("schema_file") - with open(schema_file, "r", encoding="utf-8") as f: - schema = json.load(f) - extractor = SchemaGuidedExtractor(llm_client, schema) - else: - raise ValueError(f"Unsupported extraction method: {method}") - - chunks = chunk_storage.get_all() - chunks = [{k: v} for k, v in chunks.items()] - logger.info("Start extracting information from %d chunks", len(chunks)) - - results = await run_concurrent( - extractor.extract, - chunks, - desc="Extracting information", - unit="chunk", - progress_bar=progress_bar, - ) - - results = await extractor.merge_extractions(results) - return results diff --git a/graphgen/operators/extract/extract_service.py b/graphgen/operators/extract/extract_service.py new file mode 100644 index 00000000..33987fcb --- /dev/null +++ b/graphgen/operators/extract/extract_service.py @@ -0,0 +1,45 @@ +import json + +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm +from graphgen.models.extractor import SchemaGuidedExtractor +from graphgen.utils import logger, run_concurrent + + +class ExtractService(BaseOperator): + def __init__(self, working_dir: str = "cache", **extract_kwargs): + super().__init__(working_dir=working_dir, op_name="extract_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.extract_kwargs = extract_kwargs + self.method = self.extract_kwargs.get("method") + if self.method == "schema_guided": + schema_file = self.extract_kwargs.get("schema_path") + with open(schema_file, "r", encoding="utf-8") as f: + schema = json.load(f) + self.extractor = SchemaGuidedExtractor(self.llm_client, schema) + else: + raise ValueError(f"Unsupported extraction method: {self.method}") + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + return pd.DataFrame(self.extract(items)) + + def extract(self, items: list[dict]) -> list[dict]: + + logger.info("Start extracting information from %d items", len(items)) + + results = run_concurrent( + self.extractor.extract, + items, + desc="Extracting information", + unit="item", + ) + results = self.extractor.merge_extractions(results) + + results = [ + {"_extract_id": key, "extracted_data": value} + for key, value in results.items() + ] + return results diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index 035eca36..04057ce6 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -1 +1 @@ -from .generate_qas import generate_qas +from .generate_service import GenerateService diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py deleted file mode 100644 index 86dbb9c9..00000000 --- a/graphgen/operators/generate/generate_qas.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import ( - AggregatedGenerator, - AtomicGenerator, - CoTGenerator, - MultiHopGenerator, - VQAGenerator, -) -from graphgen.utils import logger, run_concurrent - - -async def generate_qas( - llm_client: BaseLLMWrapper, - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - generation_config: dict, - progress_bar: gr.Progress = None, -) -> list[dict[str, Any]]: - """ - Generate question-answer pairs based on nodes and edges. - :param llm_client: LLM client - :param batches - :param generation_config - :param progress_bar - :return: QA pairs - """ - method = generation_config["method"] - logger.info("[Generation] mode: %s, batches: %d", method, len(batches)) - - if method == "atomic": - generator = AtomicGenerator(llm_client) - elif method == "aggregated": - generator = AggregatedGenerator(llm_client) - elif method == "multi_hop": - generator = MultiHopGenerator(llm_client) - elif method == "cot": - generator = CoTGenerator(llm_client) - elif method in ["vqa"]: - generator = VQAGenerator(llm_client) - else: - raise ValueError(f"Unsupported generation mode: {method}") - - results = await run_concurrent( - generator.generate, - batches, - desc="[4/4]Generating QAs", - unit="batch", - progress_bar=progress_bar, - ) - - # format - data_format = generation_config["data_format"] - logger.info("Output data format: %s", data_format) - - results = generator.format_generation_results( - results, output_data_format=data_format - ) - - return results diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py new file mode 100644 index 00000000..1ae2f067 --- /dev/null +++ b/graphgen/operators/generate/generate_service.py @@ -0,0 +1,68 @@ +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm +from graphgen.models import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, + VQAGenerator, +) +from graphgen.utils import logger, run_concurrent + + +class GenerateService(BaseOperator): + """ + Generate question-answer pairs based on nodes and edges. + """ + + def __init__( + self, + working_dir: str = "cache", + method: str = "aggregated", + data_format: str = "ChatML", + ): + super().__init__(working_dir=working_dir, op_name="generate_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + self.method = method + self.data_format = data_format + + if self.method == "atomic": + self.generator = AtomicGenerator(self.llm_client) + elif self.method == "aggregated": + self.generator = AggregatedGenerator(self.llm_client) + elif self.method == "multi_hop": + self.generator = MultiHopGenerator(self.llm_client) + elif self.method == "cot": + self.generator = CoTGenerator(self.llm_client) + elif self.method in ["vqa"]: + self.generator = VQAGenerator(self.llm_client) + else: + raise ValueError(f"Unsupported generation mode: {method}") + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + return pd.DataFrame(self.generate(items)) + + def generate(self, items: list[dict]) -> list[dict]: + """ + Generate question-answer pairs based on nodes and edges. + :param items + :return: QA pairs + """ + logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) + items = [(item["nodes"], item["edges"]) for item in items] + results = run_concurrent( + self.generator.generate, + items, + desc="[4/4]Generating QAs", + unit="batch", + ) + + results = self.generator.format_generation_results( + results, output_data_format=self.data_format + ) + + return results diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py deleted file mode 100644 index ec604441..00000000 --- a/graphgen/operators/init/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .init_llm import init_llm diff --git a/graphgen/operators/judge/__init__.py b/graphgen/operators/judge/__init__.py new file mode 100644 index 00000000..32ccf5c2 --- /dev/null +++ b/graphgen/operators/judge/__init__.py @@ -0,0 +1 @@ +from .judge_service import JudgeService diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py new file mode 100644 index 00000000..4d554a0b --- /dev/null +++ b/graphgen/operators/judge/judge_service.py @@ -0,0 +1,70 @@ +import math + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm, init_storage +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy + + +class JudgeService(BaseOperator): + """Service for judging graph edges and nodes using a trainee LLM.""" + + def __init__(self, working_dir: str = "cache"): + super().__init__(working_dir=working_dir, op_name="judge_service") + self.llm_client: BaseLLMWrapper = init_llm("trainee") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", + working_dir=working_dir, + namespace="graph", + ) + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + self.graph_storage.reload() + self.judge(items) + return pd.DataFrame([{"status": "judging_completed"}]) + + async def _process_single_judge(self, item: dict) -> dict: + description = item["description"] + try: + judgement = await self.llm_client.generate_topk_per_token( + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) + ) + top_candidates = judgement[0].top_candidates + gt = item.get("ground_truth", "yes") + loss = yes_no_loss_entropy([top_candidates], [gt]) + logger.debug("Description: %s Loss: %s", description, loss) + item["loss"] = loss + except Exception as e: # pylint: disable=broad-except + logger.error("Error in judging description: %s", e) + logger.info("Use default loss 0.1") + item["loss"] = -math.log(0.1) + return item + + def judge(self, items: list[dict]) -> None: + """ + Judge the description in the item and compute the loss. + """ + results = run_concurrent( + self._process_single_judge, + items, + desc="Judging descriptions", + unit="description", + ) + # Update the graph storage with the computed losses + for item in results: + index = item["index"] + loss = item["loss"] + if isinstance(index, str): + node_id = index + node_data = self.graph_storage.get_node(node_id) + node_data["loss"] = loss + self.graph_storage.update_node(node_id, node_data) + elif isinstance(index, tuple): + edge_source, edge_target = index + edge_data = self.graph_storage.get_edge(edge_source, edge_target) + edge_data["loss"] = loss + self.graph_storage.update_edge(edge_source, edge_target, edge_data) + self.graph_storage.index_done_callback() diff --git a/graphgen/operators/partition/__init__.py b/graphgen/operators/partition/__init__.py index 21f934b3..8d586b95 100644 --- a/graphgen/operators/partition/__init__.py +++ b/graphgen/operators/partition/__init__.py @@ -1 +1 @@ -from .partition_kg import partition_kg +from .partition_service import PartitionService diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py deleted file mode 100644 index 4c4fdaa1..00000000 --- a/graphgen/operators/partition/partition_kg.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Any - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer -from graphgen.models import ( - AnchorBFSPartitioner, - BFSPartitioner, - DFSPartitioner, - ECEPartitioner, - LeidenPartitioner, -) -from graphgen.utils import logger - -from .pre_tokenize import pre_tokenize - - -async def partition_kg( - kg_instance: BaseGraphStorage, - chunk_storage: BaseKVStorage, - tokenizer: Any = BaseTokenizer, - partition_config: dict = None, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - method = partition_config["method"] - method_params = partition_config["method_params"] - if method == "bfs": - logger.info("Partitioning knowledge graph using BFS method.") - partitioner = BFSPartitioner() - elif method == "dfs": - logger.info("Partitioning knowledge graph using DFS method.") - partitioner = DFSPartitioner() - elif method == "ece": - logger.info("Partitioning knowledge graph using ECE method.") - # TODO: before ECE partitioning, we need to: - # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random - # 2. pre-tokenize nodes and edges to get the token length - edges = kg_instance.get_all_edges() - nodes = kg_instance.get_all_nodes() - await pre_tokenize(kg_instance, tokenizer, edges, nodes) - partitioner = ECEPartitioner() - elif method == "leiden": - logger.info("Partitioning knowledge graph using Leiden method.") - partitioner = LeidenPartitioner() - elif method == "anchor_bfs": - logger.info("Partitioning knowledge graph using Anchor BFS method.") - partitioner = AnchorBFSPartitioner( - anchor_type=method_params.get("anchor_type"), - anchor_ids=set(method_params.get("anchor_ids", [])) - if method_params.get("anchor_ids") - else None, - ) - else: - raise ValueError(f"Unsupported partition method: {method}") - - communities = await partitioner.partition(g=kg_instance, **method_params) - logger.info("Partitioned the graph into %d communities.", len(communities)) - batches = await partitioner.community2batch(communities, g=kg_instance) - - batches = await attach_additional_data_to_node(batches, chunk_storage) - return batches - - -async def attach_additional_data_to_node( - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - chunk_storage: BaseKVStorage, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - """ - Attach additional data from chunk_storage to nodes in the batches. - :param batches: - :param chunk_storage: - :return: - """ - for batch in batches: - for node_id, node_data in batch[0]: - await _attach_by_type(node_id, node_data, chunk_storage) - return batches - - -async def _attach_by_type( - node_id: str, - node_data: dict, - chunk_storage: BaseKVStorage, -) -> None: - """ - Attach additional data to the node based on its entity type. - """ - entity_type = (node_data.get("entity_type") or "").lower() - if not entity_type: - return - - source_ids = [ - sid.strip() - for sid in node_data.get("source_id", "").split("") - if sid.strip() - ] - - # Handle images - if "image" in entity_type: - image_chunks = [ - data - for sid in source_ids - if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid)) - ] - if image_chunks: - # The generator expects a dictionary with an 'img_path' key, not a list of captions. - # We'll use the first image chunk found for this node. - node_data["images"] = image_chunks[0] - logger.debug("Attached image data to node %s", node_id) diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py new file mode 100644 index 00000000..b4c0eda0 --- /dev/null +++ b/graphgen/operators/partition/partition_service.py @@ -0,0 +1,157 @@ +import json +import os +from typing import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseOperator, BaseTokenizer +from graphgen.common import init_storage +from graphgen.models import ( + AnchorBFSPartitioner, + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, + Tokenizer, +) +from graphgen.utils import logger + + +class PartitionService(BaseOperator): + def __init__(self, working_dir: str = "cache", **partition_kwargs): + super().__init__(working_dir=working_dir, op_name="partition_service") + self.kg_instance: BaseGraphStorage = init_storage( + backend="networkx", + working_dir=working_dir, + namespace="graph", + ) + self.chunk_storage: BaseKVStorage = init_storage( + backend="json_kv", + working_dir=working_dir, + namespace="chunk", + ) + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) + self.partition_kwargs = partition_kwargs + + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + self.kg_instance.reload() + self.chunk_storage.reload() + + yield from self.partition() + + def partition(self) -> Iterable[pd.DataFrame]: + method = self.partition_kwargs["method"] + method_params = self.partition_kwargs["method_params"] + if method == "bfs": + logger.info("Partitioning knowledge graph using BFS method.") + partitioner = BFSPartitioner() + elif method == "dfs": + logger.info("Partitioning knowledge graph using DFS method.") + partitioner = DFSPartitioner() + elif method == "ece": + logger.info("Partitioning knowledge graph using ECE method.") + # TODO: before ECE partitioning, we need to: + # 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random + # 2. pre-tokenize nodes and edges to get the token length + self._pre_tokenize() + partitioner = ECEPartitioner() + elif method == "leiden": + logger.info("Partitioning knowledge graph using Leiden method.") + partitioner = LeidenPartitioner() + elif method == "anchor_bfs": + logger.info("Partitioning knowledge graph using Anchor BFS method.") + partitioner = AnchorBFSPartitioner( + anchor_type=method_params.get("anchor_type"), + anchor_ids=set(method_params.get("anchor_ids", [])) + if method_params.get("anchor_ids") + else None, + ) + else: + raise ValueError(f"Unsupported partition method: {method}") + + communities = partitioner.partition(g=self.kg_instance, **method_params) + + for community in communities: + batch = partitioner.community2batch(community, g=self.kg_instance) + batch = self._attach_additional_data_to_node(batch) + + yield pd.DataFrame( + { + "nodes": [batch[0]], + "edges": [batch[1]], + } + ) + + def _pre_tokenize(self) -> None: + """Pre-tokenize all nodes and edges to add token length information.""" + logger.info("Starting pre-tokenization of nodes and edges...") + + nodes = self.kg_instance.get_all_nodes() + edges = self.kg_instance.get_all_edges() + + # Process nodes + for node_id, node_data in nodes: + if "length" not in node_data: + try: + description = node_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + node_data["length"] = len(tokens) + self.kg_instance.update_node(node_id, node_data) + except Exception as e: + logger.warning("Failed to tokenize node %s: %s", node_id, e) + node_data["length"] = 0 + + # Process edges + for u, v, edge_data in edges: + if "length" not in edge_data: + try: + description = edge_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + edge_data["length"] = len(tokens) + self.kg_instance.update_edge(u, v, edge_data) + except Exception as e: + logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e) + edge_data["length"] = 0 + + # Persist changes + self.kg_instance.index_done_callback() + logger.info("Pre-tokenization completed.") + + def _attach_additional_data_to_node(self, batch: tuple) -> tuple: + """ + Attach additional data from chunk_storage to nodes in the batch. + :param batch: tuple of (nodes_data, edges_data) + :return: updated batch with additional data attached to nodes + """ + nodes_data, edges_data = batch + + for node_id, node_data in nodes_data: + entity_type = (node_data.get("entity_type") or "").lower() + if not entity_type: + continue + + source_ids = [ + sid.strip() + for sid in node_data.get("source_id", "").split("") + if sid.strip() + ] + + # Handle images + if "image" in entity_type: + image_chunks = [ + data + for sid in source_ids + if "image" in sid.lower() + and (data := self.chunk_storage.get_by_id(sid)) + ] + if image_chunks: + # The generator expects a dictionary with an 'img_path' key, not a list of captions. + # We'll use the first image chunk found for this node. + node_data["image_data"] = json.loads(image_chunks[0]["content"]) + logger.debug("Attached image data to node %s", node_id) + + return nodes_data, edges_data diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py deleted file mode 100644 index 83e99060..00000000 --- a/graphgen/operators/partition/pre_tokenize.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import List, Tuple - -import gradio as gr - -from graphgen.bases import BaseGraphStorage, BaseTokenizer -from graphgen.utils import run_concurrent - - -async def pre_tokenize( - graph_storage: BaseGraphStorage, - tokenizer: BaseTokenizer, - edges: List[Tuple], - nodes: List[Tuple], - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> Tuple[List, List]: - """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" - sem = asyncio.Semaphore(max_concurrent) - - async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: - async with sem: - data = obj[1] if is_node else obj[2] - if "length" not in data: - loop = asyncio.get_event_loop() - data["length"] = len( - await loop.run_in_executor( - None, tokenizer.encode, data["description"] - ) - ) - if is_node: - graph_storage.update_node(obj[0], obj[1]) - else: - graph_storage.update_edge(obj[0], obj[1], obj[2]) - return obj - - new_edges, new_nodes = await asyncio.gather( - run_concurrent( - lambda e: _patch_and_write(e, is_node=False), - edges, - desc="Pre-tokenizing edges", - unit="edge", - progress_bar=progress_bar, - ), - run_concurrent( - lambda n: _patch_and_write(n, is_node=True), - nodes, - desc="Pre-tokenizing nodes", - unit="node", - progress_bar=progress_bar, - ), - ) - - graph_storage.index_done_callback() - return new_edges, new_nodes diff --git a/graphgen/operators/quiz/__init__.py b/graphgen/operators/quiz/__init__.py new file mode 100644 index 00000000..2a931f4b --- /dev/null +++ b/graphgen/operators/quiz/__init__.py @@ -0,0 +1 @@ +from .quiz_service import QuizService diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py new file mode 100644 index 00000000..a5e1baf5 --- /dev/null +++ b/graphgen/operators/quiz/quiz_service.py @@ -0,0 +1,112 @@ +from collections.abc import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm, init_storage +from graphgen.models import QuizGenerator +from graphgen.utils import compute_dict_hash, logger, run_concurrent + + +class QuizService(BaseOperator): + def __init__( + self, + working_dir: str = "cache", + quiz_samples: int = 1, + concurrency_limit: int = 200, + ): + super().__init__(working_dir=working_dir, op_name="quiz_service") + self.quiz_samples = quiz_samples + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend="networkx", working_dir=working_dir, namespace="graph" + ) + # { _quiz_id: { "description": str, "quizzes": List[Tuple[str, str]] } } + self.quiz_storage: BaseKVStorage = init_storage( + backend="json_kv", working_dir=working_dir, namespace="quiz" + ) + self.generator = QuizGenerator(self.llm_client) + self.concurrency_limit = concurrency_limit + + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + self.graph_storage.reload() + yield from self.quiz() + + async def _process_single_quiz(self, item: tuple) -> dict | None: + # if quiz in quiz_storage exists already, directly get it + index, desc = item + _quiz_id = compute_dict_hash({"index": index, "description": desc}) + if self.quiz_storage.get_by_id(_quiz_id): + return None + + tasks = [] + for i in range(self.quiz_samples): + if i > 0: + tasks.append((desc, "TEMPLATE", "yes")) + tasks.append((desc, "ANTI_TEMPLATE", "no")) + try: + quizzes = [] + for d, template_type, gt in tasks: + prompt = self.generator.build_prompt_for_description(d, template_type) + new_description = await self.llm_client.generate_answer( + prompt, temperature=1 + ) + rephrased_text = self.generator.parse_rephrased_text(new_description) + quizzes.append((rephrased_text, gt)) + return { + "_quiz_id": _quiz_id, + "description": desc, + "index": index, + "quizzes": quizzes, + } + except Exception as e: + logger.error("Error when quizzing description %s: %s", item, e) + return None + + def quiz(self) -> Iterable[pd.DataFrame]: + """ + Get all nodes and edges and quiz their descriptions using QuizGenerator. + """ + edges = self.graph_storage.get_all_edges() + nodes = self.graph_storage.get_all_nodes() + + items = [] + + for edge in edges: + edge_data = edge[2] + desc = edge_data["description"] + items.append(((edge[0], edge[1]), desc)) + + for node in nodes: + node_data = node[1] + desc = node_data["description"] + items.append((node[0], desc)) + + logger.info("Total descriptions to quiz: %d", len(items)) + + for i in range(0, len(items), self.concurrency_limit): + batch_items = items[i : i + self.concurrency_limit] + batch_results = run_concurrent( + self._process_single_quiz, + batch_items, + desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})", + unit="description", + ) + + final_results = [] + for new_result in batch_results: + if new_result: + self.quiz_storage.upsert( + { + new_result["_quiz_id"]: { + "description": new_result["description"], + "quizzes": new_result["quizzes"], + } + } + ) + final_results.append(new_result) + self.quiz_storage.index_done_callback() + yield pd.DataFrame(final_results) diff --git a/graphgen/operators/quiz_and_judge/__init__.py b/graphgen/operators/quiz_and_judge/__init__.py deleted file mode 100644 index cb73251a..00000000 --- a/graphgen/operators/quiz_and_judge/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .judge import judge_statement -from .quiz import quiz diff --git a/graphgen/operators/quiz_and_judge/judge.py b/graphgen/operators/quiz_and_judge/judge.py deleted file mode 100644 index b5e35eb9..00000000 --- a/graphgen/operators/quiz_and_judge/judge.py +++ /dev/null @@ -1,139 +0,0 @@ -import math - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT -from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy - - -async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - re_judge: bool = False, - progress_bar: gr.Progress = None, -) -> NetworkXStorage: - """ - Get all edges and nodes and judge them - - :param trainee_llm_client: judge the statements to get comprehension loss - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param re_judge: re-judge the relations - :param progress_bar - :return: - """ - - async def _judge_single_relation( - edge: tuple, - ): - source_id = edge[0] - target_id = edge[1] - edge_data = edge[2] - - if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: - logger.debug( - "Edge %s -> %s already judged, loss: %s, skip", - source_id, - target_id, - edge_data["loss"], - ) - return source_id, target_id, edge_data - - description = edge_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug( - "Edge %s -> %s description: %s loss: %s", - source_id, - target_id, - description, - loss, - ) - - edge_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error in judging relation %s -> %s: %s", source_id, target_id, e - ) - logger.info("Use default loss 0.1") - edge_data["loss"] = -math.log(0.1) - - graph_storage.update_edge(source_id, target_id, edge_data) - return source_id, target_id, edge_data - - edges = graph_storage.get_all_edges() - - await run_concurrent( - _judge_single_relation, - edges, - desc="Judging relations", - unit="relation", - progress_bar=progress_bar, - ) - - async def _judge_single_entity( - node: tuple, - ): - node_id = node[0] - node_data = node[1] - - if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: - logger.debug( - "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] - ) - return node_id, node_data - - description = node_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug("Node %s description: %s loss: %s", node_id, description, loss) - - node_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error("Error in judging entity %s: %s", node_id, e) - logger.error("Use default loss 0.1") - node_data["loss"] = -math.log(0.1) - - graph_storage.update_node(node_id, node_data) - return node_id, node_data - - nodes = graph_storage.get_all_nodes() - - await run_concurrent( - _judge_single_entity, - nodes, - desc="Judging entities", - unit="entity", - progress_bar=progress_bar, - ) - - return graph_storage diff --git a/graphgen/operators/quiz_and_judge/quiz.py b/graphgen/operators/quiz_and_judge/quiz.py deleted file mode 100644 index 9aadb34b..00000000 --- a/graphgen/operators/quiz_and_judge/quiz.py +++ /dev/null @@ -1,93 +0,0 @@ -from collections import defaultdict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator -from graphgen.utils import logger, run_concurrent - - -async def quiz( - synth_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - max_samples: int = 1, - progress_bar: gr.Progress = None, -) -> JsonKVStorage: - """ - Get all edges and quiz them using QuizGenerator. - - :param synth_llm_client: generate statements - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param max_samples: max samples for each edge - :param progress_bar - :return: - """ - - generator = QuizGenerator(synth_llm_client) - - async def _process_single_quiz(item: tuple[str, str, str]): - description, template_type, gt = item - try: - # if rephrase_storage exists already, directly get it - descriptions = rephrase_storage.get_by_id(description) - if descriptions: - return None - - prompt = generator.build_prompt_for_description(description, template_type) - new_description = await synth_llm_client.generate_answer( - prompt, temperature=1 - ) - rephrased_text = generator.parse_rephrased_text(new_description) - return {description: [(rephrased_text, gt)]} - - except Exception as e: # pylint: disable=broad-except - logger.error("Error when quizzing description %s: %s", description, e) - return None - - edges = graph_storage.get_all_edges() - nodes = graph_storage.get_all_nodes() - - results = defaultdict(list) - items = [] - for edge in edges: - edge_data = edge[2] - description = edge_data["description"] - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - for node in nodes: - node_data = node[1] - description = node_data["description"] - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - quiz_results = await run_concurrent( - _process_single_quiz, - items, - desc="Quizzing descriptions", - unit="description", - progress_bar=progress_bar, - ) - - for new_result in quiz_results: - if new_result: - for key, value in new_result.items(): - results[key].extend(value) - - for key, value in results.items(): - results[key] = list(set(value)) - rephrase_storage.upsert({key: results[key]}) - - return rephrase_storage diff --git a/graphgen/operators/read/__init__.py b/graphgen/operators/read/__init__.py index 075ae938..cda44587 100644 --- a/graphgen/operators/read/__init__.py +++ b/graphgen/operators/read/__init__.py @@ -1 +1 @@ -from .read_files import read_files +from .read import read diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 73b477c3..db50d7af 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Set, Union from graphgen.models import RocksDBCache -from graphgen.utils import logger class ParallelFileScanner: @@ -32,15 +31,12 @@ def scan( self._scan_files, Path(p).resolve(), recursive, set() ) future_to_path[future] = p - else: - logger.warning("[READ] Path does not exist: %s", p) for future in as_completed(future_to_path): path = future_to_path[future] try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning path %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -56,17 +52,14 @@ def _scan_files( # Avoid cycles due to symlinks if path_str in visited: - logger.warning("[READ] Skipping already visited path: %s", path_str) return self._empty_result(path_str) # cache check cache_key = f"scan::{path_str}::recursive::{recursive}" cached = self.cache.get(cache_key) if cached and not self.rescan: - logger.info("[READ] Using cached scan result for path: %s", path_str) return cached["data"] - logger.info("[READ] Scanning path: %s", path_str) files, dirs = [], [] stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0} @@ -108,7 +101,6 @@ def _scan_files( stats["errors"] += 1 except (PermissionError, FileNotFoundError, OSError) as e: - logger.error("[READ] Failed to scan path %s: %s", path_str, e) return {"error": str(e), "files": [], "dirs": [], "stats": stats} if recursive: @@ -171,7 +163,6 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning subdirectory %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -183,18 +174,14 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An def _cache_result(self, key: str, result: Dict, path: Path): """Cache the scan result""" - try: - self.cache.set( - key, - { - "data": result, - "dir_mtime": path.stat().st_mtime, - "cached_at": time.time(), - }, - ) - logger.info("[READ] Cached scan result for path: %s", path) - except OSError as e: - logger.error("[READ] Failed to cache scan result for path %s: %s", path, e) + self.cache.set( + key, + { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + ) def _is_allowed_file(self, path: Path) -> bool: """Check if the file has an allowed suffix""" @@ -209,7 +196,6 @@ def invalidate(self, path: str): keys = [k for k in self.cache if k.startswith(f"scan::{path}")] for k in keys: self.cache.delete(k) - logger.info("[READ] Invalidated cache for path: %s", path) def close(self): self.cache.close() diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py new file mode 100644 index 00000000..fbed377e --- /dev/null +++ b/graphgen/operators/read/read.py @@ -0,0 +1,128 @@ +from pathlib import Path +from typing import Any, List, Optional, Union + +import ray + +from graphgen.models import ( + CSVReader, + JSONReader, + ParquetReader, + PDFReader, + PickleReader, + RDFReader, + TXTReader, +) +from graphgen.utils import compute_mm_hash, logger + +from .parallel_file_scanner import ParallelFileScanner + +_MAPPING = { + "jsonl": JSONReader, + "json": JSONReader, + "txt": TXTReader, + "csv": CSVReader, + "md": TXTReader, + "pdf": PDFReader, + "parquet": ParquetReader, + "pickle": PickleReader, + "rdf": RDFReader, + "owl": RDFReader, + "ttl": RDFReader, +} + + +def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): + """Factory function to build appropriate reader instance""" + suffix = suffix.lower() + reader_cls = _MAPPING.get(suffix) + if not reader_cls: + raise ValueError(f"Unsupported file suffix: {suffix}") + + # Special handling for PDFReader which needs output_dir + if suffix == "pdf": + if cache_dir is None: + raise ValueError("cache_dir must be provided for PDFReader") + return reader_cls(output_dir=cache_dir, **reader_kwargs) + + return reader_cls(**reader_kwargs) + + +def read( + input_path: Union[str, List[str]], + allowed_suffix: Optional[List[str]] = None, + cache_dir: Optional[str] = "cache", + parallelism: int = 4, + recursive: bool = True, + **reader_kwargs: Any, +) -> ray.data.Dataset: + """ + Unified entry point to read files of multiple types using Ray Data. + + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents + """ + try: + # 1. Scan all paths to discover files + logger.info("[READ] Scanning paths: %s", input_path) + scanner = ParallelFileScanner( + cache_dir=cache_dir, + allowed_suffix=allowed_suffix, + rescan=False, + max_workers=parallelism if parallelism > 0 else 1, + ) + + all_files = [] + scan_results = scanner.scan(input_path, recursive=recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + + if not all_files: + raise ValueError("No files found to read.") + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + for suffix, file_paths in files_by_suffix.items(): + reader = _build_reader(suffix, cache_dir, **reader_kwargs) + ds = reader.read(file_paths) + read_tasks.append(ds) + + # 4. Combine all datasets + if not read_tasks: + raise ValueError("No datasets created from the provided files.") + + if len(read_tasks) == 1: + combined_ds = read_tasks[0] + else: + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + combined_ds = combined_ds.map( + lambda record: { + **record, + "_doc_id": compute_mm_hash(record, prefix="doc-"), + } + ) + + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds + + except Exception as e: + logger.error("[READ] Failed to read files from %s: %s", input_path, e) + raise diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py deleted file mode 100644 index d9e7f673..00000000 --- a/graphgen/operators/read/read_files.py +++ /dev/null @@ -1,99 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional - -from graphgen.models import ( - CSVReader, - JSONLReader, - JSONReader, - ParquetReader, - PDFReader, - PickleReader, - RDFReader, - TXTReader, -) -from graphgen.utils import logger - -from .parallel_file_scanner import ParallelFileScanner - -_MAPPING = { - "jsonl": JSONLReader, - "json": JSONReader, - "txt": TXTReader, - "csv": CSVReader, - "md": TXTReader, - "pdf": PDFReader, - "parquet": ParquetReader, - "pickle": PickleReader, - "rdf": RDFReader, - "owl": RDFReader, - "ttl": RDFReader, -} - - -def _build_reader(suffix: str, cache_dir: str | None): - suffix = suffix.lower() - if suffix == "pdf" and cache_dir is not None: - return _MAPPING[suffix](output_dir=cache_dir) - return _MAPPING[suffix]() - - -def read_files( - input_file: str, - allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - max_workers: int = 4, - rescan: bool = False, -) -> Iterator[Dict[str, Any]]: - """ - Read files from a path using parallel scanning and appropriate readers. - - Args: - input_file: Path to a file or directory - allowed_suffix: List of file suffixes to read. If None, uses all supported types - cache_dir: Directory for caching PDF extraction and scan results - max_workers: Number of workers for parallel scanning - rescan: Whether to force rescan even if cached results exist - """ - - path = Path(input_file).expanduser() - if not path.exists(): - raise FileNotFoundError(f"input_path not found: {input_file}") - - if allowed_suffix is None: - support_suffix = set(_MAPPING.keys()) - else: - support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} - - with ParallelFileScanner( - cache_dir=cache_dir or "cache", - allowed_suffix=support_suffix, - rescan=rescan, - max_workers=max_workers, - ) as scanner: - scan_results = scanner.scan(str(path), recursive=True) - - # Extract files from scan results - files_to_read = [] - for path_result in scan_results.values(): - if "error" in path_result: - logger.warning("Error scanning %s: %s", path_result.path, path_result.error) - continue - files_to_read.extend(path_result.get("files", [])) - - logger.info( - "Found %d eligible file(s) under folder %s (allowed_suffix=%s)", - len(files_to_read), - input_file, - support_suffix, - ) - - for file_info in files_to_read: - try: - file_path = file_info["path"] - suffix = Path(file_path).suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - - yield from reader.read(file_path) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error reading %s: %s", file_info.get("path"), e) diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py deleted file mode 100644 index 2afc738d..00000000 --- a/graphgen/operators/split/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .split_chunks import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/split/split_chunks.py deleted file mode 100644 index 3f728e00..00000000 --- a/graphgen/operators/split/split_chunks.py +++ /dev/null @@ -1,84 +0,0 @@ -from functools import lru_cache -from typing import Union - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import ( - ChineseRecursiveTextSplitter, - RecursiveCharacterSplitter, - Tokenizer, -) -from graphgen.utils import compute_content_hash, detect_main_language - -_MAPPING = { - "en": RecursiveCharacterSplitter, - "zh": ChineseRecursiveTextSplitter, -} - -SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] - - -@lru_cache(maxsize=None) -def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: - cls = _MAPPING[language] - kwargs = dict(frozen_kwargs) - return cls(**kwargs) - - -def split_chunks(text: str, language: str = "en", **kwargs) -> list: - if language not in _MAPPING: - raise ValueError( - f"Unsupported language: {language}. " - f"Supported languages are: {list(_MAPPING.keys())}" - ) - frozen_kwargs = frozenset( - (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() - ) - splitter = _get_splitter(language, frozen_kwargs) - return splitter.split_text(text) - - -async def chunk_documents( - new_docs: dict, - tokenizer_instance: Tokenizer = None, - progress_bar=None, - **kwargs, -) -> dict: - inserting_chunks = {} - cur_index = 1 - doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): - doc_type = doc.get("type") - if doc_type == "text": - doc_language = detect_main_language(doc["content"]) - - text_chunks = split_chunks( - doc["content"], - language=doc_language, - **kwargs, - ) - - chunks = { - compute_content_hash(txt, prefix="chunk-"): { - "content": txt, - "type": "text", - "_full_docs_id": doc_key, - "length": len(tokenizer_instance.encode(txt)) - if tokenizer_instance - else len(txt), - "language": doc_language, - } - for txt in text_chunks - } - else: - chunks = {doc_key.replace("doc-", f"{doc_type}-"): {**doc}} - - inserting_chunks.update(chunks) - - if progress_bar is not None: - progress_bar(cur_index / doc_number, f"Chunking {doc_key}") - cur_index += 1 - - return inserting_chunks diff --git a/graphgen/operators/storage.py b/graphgen/operators/storage.py deleted file mode 100644 index ea5488ac..00000000 --- a/graphgen/operators/storage.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from typing import Any - -import ray - -from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage - - -@ray.remote -class StorageManager: - """ - Centralized storage for all operators - - Example Usage: - ---------- - # init - storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123) - - # visit storage in tasks - @ray.remote - def some_task(storage_manager): - full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs")) - - # visit storage in other actors - @ray.remote - class SomeOperator: - def __init__(self, storage_manager): - self.storage_manager = storage_manager - def some_method(self): - full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs")) - """ - - def __init__(self, working_dir: str, unique_id: int): - self.working_dir = working_dir - self.unique_id = unique_id - - # Initialize all storage backends - self.storages = { - "full_docs": JsonKVStorage(working_dir, namespace="full_docs"), - "chunks": JsonKVStorage(working_dir, namespace="chunks"), - "graph": NetworkXStorage(working_dir, namespace="graph"), - "rephrase": JsonKVStorage(working_dir, namespace="rephrase"), - "partition": JsonListStorage(working_dir, namespace="partition"), - "search": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="search", - ), - "extraction": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="extraction", - ), - "qa": JsonListStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="qa", - ), - } - - def get_storage(self, name: str) -> Any: - return self.storages.get(name) diff --git a/graphgen/run.py b/graphgen/run.py index c300a6aa..419fd7bd 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -1,14 +1,18 @@ import argparse import os import time -from importlib.resources import files +from importlib import resources +from typing import Any, Dict +import ray import yaml from dotenv import load_dotenv +from ray.data.block import Block +from ray.data.datasource.filename_provider import FilenameProvider -from graphgen.engine import Context, Engine, collect_ops -from graphgen.graphgen import GraphGen -from graphgen.utils import logger, set_logger +from graphgen.engine import Engine +from graphgen.operators import operators +from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -28,12 +32,38 @@ def save_config(config_path, global_config): ) +class NodeFilenameProvider(FilenameProvider): + def __init__(self, node_id: str): + self.node_id = node_id + + def get_filename_for_block( + self, block: Block, write_uuid: str, task_index: int, block_index: int + ) -> str: + # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json + return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" + + def get_filename_for_row( + self, + row: Dict[str, Any], + write_uuid: str, + task_index: int, + block_index: int, + row_index: int, + ) -> str: + raise NotImplementedError( + f"Row-based filenames are not supported by write_json. " + f"Node: {self.node_id}, write_uuid: {write_uuid}" + ) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config_file", help="Config parameters for GraphGen.", - default=files("graphgen").joinpath("configs", "aggregated_config.yaml"), + default=resources.files("graphgen") + .joinpath("configs") + .joinpath("aggregated_config.yaml"), type=str, ) parser.add_argument( @@ -51,29 +81,41 @@ def main(): with open(args.config_file, "r", encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.FullLoader) + engine = Engine(config, operators) + unique_id = int(time.time()) output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") set_working_dir(output_path) - set_logger( - os.path.join(output_path, f"{unique_id}.log"), + log_path = os.path.join(working_dir, "logs", "Driver.log") + driver_logger = set_logger( + log_path, + name="GraphGen", if_stream=True, ) + CURRENT_LOGGER_VAR.set(driver_logger) logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join(working_dir, f"{unique_id}.log"), + log_path, ) - - graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) - - # share context between different steps - ctx = Context(config=config, graph_gen=graph_gen) - ops = collect_ops(config, graph_gen) - - # run operations - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) + ds = ray.data.from_items([]) + results = engine.execute(ds) + + for node_id, dataset in results.items(): + output_path = os.path.join(output_path, f"{node_id}") + os.makedirs(output_path, exist_ok=True) + dataset.write_json( + output_path, + filename_provider=NodeFilenameProvider(node_id), + pandas_json_args_fn=lambda: { + "force_ascii": False, + "orient": "records", + "lines": True, + }, + ) + logger.info("Node %s results saved to %s", node_id, output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index d3e6df7b..ec118816 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -16,7 +16,7 @@ compute_mm_hash, ) from .help_nltk import NLTKHelper -from .log import logger, parse_log, set_logger +from .log import CURRENT_LOGGER_VAR, logger, set_logger from .loop import create_event_loop from .run_concurrent import run_concurrent from .wrap import async_to_sync_method diff --git a/graphgen/utils/log.py b/graphgen/utils/log.py index 102b7b23..e29e994e 100644 --- a/graphgen/utils/log.py +++ b/graphgen/utils/log.py @@ -1,13 +1,15 @@ +import contextvars import logging +import os from logging.handlers import RotatingFileHandler +from typing import Any from rich.logging import RichHandler -logger = logging.getLogger("graphgen") - def set_logger( log_file: str, + name: str, file_level: int = logging.DEBUG, console_level: int = logging.INFO, *, @@ -17,26 +19,27 @@ def set_logger( force: bool = False, ): - if logger.hasHandlers() and not force: - return + current_logger = logging.getLogger(name) + if current_logger.hasHandlers() and not force: + return current_logger if force: - logger.handlers.clear() + current_logger.handlers.clear() - logger.setLevel( + current_logger.setLevel( min(file_level, console_level) ) # Set to the lowest level to capture all logs - logger.propagate = False + current_logger.propagate = False - if logger.handlers: - logger.handlers.clear() + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) if if_stream: console = RichHandler( level=console_level, show_path=False, rich_tracebacks=True ) console.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(console) + current_logger.addHandler(console) file_handler = RotatingFileHandler( log_file, @@ -51,10 +54,48 @@ def set_logger( datefmt="%y-%m-%d %H:%M:%S", ) ) - logger.addHandler(file_handler) + current_logger.addHandler(file_handler) + return current_logger + + +CURRENT_LOGGER_VAR = contextvars.ContextVar("current_logger") + + +def get_current_logger() -> logging.Logger: + current_logger = CURRENT_LOGGER_VAR.get() + if not current_logger: + raise RuntimeError("No logger is set in the current context.") + return current_logger + + +class ContextAwareLogger: + @staticmethod + def _get_logger() -> logging.Logger: + return get_current_logger() + + def debug(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().debug(msg, *args, **kwargs) + + def info(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().info(msg, *args, **kwargs) + + def warning(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().warning(msg, *args, **kwargs) + + def error(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().error(msg, *args, **kwargs) + + def exception(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().exception(msg, *args, **kwargs) + + def critical(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().critical(msg, *args, **kwargs) + + def log(self, level: int, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().log(level, msg, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_logger(), name) -def parse_log(log_file: str): - with open(log_file, "r", encoding="utf-8") as f: - lines = f.readlines() - return lines +logger = ContextAwareLogger() diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index ac63f87b..d1a9b0e2 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -1,55 +1,44 @@ import asyncio -from typing import Awaitable, Callable, List, Optional, TypeVar +from typing import Awaitable, Callable, List, TypeVar -import gradio as gr from tqdm.asyncio import tqdm as tqdm_async from graphgen.utils.log import logger +from .loop import create_event_loop + T = TypeVar("T") R = TypeVar("R") -async def run_concurrent( +def run_concurrent( coro_fn: Callable[[T], Awaitable[R]], items: List[T], *, desc: str = "processing", unit: str = "item", - progress_bar: Optional[gr.Progress] = None, ) -> List[R]: - tasks = [asyncio.create_task(coro_fn(it)) for it in items] - - completed_count = 0 - results = [] - - pbar = tqdm_async(total=len(items), desc=desc, unit=unit) - - if progress_bar is not None: - progress_bar(0.0, desc=f"{desc} (0/{len(items)})") - - for future in asyncio.as_completed(tasks): - try: - result = await future - results.append(result) - except Exception as e: # pylint: disable=broad-except - logger.exception("Task failed: %s", e) - # even if failed, record it to keep results consistent with tasks - results.append(e) - - completed_count += 1 - pbar.update(1) - - if progress_bar is not None: - progress = completed_count / len(items) - progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})") - - pbar.close() - - if progress_bar is not None: - progress_bar(1.0, desc=f"{desc} (completed)") - - # filter out exceptions - results = [res for res in results if not isinstance(res, Exception)] - - return results + async def _run_all(): + tasks = [asyncio.create_task(coro_fn(item)) for item in items] + + results = [] + pbar = tqdm_async(total=len(items), desc=desc, unit=unit) + + for future in asyncio.as_completed(tasks): + try: + result = await future + results.append(result) + except Exception as e: + logger.exception("Task failed: %s", e) + results.append(e) + + pbar.update(1) + + pbar.close() + return [res for res in results if not isinstance(res, Exception)] + + loop = create_event_loop() + try: + return loop.run_until_complete(_run_all()) + finally: + loop.close() diff --git a/requirements.txt b/requirements.txt index 85fc43e3..44079ab5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,8 @@ fastapi trafilatura aiohttp socksio +pydantic +ray==2.52.1 leidenalg igraph diff --git a/scripts/extract/extract_schema_guided.sh b/scripts/extract/extract_schema_guided.sh deleted file mode 100644 index 0badc174..00000000 --- a/scripts/extract/extract_schema_guided.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/schema_guided_extraction_config.yaml \ ---output_dir cache/ diff --git a/scripts/generate/generate_aggregated.sh b/scripts/generate/generate_aggregated.sh deleted file mode 100644 index 7117eff1..00000000 --- a/scripts/generate/generate_aggregated.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/aggregated_config.yaml \ ---output_dir cache/ diff --git a/scripts/generate/generate_atomic.sh b/scripts/generate/generate_atomic.sh deleted file mode 100644 index 822d6c48..00000000 --- a/scripts/generate/generate_atomic.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/atomic_config.yaml \ ---output_dir cache/ diff --git a/scripts/generate/generate_cot.sh b/scripts/generate/generate_cot.sh deleted file mode 100644 index 9c2ee151..00000000 --- a/scripts/generate/generate_cot.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/cot_config.yaml \ ---output_dir cache/ diff --git a/scripts/generate/generate_multi_hop.sh b/scripts/generate/generate_multi_hop.sh deleted file mode 100644 index 6480e080..00000000 --- a/scripts/generate/generate_multi_hop.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/multi_hop_config.yaml \ ---output_dir cache/ diff --git a/scripts/generate/generate_vqa.sh b/scripts/generate/generate_vqa.sh deleted file mode 100644 index f7fd2726..00000000 --- a/scripts/generate/generate_vqa.sh +++ /dev/null @@ -1,3 +0,0 @@ -python3 -m graphgen.run \ ---config_file graphgen/configs/vqa_config.yaml \ ---output_dir cache/