This project implements an unsupervised domain adaptation model based on the paper "HoMM: Higher-order Moment Matching for Unsupervised Domain Adaptation". The method extends standard discrepancy-based losses (MMD, CORAL) by matching higher-order moments in the feature space.
-
Architecture:
The model uses a ResNet34 backbone with a custom adapted layer (using tanh activation instead of relu) to extract features. An added classification layer produces the final predictions. Two loss components are used:- Domain discrepancy loss (HoMM loss):
This loss matches higher-order statistics between source and target domains. Several versions are implemented, including direct 3rd-order, grouped 4th-order, and an arbitrary-order variant via random sampling. - Discriminative clustering loss:
This loss enforces that pseudo-labeled target samples become closer to their respective class centers. The centers are updated with a moving average.
- Domain discrepancy loss (HoMM loss):
-
Data:
A customSubsetImageFolderclass is used to load only a subset of classes from the dataset. Two domains are considered, for example,product_imagesandreal_life. -
Training:
Training is performed with two modes:- A full UDA training step that combines the classification, HoMM, and clustering losses.
- A baseline training step that uses cross-entropy loss only.
Hyperparameters such as batch size, learning rate, HoMM order, and lambda factors can be tuned. For instance, the script uses a batch size of 128, a learning rate of 0.001, and lambda values to weight the loss contributions.
- Python with PyTorch and torchvision
- tqdm
- Google Colab (for drive mounting if running in Colab)
- matplotlib (for plotting training curves)
-
Preparation:
- Download the dataset (e.g.,
Adaptiope.zip) and unzip it. - Adjust the
img_rootand subset names as needed.
- Download the dataset (e.g.,
-
Running the Experiment:
- To run the full UDA model, execute the notebook cells (or run the following via a main function):
main() - To run the baseline model, call:
main(baseline=True)
- To reverse the domains, use:
main(reverse=True)
- Both reverse and baseline options can be combined:
main(reverse=True, baseline=True)
- To run the full UDA model, execute the notebook cells (or run the following via a main function):
-
Monitoring:
- The training progress is printed and accuracy and loss plots are generated at the end of training.
Some key hyperparameter settings in the main function:
batch_size=128homm_order=4num_samples=350000lambda_d=100(discrepancy loss weight)lambda_dc=0.1(clustering loss weight)alpha=0.7for center updatesthreshold=0.65for selecting target samples for clustering