Skip to content

reflection-removal: A model for removing reflections using a single image. The goal is to significantly reduce the amount of calculation required for GAN and remove reflections in real time. The quality of the reflection removal is not important for this project. Built with DINOv3.

License

Notifications You must be signed in to change notification settings

PINTO0309/reflection-removal

Repository files navigation

reflection-removal

DOI GitHub License Ask DeepWiki

WIP, November 13, 2025

A model for removing reflections using a single image. The goal is to significantly reduce the amount of calculation required for GAN and remove reflections in real time. The quality of the reflection removal is not important for this project. Built with DINOv3.

Real Removal
100025 100025_transmission
hyodo_001 hyodo_001_transmission
Variant Size TensorRT
inference
latency
ONNX
dinov3_vitt_distill_disthyper_residual_gennerator_640x640_320x320 24.2 MB 6.75 ms Download
dinov3_vitt_distill_disthyper_residual_gennerator_640x640_640x640 24.2 MB 18.82 ms Download
dinov3_vits16_disthyper_residual_gennerator_640x640_320x320 89.2 MB 9.49 ms Download
dinov3_vits16_disthyper_residual_gennerator_640x640_640x640 89.3 MB 24.97 ms Download

Setup

git clone https://github.com/PINTO0309/reflection-removal.git && cd reflection-removal
curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync
source .venv/bin/activate
  • All pretrained weights (VGG-19) are fetched automatically via torchvision, so no manual download is required.
  • Additional backbones are supported:
    • DINOv3 ViT-Tiny (dinov3_vitt) — DEIMv2-finetuned weights stored at ckpts/deimv2_dinov3_s_wholebody34.pth. No torch.hub download is available, so keep this file locally.
    • HGNetV2 (hgnetv2) — DEIMv2-finetuned CNN backbone stored at ckpts/deimv2_hgnetv2_n_wholebody34.pth.
    • DINOv3 standard variants (dinov3_vits16, dinov3_vits16plus, dinov3_vitb16) — place the provided checkpoints inside ./ckpts/ (files named dinov3_*_pretrain_lvd1689m-*.pth). If the files are missing, they will be downloaded automatically through torch.hub.

Dataset

There is a significant lack of real-world reflection data, but it should be enough to ensure that learning is proceeding normally. Synthetic data is useless, so we should diligently collect authentic data from the real world.

https://github.com/ceciliavision/perceptual-reflection-removal?tab=readme-ov-file#dataset

reflection-dataset/
├── real
│   ├── blended
│   │   ├── 100001.jpg
│   │   ├── 100002.jpg
│   │   ├── 100003.jpg
│   │   ├── 100004.jpg
│   │   └── 100005.jpg
│   └── transmission_layer
│       ├── 100001.jpg
│       ├── 100002.jpg
│       ├── 100003.jpg
│       ├── 100004.jpg
│       └── 100005.jpg
└── synthetic
    ├── reflection_layer
    │   ├── 100001.jpg
    │   ├── 100002.jpg
    │   ├── 100003.jpg
    │   ├── 100004.jpg
    │   └── 100005.jpg
    └── transmission_layer
        ├── 100001.jpg
        ├── 100002.jpg
        ├── 100003.jpg
        ├── 100004.jpg
        └── 100005.jpg

Training

backbone: "vgg19", "hgnetv2", "dinov3_vitt", "dinov3_vits16", "dinov3_vits16plus", "dinov3_vitb16"

# Baseline
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16 \
--use_amp
# Initialized with DINOv3 default weights + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_residual \
--residual_skips \
--residual_init 0.1 \
--use_amp

Residual generator variants:

  • --residual_skips enables single-stage residual scaling after each dilated block.
  • --residual_in_residual_skips stacks the eight dilated blocks into two Residual-in-Residual groups (4 layers each) for deeper skip modulation. This flag supersedes --residual_skips when both are provided.
