-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 44bf601
Showing
24 changed files
with
3,494 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ | ||
|
||
# Ours | ||
.ipynb_checkpoints | ||
__pycache__ | ||
ckpt | ||
log | ||
script | ||
data | ||
tb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Hypergraph Transformer: Weakly-supervised Multi-hop Reasoning for Knowledge-based Visual Question Answering | ||
Source code for ACL 2022 paper: "Hypergraph Transformer: Weakly-supervised Multi-hop Reasoning for Knowledge-based Visual Question Answering", Yu-Jung Heo, Eun-Sol Kim, Woo Suk Choi, and Byoung-Tak Zhang | ||
* [[Paper]](https://aclanthology.org/2022.acl-long.29.pdf) [[Slides]](https://www.dropbox.com/s/kkbmmm1sy7f1ldr/acl22_HGT_slides.pdf) | ||
|
||
> To answer complex questions requiring multi-hop reasoning under weak supervision is considered as a challenging problem since i) no supervision is given to the reasoning process and ii) high-order semantics of multi-hop knowledge facts need to be captured. In this paper, we introduce a concept of hypergraph to encode high-level semantics of a question and a knowledge base, and to learn high-order associations between them. The proposed model, Hypergraph Transformer, constructs a question hypergraph and a query-aware knowledge hypergraph, and infers an answer by encoding inter-associations between two hypergraphs and intra-associations in both hypergraph itself. | ||
 | ||
