Skip to content

Latest commit

 

History

History
118 lines (101 loc) · 6.5 KB

README.md

File metadata and controls

118 lines (101 loc) · 6.5 KB

Do captions in different languages produce different images?: Efficiently training Multilingual Diffusion 🇬🇧🇩🇪🇫🇷

In this project, we ask the question: Is it worth adding non-English support to monolingual Text-to-Image models or can we simply get away with translating the non-English prompts to English. We train a multilingual Diffusion model, based on Stable Diffusion v2.1, to support German and French, in addition to English. To do so, first, we construct two high quality training datasets for our proposed training method by filtering the WIT dataset and the Conceptual Captions 12M dataset. Next, we perform two stages of training:

  • 1). Teacher Learning (To add multilingual capabilities to the CLIP ViT-H/14 text encoder model)
  • 2). Concept Alignment (To align Stable Diffusion with the new text encoder by fine tuning the U-Net with LoRA rank 4)

Finally, we test our multilingual diffusion model (which we dub RKS-diffusion) and the standard Stable Diffusion v2.1 model on our high quality WIT test subset (see results below).

This project was completed as part of 10-623 under the guidance of Prof. Matt Gormley and Henry Chai at CMU. For more details refer to our poster

Main Contributions

  • Our Fine-tuned Multilingual Diffusion Model: RKS-Diffusion (Supports English, French, and German)
  • Our Fine-tuned Multilingual CLIP Text Model: RKS-CLIP-Text-Encoder
  • Our high quality subset of WIT : Multilingual-RKS-WIT (train/test split: 6k/1.5k image-caption pairs, with equal English, German, and French representation in captions).

Note : RKS is a reference to Arceus

Results

FID and IS scores for the images generated using the two models, on three languages: English (EN), German (DE), and French (FR)

Model FID(↓) IS(↑)
EN DE FR EN DE FR
Stable Diffusion v2.1 (Baseline) 1.08 1.16 1.3 11.33 11.45 11.17
RKS-Diffusion (Ours) 0.99 1.04 0.95 11.73 12.42 11.63

For the baselines, English achieves the best FID score, and surprisingly, German gets the best IS score (perhaps due to its similarity to English). RKS-Diffusion outperforms the baseline on all metrics for French and German, while still not sacrificing English performance. French gets the biggest improvement in FID score, and German gets the biggest improvement in IS score.

Directory Structure

  • data
    • utils
      • clip.py: CLIP Score generator (Used in wit_dataset_filtering.py)
      • translator.py: Translates German and French captions with NLLB-200 (Used in wit_dataset_filtering.py)
    • cc_dataset_filtering.py: Filter the CC-12 dataset
    • wit_dataset_filtering.py: Filter the WIT dataset
    • final_dataset_translated.csv: The complete high quality filtered WIT sample
    • final_test.csv: Test split of WIT sample (Used for Evaluation)
    • final_train.csv: Train split of WIT sample (Used for Stage-2 Training of U-Net)
    • teacher_set.csv: Train set using CC-12 (Used for Stage-1 Training of CLIP Text Encoder)
  • scripts
    • evaluation.py: Evaluation (FID and IS) code for the generated images
    • lora_inference.py: Inference code for the trained RKS-Diffusion model with the trained RKS-CLIP-Text-Encoder (using pipeline)
    • manual_inference.py: Inference code for Stable Diffusion v2.1 (from scratch, i.e. manually performing the reverse denoising process)
    • teacher_learning.py: Training Stage-1 code for the RKS-CLIP-Text-Encoder
    • train_text_to_image_lora_rks.py: Training Stage-2 code for fine-tuning Stable Diffusion with LoRA (adapted from this blog by Huggingface)
    • lora.sh: The hyperparameters for the LoRA fine-tuning
    • visualize.py: Code to compare and visualize the output of the same prompt in three languages with the baseline and our model

Reproduce the Results

Recreate the environment (Importantly, you need to perform installation of 'diffusers' from the source):

conda env create --file requirements.txt -n genai
conda activate genai
pip install git+https://github.com/huggingface/diffusers

Recreate Dataset

First, download data files from: CC12 and WIT

cd data
python cc_dataset_filtering.py
python wit_dataset_filtering.py

Training

python teacher_learning.py  # Teacher Learning
bash lora.sh  # Concept Alignment

Evaluation

python manual_inference.py --output_dir <baseline_output_dir> 
python lora_inference.py --checkpoint_path <your_checkpoint_path> --output_dir <rks_output_dir> 
python eval.py --generated_image_dir <baseline_output_dir> > baseline_results.txt
python eval.py --generated_image_dir <rks_output_dir> > rks_diffusion_results.txt

Visualize

python visualize.py --checkpoint_path <your_checkpoint_path> --step_size <the step size to iterate over: 1000, 2000, ..>

Training Runs

You can find W&B dashboards of our training runs here: