Official PyTorch implementation of SaGess.
SaGess is a discrete denoising diffusion model, which extends DiGress with a divide-and-conquer strategy to generate large synthetic networks by training on subgraph samples and reconstructing the overall graph.
Create anaconda environment
chmod +x *.sh
./install_conda_env.sh
Activate the environment
conda activate sagess
By default, wandb stores the logs offline and would need to be synced after training.
Make sure to set the 'entity' parameter in the setup_wandb()
function located in src/run_sagess.py
to be able to sync the logs to your account.
'entity': 'wandb_username'
For online syncing, change the 'wandb' parameter in configs/general/general_default.yaml
to 'online'.
The main script can be launched as such:
python src\run_sagess.py dataset=Cora
4 datasets from torch_geometric
are supported: Cora, Wiki, EmailEUCore, ego-facebook and one custom SBM dataset loaded as a .pkl
file. All the datasets are downloaded to or placed in the data
folder.
Saved checkpoints, wandb log folder and other outputs can be found in the outputs
folder.
Dataset specific configuration resides in configs/dataset/*.yaml
files, including the number of subgraphs to train on, their size and sampling method.
Other default parameters for DiGress are found in configs/train/train_default.yaml
, configs\model\discrete.yaml
and configs\general\general_default.yaml
.
To build and run the docker container, use docker_build.sh
and run_docker_container.sh
scripts respectively.