|
||
## Requirements | ||
This code runs on Python 3.7 and PyTorch 1.5.1. We recommend using Anaconda to install all dependencies. | ||
``` | ||
git clone https://github.com/YuJungHeo/kbvqa-public.git | ||
conda env create --file environment.yml --name kbvqa | ||
``` | ||
|
||
## Setup | ||
Download preprocessed KVQA, PQ-{2H, 3H, M}, PQL-{2H, 3H, M} datasets. | ||
``` | ||
bash download.sh | ||
``` | ||
|
||
## Training | ||
|
||
Train on KVQA dataset | ||
* `--cfg` specifies a configuration yaml file. | ||
* `--q_opt` specifies a question type among original (ORG) and paraphrased (PRP). | ||
* `--n_hop` specifies a number of graph walk (e.g., 1-hop, 2-hop, and 3-hop). | ||
|
||
``` | ||
# ORG, 3-hop on oracle setting | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 3 --q_opt org --lr 0.00001 --exp_name ht_kvqa_org_3hop | ||
``` | ||
|
||
Train on PathQuestions (PQ) dataset | ||
``` | ||
# PQ-2H | ||
python main.py --data_name pq --model_name ht --cfg ht_pq2h --n_hop 2 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pq2h | ||
# PQ-3H | ||
python main.py --data_name pq --model_name ht --cfg ht_pq3h --n_hop 3 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pq3h | ||
# PQ-M (a mixture of the PQ-2H and PQ-3H questions) | ||
python main.py --data_name pq --model_name ht --cfg ht_pqM --n_hop 3 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pqM | ||
``` | ||
|
||
Train on PathQuestions-Large (PQL) dataset | ||
``` | ||
# PQL-2H | ||
python main.py --data_name pql --model_name ht --cfg ht_pql2h --n_hop 2 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pql2h --split_seed 789 | ||
# PQL-3H-More | ||
python main.py --data_name pql --model_name ht --cfg ht_pql3h_more --n_hop 3 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pql3h_more --split_seed 789 | ||
# PQL-M (a mixture of the PQL-2H and PQL-3H questions) | ||
python main.py --data_name pql --model_name ht --cfg ht_pqlM --n_hop 3 --num_workers 2 --lr 0.0001 --abl_ans_fc --exp_name ht_pqlM --split_seed 789 | ||
``` | ||
|
||
## Evaluation | ||
We release the trained model checkpoints (1-hop, 2-hop, 3-hop on ORG and PRP questions) that we have used for our experiments on KVQA dataset. | ||
``` | ||
bash download_kvqa_checkpoints.sh | ||
# for original (ORG) questions on oracle setting in Table 1 | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 1 --q_opt org --exp_name ht_kvqa_org_1hop_dist --inference | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 2 --q_opt org --exp_name ht_kvqa_org_2hop_dist --inference | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 3 --q_opt org --exp_name ht_kvqa_org_3hop_dist --inference | ||
# for phraphrased (PRP) questions on oracle setting in Table 1 | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 1 --q_opt prp --exp_name ht_kvqa_prp_1hop_dist --inference | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 2 --q_opt prp --exp_name ht_kvqa_prp_2hop_dist --inference | ||
python main.py --model_name ht --cfg ht_kvqa --n_hop 3 --q_opt prp --exp_name ht_kvqa_prp_3hop_dist --inference | ||
``` | ||
|
||
We also release the trained model checkpoints that we have achieved best performance on the five repeated runs of different data splits in PQ and PQL dataset. | ||
``` | ||
# for PQ dataset | ||
bash download_pq_checkpoints.sh | ||
python main.py --data_name pq --model_name ht --cfg ht_pq2h --n_hop 2 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pq2h_dist | ||
python main.py --data_name pq --model_name ht --cfg ht_pq3h --n_hop 3 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pq3h_dist | ||
python main.py --data_name pq --model_name ht --cfg ht_pqM --n_hop 3 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pqM_dist | ||
# for PQL dataset | ||
bash download_pql_checkpoints.sh | ||
python main.py --data_name pql --model_name ht --cfg ht_pql2h --n_hop 2 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pql2h_dist --split_seed 789 | ||
python main.py --data_name pql --model_name ht --cfg ht_pql3h --n_hop 3 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pql3h_dist --split_seed 789 | ||
python main.py --data_name pql --model_name ht --cfg ht_pql3h_more --n_hop 3 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pql3h_more_dist --split_seed 789 | ||
python main.py --data_name pql --model_name ht --cfg ht_pqlM --n_hop 3 --num_workers 2 --abl_ans_fc --inference --exp_name ht_pqlM_dist --split_seed 789 | ||
``` | ||
|
||
## Credits | ||
* Parts of the code were adapted from [Multimodal Transformer](https://github.com/yaohungt/Multimodal-Transformer) by Yao-Hung Hubert Tsai. | ||
|
||
## Citation | ||
``` | ||
@inproceedings{heo-etal-2022-hypergraph, | ||
title = "Hypergraph {T}ransformer: {W}eakly-Supervised Multi-hop Reasoning for Knowledge-based Visual Question Answering", | ||
author = "Heo, Yu-Jung and Kim, Eun-Sol and Choi, Woo Suk and Zhang, Byoung-Tak", | ||
booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", | ||
month = may, | ||
year = "2022", | ||
address = "Dublin, Ireland", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://aclanthology.org/2022.acl-long.29", | ||
pages = "373--390" | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
DATASET: | ||
NAME: "kvqa" | ||
RAW_DATA: "data/kvqa/raw/dataset.json" | ||
PROC_DATA: "data/kvqa/processed/proc_data.pkl" | ||
FACT: "data/kvqa/processed/fact_with_index.pkl" | ||
REL2IDX: "data/kvqa/processed/rel_index.pkl" | ||
AVOCAB2IDX: "data/kvqa/processed/ans_vocab2idx.pkl" | ||
IDX2VOCAB: "data/kvqa/processed/kg_qc_ans_idx2vocab.pkl" | ||
VOCAB2IDX: "data/kvqa/processed/kg_qc_ans_vocab2idx.pkl" | ||
NE2QID: "data/kvqa/processed/ne2qid.pkl" | ||
QID2NE: "data/kvqa/processed/qid2ne.pkl" | ||
IDX2QTYPE: "data/kvqa/processed/idx2qtype.pkl" | ||
KG_1hop: "data/kvqa/processed/kg_1hop.pkl" | ||
KG_2hop: "data/kvqa/processed/kg_2hop.pkl" | ||
KG_3hop: "data/kvqa/processed/kg_3hop.pkl" | ||
KG_spat: "data/kvqa/processed/kg_spatial.pkl" | ||
GLOVE: "data/kvqa/processed/glove_embs_kvqa.pkl" | ||
GLOVE_ANS_CAND: "data/kvqa/processed/glove_embs_kvqa_ans.pkl" | ||
RES: | ||
TB: "tb/" | ||
CKPT: "ckpt/" | ||
LOG: "log/" | ||
MODEL: | ||
SEED: 1234 | ||
NUM_EDGE: 19 | ||
NUM_MAX_Q: 15 | ||
NUM_MAX_C: 20 | ||
NUM_MAX_QNODE: 3 | ||
NUM_MAX_HK_1H: 50 | ||
NUM_MAX_HK_2H: 100 | ||
NUM_MAX_HK_3H: 150 | ||
NUM_MAX_KNODE_1H: 6 | ||
NUM_MAX_KNODE_2H: 10 | ||
NUM_MAX_KNODE_3H: 12 | ||
FC_HID_COEFF: 4 | ||
NUM_OUT: 300 | ||
NUM_ANS: 19360 | ||
NUM_WORD_EMB : 300 | ||
NUM_HIDDEN: 256 | ||
NUM_HEAD: 4 | ||
NUM_LAYER: 2 | ||
INP_DROPOUT: 0.0 | ||
ATTN_DROPOUT_K: 0.0 | ||
ATTN_DROPOUT_Q: 0.0 | ||
RELU_DROPOUT: 0.0 | ||
RES_DROPOUT: 0.0 | ||
EMB_DROPOUT: 0.0 | ||
ATTN_MASK: True | ||
BATCH_SIZE: 256 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
DATASET: | ||
NAME: "pq2h" | ||
PROC_DATA: "data/PathQuestion/processed/pq2h/proc_data.pkl" | ||
AVOCAB2IDX: "data/PathQuestion/processed/pq2h/ans_vocab2idx.pkl" | ||
IDX2VOCAB: "data/PathQuestion/processed/pq2h/idx2vocab.pkl" | ||
VOCAB2IDX: "data/PathQuestion/processed/pq2h/vocab2idx.pkl" | ||
KG_2hop: "data/PathQuestion/processed/pq2h/kg_2hop.pkl" | ||
GLOVE: "data/PathQuestion/processed/pq2h/glove_embs.pkl" | ||
GLOVE_ANS_CAND: "data/PathQuestion/processed/pq2h/glove_embs_ans.pkl" | ||
RES: | ||
TB: "tb/" | ||
CKPT: "ckpt/" | ||
LOG: "log/" | ||
MODEL: | ||
SEED: 1234 | ||
NUM_EDGE: 13 | ||
NUM_MAX_Q: 11 | ||
NUM_MAX_ASET: 2 | ||
NUM_MAX_QNODE: 3 | ||
NUM_MAX_HK_2H: 10 | ||
NUM_MAX_KNODE_2H: 5 | ||
FC_HID_COEFF: 1 | ||
NUM_OUT: 300 | ||
NUM_ANS: 305 | ||
NUM_WORD_EMB : 300 | ||
NUM_HIDDEN: 256 | ||
NUM_HEAD: 4 | ||
NUM_LAYER: 2 | ||
INP_DROPOUT: 0.2 | ||
ATTN_DROPOUT_K: 0.1 | ||
ATTN_DROPOUT_Q: 0.1 | ||
RELU_DROPOUT: 0.1 | ||
RES_DROPOUT: 0.1 | ||
EMB_DROPOUT: 0.1 | ||
ATTN_MASK: True | ||
BATCH_SIZE: 128 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
DATASET: | ||
NAME: "pq3h" | ||
PROC_DATA: "data/PathQuestion/processed/pq3h/proc_data.pkl" | ||
AVOCAB2IDX: "data/PathQuestion/processed/pq3h/ans_vocab2idx.pkl" | ||
IDX2VOCAB: "data/PathQuestion/processed/pq3h/idx2vocab.pkl" | ||
VOCAB2IDX: "data/PathQuestion/processed/pq3h/vocab2idx.pkl" | ||
KG_3hop: "data/PathQuestion/processed/pq3h/kg_3hop.pkl" | ||
GLOVE: "data/PathQuestion/processed/pq3h/glove_embs.pkl" | ||
GLOVE_ANS_CAND: "data/PathQuestion/processed/pq3h/glove_embs_ans.pkl" | ||
RES: | ||
TB: "tb/" | ||
CKPT: "ckpt/" | ||
LOG: "log/" | ||
MODEL: | ||
SEED: 1234 | ||
NUM_EDGE: 13 | ||
NUM_MAX_Q: 15 | ||
NUM_MAX_ASET: 5 | ||
NUM_MAX_QNODE: 3 | ||
NUM_MAX_HK_3H: 45 | ||
NUM_MAX_KNODE_3H: 7 | ||
FC_HID_COEFF: 1 | ||
NUM_OUT: 300 | ||
NUM_ANS: 1009 | ||
NUM_WORD_EMB : 300 | ||
NUM_HIDDEN: 256 | ||
NUM_HEAD: 4 | ||
NUM_LAYER: 2 | ||
INP_DROPOUT: 0.2 | ||
ATTN_DROPOUT_K: 0.1 | ||
ATTN_DROPOUT_Q: 0.1 | ||
RELU_DROPOUT: 0.1 | ||
RES_DROPOUT: 0.1 | ||
EMB_DROPOUT: 0.1 | ||
ATTN_MASK: True | ||
BATCH_SIZE: 128 |
Oops, something went wrong.