To reproduce the paper, see section Reproducing the paper results.
Best practice is to create a new conda/mamba environment, and install fresh from the conda-forge channel, e.g. install Miniforge:
For running the multimodal representation algorithms:
- JAX:
mamba install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia- additionally on the IWR server:
mamba install cuda-cupti
- additionally on the IWR server:
- flax:
mamba install -c conda-forge flax - optax:
mamba install -c conda-forge optax - orbax:
mamba install -c conda-forge orbax - Huggingface Datasets:
mamba install -c huggingface -c conda-forge datasets
For running the scripts analyzing the embeddings:
- UMAP:
mamba install -c conda-forge umap-learn - cycler:
mamba install -c conda-forge cycler - Cmasher:
mamba install -c conda-forge cmasher - mpl-scatter-density:
mamba install -c conda-forge mpl-scatter-density
When using UMAP plotting functionality, be aware of its dependencies: https://umap-learn.readthedocs.io/en/latest/plotting.html.
For a working Contrastive Learning training setup, three things are needed:
The dataset needs to be a huggingface dataset saved to disk. Exemplary scripts for creating those from e.g. numpy data,
or CSV files are in the dataprep folder.
Consider naming the features expressively and easily distinguishable from one another. Normalization of data should take place here as data preparation step.
The huggingface dataset (dict) has to have the top-level entries train and valid.
The models for the different modalities should be created in different files in multispecs/modalities/encoders. The name
of each file should be exactly the same as the name given to the corresponding feature in the huggingface dataset. There
can be several models for the same modality in one file.
All models need to be of the flax.nnx.Module class from the flax package.
The basic configuration has to be given in a config file, structured in yaml format. See configs for example configurations.
Paths can be given relative to the user's home directory (~/...).
Parameters can be overwritten by CLI arguments. Run the help argument for more information on arguments.
With all preparations complete, CL training can be started by running the main function with the train argument,
followed by the path to the config file.
Make sure the project's root is included in python's PYTHONPATH environment variable,
e.g. when calling the main function:
PYTHONPATH="${PYTHONPATH}:<path-to-project>/ba_multispecs/" python main.py train <path-to-config>
Alternatively, let the IDE take care of this, or add it permanently to the environment.
The scripts used to run tests and create visualisations can be found in scripts.
Here is a step-by-step guide to reproduce the results from the paper "Deep Multimodal Representation Learning for Stellar Spectra":
This guide assumes a Linux OS.
The scripts are written for a working directory of ~/workspace. If your working directory (that is, the directory you're going to git clone this repo into) deviates from that,
please search the mentioned files for any occurences of ~/workspace and replace it with <your-workspace>.
Install all the dependencies from above, and activate your python environment:
mamba activate <your_mamba/conda_environment>
Go into your working directory and download this repository into your workspace:
cd <your_workspace>
git clone https://codeberg.org/cschwarz/ba_multispecs.git
cd ba_multispecs
Create or link a data and output folder (you might want to link to a big storage drive):
mkdir data output output/tests_ba
or
ln -s <path_to_existing_data_folder> data
ln -s <path_to_existing_output_folder> output
Download dataset from huggingface:
huggingface-cli download christianschwarz/deep-multimodal-representation-learning-for-stellar-spectra --repo-type dataset --local-dir ./data/huggingface_datasets/rvs_xp_w_types
Replace ~/workspace with your working directory in:
config/_tests/07-longrun-combinations/01-config_rvs-cnn_xp-1layer.yamlscripts/bachelor_thesis/09-encoder-runs/cl_combinations_longrun_tenth.py(when taking the script route)
Either train the model manually (see Usage) four times with the config files config/_tests/07-longrun-combinations/01-config_rvs-cnn_xp-1layer.yaml.
Take care to divide the learning rate's peak value (peak_value) in config/_tests/07-longrun-combinations/02-learning-rate-cycle.yaml by ten at every restart,
and change the setup/encoders/params to the best performing models of the previous run, which can be found in the output folder.
Or run the script scripts/bachelor_thesis/09-encoder-runs/cl_combinations_longrun_tenth.py:
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/09-encoder-runs/cl_combinations_longrun_tenth.py
where you can change CUDA_VISIBLE_DEVICES to the GPU number you want to use and set <your-workspace>. This might take some time.
This should result in trained models in your <workspace>/workspace/ba_multispecs/output/tests_ba/09-encoder-longrun-tenth/rvs-cnn_xp-1layer folder, where each run is saved with a timestamp.
To save compute time, we run the encoders on all data and save this to a new dataset. For this just run the following script:
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/10-embedding/save_embeddings_to_dataset.py
where you again have to replace ~/workspace with yours in the script.
The new dataset with embeddings is saved in <your-workspace>/ba_multispecs/output/tests_ba/10-embeddings/.
To run UMAP on the embeddings, run the following script:
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/10-embedding/save_embeddings_to_dataset.py
where you again have to replace ~/workspace with yours in the script.
This will run UMAP with different parameters, and color the resulting embedding with different physical parameters.
The resulting images are saved in <your-workspace>/ba_multispecs/output/tests_ba/12-embeddings-visualization.
Prepare the k-NN algorithm by first finding the nearest neighbors with the following script:
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/11-k-NN/k-nn-preprocessing.py
where you again have to replace ~/workspace with yours in the script.
Then you can run regression, classification, and retrieval with:
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/11-k-NN/k-nn-regression.py
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/11-k-NN/k-nn-classification.py
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/11-k-NN/k-nn-lookup.py
where you again have to replace ~/workspace with yours in the script.
The results, including statistics, are saved in <your-workspace>/ba_multispecs/output/tests_ba/11-k-NN.
The script (where you again have to replace ~/workspace with yours)
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/13-fusion-regression/train_fusion_regression.py
trains several fusion models for regression. The script
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/13-fusion-regression/create_study_figures_fusion_reg.py
creates the corresponding statistics and images <your-workspace>/ba_multispecs/output/tests_ba/13-fusion-regression.
The script (where you again have to replace ~/workspace with yours)
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/15-fusion-decoding/train_fusion_decode.py
trains several models for cross-modal generation (and autoencoding). The script
PYTHONPATH="${PYTHONPATH}:<your-workspace>/ba_multispecs/" CUDA_VISIBLE_DEVICES="0" python ./scripts/bachelor_thesis/15-fusion-decoding/create_study_figures_fusion_gen.py
creates the corresponding statistics and example generations <your-workspace>/ba_multispecs/output/tests_ba/15-fusion-decoding.
When running any scripts calling JAX functions and running into Out of Memory errors, consider reading JAX's GPU memory allocation page.