# Initialized with pretrained weights + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_residual \
--ckpt_file ckpts/reflection_removal_dinov3_vits16.pt \
--residual_skips \
--residual_init 0.1 \
--use_amp
# Distributed Hypercolumn + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_disthyper_residual \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_skips \
--residual_init 0.1 \
--use_amp

Distributed hypercolumns & channel reduction

  • --use_distributed_hypercolumn replaces the full hypercolumn concat (raw backbone features + RGB) with a distributed projection: each backbone stage is first reduced by a learnable 1×1 convolution and the concatenated tensor is then compressed to 64 channels via a final 1×1. This greatly lowers memory usage while keeping the generator interface unchanged. Enable this flag whenever you plan to export compact ONNX graphs or train on high resolutions; checkpoints store the necessary projection weights.
  • --hypercolumn_channel_reduction_scale controls how aggressively each stage is reduced. The reducer for a layer with C channels will emit ceil(C / scale) channels. The default (4) keeps roughly 25 % of the original channels per stage; higher values (e.g. 8, 16) shrink both parameters and FLOPs linearly, at the cost of fewer hypercolumn features. The final post-projection always outputs 64 channels, so downstream generator layers remain compatible across scales. When loading a checkpoint, the correct scale is inferred automatically, so you only need to pass this flag during training if you want a non-default compression ratio.
  • Practical guidance:
    1. Start with --use_distributed_hypercolumn --hypercolumn_channel_reduction_scale 4 for general training—it balances memory and fidelity.
    2. If you are memory-bound, increase the scale to 8 or 16. Expect parameter counts and compute inside the hypercolumn projector to drop by roughly ½ and ¼ respectively when moving from 4→8→16.
    3. If you disable distributed hypercolumns, the generator reverts to concatenating fully-resolved backbone maps; this offers maximal information but is significantly heavier and may not match ONNX deployment paths.

To resume from a previous checkpoint inside dinov3_vits16/:

uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_disthyper_residual \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_skips \
--residual_init 0.1 \
--use_amp \
--resume
Value Note
loss The average content loss calculated by adding the L1 coefficient + 0.2×perceptual + grad. When actually updating the generator, this is multiplied by 100 and added to the next adv.
percep The average perceptual loss, which measures distance in feature space (DINO/VGG).
grad The average exclusion/gradient loss, which prevents the gradients of the transmitted image and the reflected image from overlapping.
adv The average adversarial loss (BCE), which is used to make the classifier believe the image is "real."
feat_dist Mean MSE between the student and projected teacher feature maps; reported only when feature distillation is enabled.
pix_dist Mean L1 distance between the student outputs and the frozen teacher outputs (transmission, and reflection when available); reported only when pixel distillation is enabled.

Therefore, the loss display is not the final total loss, but an indicator for checking the basic loss balance on the content side.

If you want the teacher’s influence to fade during late epochs, pass --enable_distill_decay. This keeps the original distillation weights for a warmup period (default --distill_decay_warmup_epochs 5) and then applies a cosine schedule that smoothly decays both feature and pixel distillation weights to zero by the final epoch. The decay is disabled by default; omit the flag to keep static teacher weights. Distillation weights/decay settings are stored inside each checkpoint, so resuming training reproduces the same behavior even if you forget to re-specify the CLI flags.

Distillation from dinov3_vits16 to dinov3_vitt and fine-tuned backbone in DEIMv2 for students to learn.

# Without Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill \
--ckpt_dir ckpts \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16.pt \
--use_amp
# With Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_residual.pt \
--use_amp
# With Distributed Hypercolumn + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp

uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 8 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp

uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 16 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp

Interpreting the Loss Components

  • Content vs. adversarial balance: For each generator update we optimise content_loss = L1(reflection) + 0.2 × perceptual + grad, and the actual objective that is backpropagated is total_g = 100 × content_loss + adv. A rise in loss usually means the reconstruction terms need attention; a rise in adv means the discriminator is currently winning.
  • Synthetic vs. real batches: When training on synthetic pairs both the transmission and reflection branches contribute to the perceptual/L1 terms, so loss should generally be higher than during real batches (where reflection supervision is zero). Expect grad to be non-zero only for synthetic samples.
  • Healthy dynamics: You want loss, percep, and grad to trend downward slowly while adv oscillates. If all climb together, the model is diverging. If adv collapses near 0 while others stagnate, the discriminator may be too weak—lower its learning rate or add more synthetic data. If adv stays very high but the other terms shrink, the discriminator is too strong—consider reducing the GAN weight (e.g., scaling the final + adv) or adding label smoothing.
  • Practical monitoring: Track the logged scalars in TensorBoard. Focus on the moving averages per epoch; transient spikes after checkpoint saves are normal. Compare checkpoints by running --test_only so you can visually confirm whether changes in the metrics translate to better separation.

Per-Epoch Validation Dumps

After every training epoch the current generator runs inference on the images specified by --test_dir, and the outputs are written to test_results/<exp_name>/epoch_<NNNN>/. If --test_dir does not resolve to any images, the script falls back to a fixed set of up to 10 blended inputs gathered from --data_real_dir, so you still get a consistent visual trace of progress across epochs. Clean up these folders periodically if disk usage grows too large.

Each validation sample produces a directory numbered by processing order (e.g. test_results/<exp_name>/epoch_0001/0001_<image_stem>/) containing three PNGs:

  • input.png — the raw blended input frame.
  • t_output.png — the predicted transmission layer.
  • r_output.png — the predicted reflection layer.

Checkpoints, intermediate predictions, train.log, and TensorBoard summaries (saved directly inside runs/dinov3_vits16/) are all stored under runs/dinov3_vits16/. Launch TensorBoard via:

tensorboard --logdir runs

Arguments

--exp_name: experiment name. Training artifacts (checkpoints, train.log, image dumps) are stored under ./runs/<exp_name>/, and inference outputs under ./test_results/<exp_name>/.

--data_syn_dir: comma-separated list of synthetic dataset roots

--data_real_dir: comma-separated list of real dataset roots

--save_model_freq: frequency to save model and the output images

--keep_checkpoint_history: number of saved checkpoint epochs (epoch_<NNNN> folders under runs/<exp_name>/) to retain (0 keeps all)

--backbone: feature extractor for hypercolumns and perceptual loss (vgg19, hgnetv2, dinov3_vitt, dinov3_vits16, dinov3_vits16plus, dinov3_vitb16). Hypercolumn features are always enabled; older runs that used --is_hyper now default to the same behaviour.

--ckpt_dir: directory where backbone checkpoints are searched (default ckpts) --ckpt_file: optional generator checkpoint that seeds training; loads weights before the first epoch (ignored with --resume)

--test_only: skip training and run inference only

--resume: resume training from the last checkpoint in runs/<exp_name>/

--use_amp: enable torch.cuda.amp mixed precision (effective on CUDA devices)

--epochs: number of training epochs (default 100)

--device: device string such as cuda:0 or cpu

Testing

uv run python main.py \
--exp_name dinov3_vits16_test \
--test_only \
--backbone dinov3_vits16 \
--test_dir ./test_images

Make sure the --backbone flag matches the model that produced the checkpoint you are loading.

If --test_only is omitted, the script trains by default and writes checkpoints/metrics to runs/<exp_name>/.

Test outputs are written to ./test_results/<exp_name>/<image_name>/.

Dataset PSNR/SSIM Benchmark

Use compute_pnsr_ssim_metrics.py to compute transmission-layer PSNR and SSIM for the samples stored under reflection-dataset/. By default it compares the paired ground-truth images directly:

uv run python compute_pnsr_ssim_metrics.py --subset real

Key options:

  • --subset {real,synthetic,all} – choose the dataset split.
  • --max-samples N – limit the number of pairs per split (handy for quick spot checks).
  • --output-csv metrics.csv – dump per-image metrics for further analysis.
  • --skip-missing/--no-skip-missing – control what happens when a pair is absent (--skip-missing is enabled by default).
  • --synthetic-seed – seed controlling the on-the-fly blend generation used for the synthetic split (requires a model flag).

To benchmark predicted transmissions from a PyTorch checkpoint (either a .pt file or a runs/<exp>/epoch_xxxx/ directory), point the script at the artifact:

uv run python compute_pnsr_ssim_metrics.py \
--subset real \
--max-samples 100 \
--checkpoint runs/dinov3_vitt_distill_disthyper_residual/epoch_0033/checkpoint.pt \
--device cuda:0 \
--output-csv metrics.csv

If you prefer ONNX Runtime instead, supply --onnx-model <path> and select the execution providers in priority order via repeated --providers flags (cpu, cuda, tensorrt):

uv run python compute_pnsr_ssim_metrics.py \
--subset all \
--max-samples 500 \
--onnx-model dinov3_vitt_gennerator_640x640_640x640.onnx \
--providers cuda --providers cpu \
--output-csv metrics.csv

When a model is provided, the script feeds each blended (or synthetic reflection) image through the generator, resizes the transmission prediction back to the ground-truth resolution, and reports split-level summary statistics plus optional CSV output. For the synthetic split, the script now mirrors training by generating blended inputs on the fly from every reflection/transmission pair, so you must specify either --checkpoint or --onnx-model whenever you evaluate --subset synthetic (or --subset all, which includes it implicitly).

  • Example output
    [INFO] Loading ONNX model from dinov3_vitt_gennerator_640x640_640x640.onnx
    real: blended vs transmission_layer: 100%|███████| 500/500 [01:54<00:00,  4.37it/s]
    [RESULT] split=real (blended vs transmission_layer) samples=500
            PSNR  -> mean=25.0195, min=9.6244, max=36.3645
            SSIM  -> mean=0.8617, min=0.1860, max=0.9840
    synthetic: blended vs transmission_layer: 100%|███████| 500/500 [01:15<00:00,  6.61it/s]
    [RESULT] split=synthetic (synthetic_blended vs transmission_layer) samples=500
            PSNR  -> mean=24.5095, min=9.9238, max=36.4029
            SSIM  -> mean=0.9001, min=0.1582, max=0.9919
    [INFO] Wrote per-image metrics to metrics.csv
    

VITT (DEIMv2-S) backbone outputs -> generator inputs

  1. /generator/blocks.0/Add_1_output_0: float32[1,1601,192] -> float32[1,192,640,640]
  2. /generator/blocks.4/Add_1_output_0: float32[1,1601,192] -> float32[1,192,640,640]
  3. /generator/blocks.7/Add_1_output_0: float32[1,1601,192] -> float32[1,192,640,640]
  4. /generator/blocks.11/Add_1_output_0: float32[1,1601,192] -> float32[1,192,640,640]
  5. input: float32[1,3,640,640]
  6. 1 + 2 + 3 + 4 + 5 -> /generator/Concat_14_output_0: float32[1,771,640,640]

ONNX Export

  • Backbone + Head

    uv run python export_onnx.py \
    --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \
    --output dinov3_vitt_gennerator_640x640_640x640.onnx \
    --backbone dinov3_vitt \
    --static_shape \
    --height 640 \
    --width 640 \
    --head_height 640 \
    --head_width 640
    
    uv run python export_onnx.py \
    --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \
    --output dinov3_vitt_gennerator_640x640_320x320.onnx \
    --backbone dinov3_vitt \
    --static_shape \
    --height 640 \
    --width 640 \
    --head_height 320 \
    --head_width 320
    
    
    H=640
    W=640
    # dinov3_vits16
    # dinov3_vitt
    VAR=dinov3_vitt
    # _distill
    DIST=_distill
    EPOCH=0013
    
    pushd ../..
    uv run python export_onnx.py \
    --checkpoint runs/${VAR}${DIST}_disthyper_rir/epoch_${EPOCH}/checkpoint.pt \
    --output ${VAR}${DIST}_disthyper_rir_gennerator_640x640_${H}x${W}.onnx \
    --backbone ${VAR} \
    --static_shape \
    --height 640 \
    --width 640 \
    --head_height ${H} \
    --head_width ${W}
    uv run python demo_reflection_removal.py \
    --input test_images \
    --output runs/test \
    --model ${VAR}${DIST}_disthyper_rir_gennerator_640x640_${H}x${W}.onnx \
    --provider CUDAExecutionProvider
    popd
    image image
  • Head only

    uv run python export_onnx.py \
    --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \
    --output dinov3_vitt_gennerator_headonly_640x640_320x320.onnx \
    --backbone dinov3_vitt \
    --static_shape \
    --height 640 \
    --width 640 \
    --head_height 320 \
    --head_width 320 \
    --head_only
    image

ONNX Inference

uv run python demo_reflection_removal.py \
--input test_images \
--output runs/test \
--model dinov3_vits16_gennerator_640x640_640x640.onnx \
--provider CUDAExecutionProvider

Citation

If you use this repository in your research, please cite both the original method and this implementation:

@software{Hyodo_2025_reflection_removal,
  author    = {Katsuya Hyodo},
  title     = {reflection-removal: Reflection-Removal},
  year      = {2025},
  month     = {nov},
  publisher = {Zenodo},
  version   = {1.0.0},
  doi       = {10.5281/zenodo.17595044},
  url       = {https://github.com/PINTO0309/reflection-removal},
  abstract  = {A model for removing reflections using a single image.},
}

Acknowledgments

  • https://github.com/ceciliavision/perceptual-reflection-removal
    @inproceedings{zhang2018single,
      title = {Single Image Reflection Separation with Perceptual Losses},
      author = {Zhang, Xuaner and Ng, Ren and Chen, Qifeng},
      booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
      year = {2018}
    }
  • https://github.com/Intellindust-AI-Lab/DEIMv2
    @article{huang2025deimv2,
      title={Real-Time Object Detection Meets DINOv3},
      author={Huang, Shihua and Hou, Yongjie and Liu, Longfei and Yu, Xuanlong and Shen, Xi},
      journal={arXiv},
      year={2025}
    }
  • https://github.com/facebookresearch/dinov3
    @misc{simeoni2025dinov3,
      title={{DINOv3}},
      author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
      year={2025},
      eprint={2508.10104},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2508.10104},
    }
  • https://github.com/PINTO0309/PINTO_model_zoo/tree/main/472_DEIMv2-Wholebody34
    @software{DEIMv2-Wholebody34,
      author={Katsuya Hyodo},
      title={Lightweight human detection models generated on high-quality human data sets. It can detect objects with high accuracy and speed in a total of 28 classes: body, adult, child, male, female, body_with_wheelchair, body_with_crutches, head, front, right-front, right-side, right-back, back, left-back, left-side, left-front, face, eye, nose, mouth, ear, collarbone, shoulder, solar_plexus, elbow, wrist, hand, hand_left, hand_right, abdomen, hip_joint, knee, ankle, foot.},
      url={https://github.com/PINTO0309/PINTO_model_zoo/tree/main/472_DEIMv2-Wholebody34},
      year={2025},
      month={10},
      doi={10.5281/zenodo.10229410}
    }

About

reflection-removal: A model for removing reflections using a single image. The goal is to significantly reduce the amount of calculation required for GAN and remove reflections in real time. The quality of the reflection removal is not important for this project. Built with DINOv3.

Topics

Resources

License

Stars

Watchers

Forks