Skip to content

davidemodolo/HoMM-DomainAdaptation

Repository files navigation

HoMM Domain Adaptation

This project implements an unsupervised domain adaptation model based on the paper "HoMM: Higher-order Moment Matching for Unsupervised Domain Adaptation". The method extends standard discrepancy-based losses (MMD, CORAL) by matching higher-order moments in the feature space.

Overview

  • Architecture:
    The model uses a ResNet34 backbone with a custom adapted layer (using tanh activation instead of relu) to extract features. An added classification layer produces the final predictions. Two loss components are used:

    • Domain discrepancy loss (HoMM loss):
      This loss matches higher-order statistics between source and target domains. Several versions are implemented, including direct 3rd-order, grouped 4th-order, and an arbitrary-order variant via random sampling.
    • Discriminative clustering loss:
      This loss enforces that pseudo-labeled target samples become closer to their respective class centers. The centers are updated with a moving average.
  • Data:
    A custom SubsetImageFolder class is used to load only a subset of classes from the dataset. Two domains are considered, for example, product_images and real_life.

  • Training:
    Training is performed with two modes:

    • A full UDA training step that combines the classification, HoMM, and clustering losses.
    • A baseline training step that uses cross-entropy loss only.

    Hyperparameters such as batch size, learning rate, HoMM order, and lambda factors can be tuned. For instance, the script uses a batch size of 128, a learning rate of 0.001, and lambda values to weight the loss contributions.

Dependencies

  • Python with PyTorch and torchvision
  • tqdm
  • Google Colab (for drive mounting if running in Colab)
  • matplotlib (for plotting training curves)

How to Run

  1. Preparation:

    • Download the dataset (e.g., Adaptiope.zip) and unzip it.
    • Adjust the img_root and subset names as needed.
  2. Running the Experiment:

    • To run the full UDA model, execute the notebook cells (or run the following via a main function):
      main()
    • To run the baseline model, call:
      main(baseline=True)
    • To reverse the domains, use:
      main(reverse=True)
    • Both reverse and baseline options can be combined:
      main(reverse=True, baseline=True)
  3. Monitoring:

    • The training progress is printed and accuracy and loss plots are generated at the end of training.

Hyperparameters

Some key hyperparameter settings in the main function:

  • batch_size=128
  • homm_order=4
  • num_samples=350000
  • lambda_d=100 (discrepancy loss weight)
  • lambda_dc=0.1 (clustering loss weight)
  • alpha=0.7 for center updates
  • threshold=0.65 for selecting target samples for clustering

About

Project for "Deep Learning" course at UNITN - Unsupervised Domain Adaptation on Adaptiope Dataset

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published