Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm generated explanations to TAGDataset #9918

Open
wants to merge 24 commits into
base: master
Choose a base branch
from

Conversation

xnuohz
Copy link
Contributor

@xnuohz xnuohz commented Jan 5, 2025

Issue

#9361

Script

python examples/llm/glem.py --dataset arxiv --train_without_ext_pred --text_type llm_explanation

@xnuohz xnuohz requested a review from wsad1 as a code owner January 5, 2025 10:32
@puririshi98
Copy link
Contributor

this is awesome, but can you add an example to run? (also make sure the CI checks are all green)

@xnuohz
Copy link
Contributor Author

xnuohz commented Jan 7, 2025

Thanks, @puririshi98 Can you clarify more about the runnable example? A simple way is to apply it to the GLEM model. However, a better way is to add a TAPE model, which is worth opening a new PR for easy review.
btw, there is a little issue with GLEM's example. Also plz take a look if you have bandwidth:)

Hi @akihironitta CI error is weird, seems like the error has nothing to do with my changes, can you help take a look?

@puririshi98
Copy link
Contributor

puririshi98 commented Jan 7, 2025

@xnuohz i think for now just adding it as an optional flag to GLEM example is okay. feel free to submit a seperate PR for TAPE. plz ping me on slack since i have github emails heavily filtered otherwise my inbox would explode. feel free to include this as a flag there as well when you do it

@puririshi98
Copy link
Contributor

puririshi98 commented Jan 7, 2025

Hi @akihironitta CI error is weird, seems like the error has nothing to do with my changes, can you help take a look?

@xnuohz ignore those for now, i was previously just talking about the linters that were red, your CI was functionally green before. once you address my above comment im sure these new issues will go away since their unrelated to your code, ive had this happen to me many times and they always go away on future respins. just my experience.

@xnuohz
Copy link
Contributor Author

xnuohz commented Jan 10, 2025

Namespace(gpu=0, num_runs=10, num_em_iters=1, dataset='arxiv', text_type='llm_explanation', pl_ratio=0.5, hf_model='prajjwal1/bert-tiny', gnn_model='SAGE', gnn_hidden_channels=256, gnn_num_layers=3, gat_heads=4, lm_batch_size=256, gnn_batch_size=1024, external_pred_path=None, alpha=0.5, beta=0.5, lm_epochs=10, gnn_epochs=50, gnn_lr=0.002, lm_lr=0.001, patience=3, verbose=False, em_order='lm', lm_use_lora=False, token_on_disk=False, out_dir='output/', train_without_ext_pred=True)
Running on: NVIDIA GeForce RTX 3090
/home/ubuntu/Softwares/anaconda3/envs/pyg-dev/lib/python3.9/site-packages/ogb/nodeproppred/dataset_pyg.py:69: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  self.data, self.slices = torch.load(self.processed_paths[0])
/home/ubuntu/Projects/pytorch_geometric/torch_geometric/data/in_memory_dataset.py:300: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.
  warnings.warn(msg)
Processing...
Done!
Tokenizing Text Attributed Graph raw_text: 100%|█████████████████████████████████████████████████████████████████| 169343/169343 [00:22<00:00, 7604.87it/s]
Tokenizing Text Attributed Graph llm_explanation: 100%|██████████████████████████████████████████████████████████| 169343/169343 [00:20<00:00, 8320.77it/s]
40 ['node-feat.csv.gz', 'node-label.csv.gz', 'ogbn-arxiv.csv', 'num-edge-list.csv.gz', 'num-node-list.csv.gz', 'node-gpt-response.csv.gz', 'edge.csv.gz', 'node_year.csv.gz', 'node-text.csv.gz']
train_idx: 136411, gold_idx: 90941, pseudo labels ratio: 0.5, 0.49999450192982264
Building language model dataloader...-->done
GPU memory usage -- data to gpu: 0.10 GB
build GNN dataloader(GraphSAGE NeighborLoader)--># GNN Params: 217640
2025-01-10 01:08:52.467527: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-10 01:08:52.485008: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-10 01:08:52.485033: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-10 01:08:52.485046: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 01:08:52.488697: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-01-10 01:08:52.887660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# LM Params: 4391080
pretraining gnn to generate pseudo labels
Epoch: 01 Loss: 2.1608 Approx. Train: 0.4124
Epoch: 02 Loss: 1.5093 Approx. Train: 0.5615
Epoch: 03 Loss: 1.3932 Approx. Train: 0.5870
Epoch: 04 Loss: 1.3258 Approx. Train: 0.6046
Epoch: 05 Loss: 1.2801 Approx. Train: 0.6159
Train: 0.6067, Val: 0.5852
Epoch: 06 Loss: 1.2459 Approx. Train: 0.6250
Train: 0.6145, Val: 0.5911
Epoch: 07 Loss: 1.2151 Approx. Train: 0.6317
Train: 0.6196, Val: 0.5999
Epoch: 08 Loss: 1.1907 Approx. Train: 0.6374
Train: 0.6213, Val: 0.5876
Epoch: 09 Loss: 1.1649 Approx. Train: 0.6445
Train: 0.6297, Val: 0.6033
Epoch: 10 Loss: 1.1433 Approx. Train: 0.6514
Train: 0.6290, Val: 0.5988
Epoch: 11 Loss: 1.1221 Approx. Train: 0.6560
Train: 0.6420, Val: 0.5989
Epoch: 12 Loss: 1.0989 Approx. Train: 0.6615
Train: 0.6392, Val: 0.6019
Pretrain Early stopped by Epoch: 12
Pretrain gnn time: 10.77s
Saved predictions to output/preds/arxiv/gnn_pretrain.pt
Pretraining acc: 0.6392, Val: 0.6019, Test: 0.5453
EM iteration: 1, EM phase: lm
Move lm model from cpu memory
Epoch 01 Loss: 1.5116 Approx. Train: 0.6574
Epoch 02 Loss: 1.1643 Approx. Train: 0.7199
Epoch 03 Loss: 1.0531 Approx. Train: 0.7243
Epoch 04 Loss: 0.9468 Approx. Train: 0.7283
Epoch 05 Loss: 0.8540 Approx. Train: 0.7320
Train: 0.8205, Val: 0.6925,
Epoch 06 Loss: 0.7706 Approx. Train: 0.7373
Train: 0.8343, Val: 0.6895,
Epoch 07 Loss: 0.7037 Approx. Train: 0.7413
Train: 0.8464, Val: 0.6699,
Epoch 08 Loss: 0.6463 Approx. Train: 0.7451
Train: 0.8590, Val: 0.6741,
Epoch 09 Loss: 0.6028 Approx. Train: 0.7487
Train: 0.8680, Val: 0.6777,
Early stopped by Epoch: 9,                             Best acc: 0.6925400181214135
EM iteration: 2, EM phase: gnn
Move gnn model from cpu memory
Epoch: 01 Loss: 0.9413 Approx. Train: 0.6264
Epoch: 02 Loss: 0.9080 Approx. Train: 0.6299
Epoch: 03 Loss: 0.8870 Approx. Train: 0.6345
Epoch: 04 Loss: 0.8745 Approx. Train: 0.6363
Epoch: 05 Loss: 0.8623 Approx. Train: 0.6394
Train: 0.6444, Val: 0.6100,
Epoch: 06 Loss: 0.8464 Approx. Train: 0.6423
Train: 0.6546, Val: 0.6163,
Epoch: 07 Loss: 0.8352 Approx. Train: 0.6439
Train: 0.6560, Val: 0.6143,
Epoch: 08 Loss: 0.8229 Approx. Train: 0.6460
Train: 0.6628, Val: 0.6180,
Epoch: 09 Loss: 0.8094 Approx. Train: 0.6495
Train: 0.6485, Val: 0.6083,
Epoch: 10 Loss: 0.7965 Approx. Train: 0.6522
Train: 0.6647, Val: 0.6136,
Epoch: 11 Loss: 0.7855 Approx. Train: 0.6547
Train: 0.6693, Val: 0.6173,
Early stopped by Epoch: 11,                             Best acc: 0.6179737575086413
Best GNN validation acc: 0.6179737575086413,LM validation acc: 0.6925400181214135
============================
Best test acc: 0.6018352776577578, model: lm
Total running time: 0.08 hours

Copy link
Contributor

@puririshi98 puririshi98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but can you make an argparser option that combines the raw text and the llm explanation and run that? curious how it effects accuracy

…ch_geometric into tagdataset/add-llm-exp-pred
@xnuohz
Copy link
Contributor Author

xnuohz commented Jan 12, 2025

I tested some configurations, and the acc shows that LLM explanation > Raw text + LLM explanation > Raw text. The reason may be that the explanation is inferred from the original text through LLM, which already contains enough semantic information. Therefore, combining the two text types does not further improve the acc.
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants