A model for removing reflections using a single image. The goal is to significantly reduce the amount of calculation required for GAN and remove reflections in real time. The quality of the reflection removal is not important for this project. Built with DINOv3.
| Real | Removal |
|---|---|
![]() |
![]() |
![]() |
![]() |
| Variant | Size | TensorRT inference latency |
ONNX |
|---|---|---|---|
| dinov3_vitt_distill_disthyper_residual_gennerator_640x640_320x320 | 24.2 MB | 6.75 ms | Download |
| dinov3_vitt_distill_disthyper_residual_gennerator_640x640_640x640 | 24.2 MB | 18.82 ms | Download |
| dinov3_vits16_disthyper_residual_gennerator_640x640_320x320 | 89.2 MB | 9.49 ms | Download |
| dinov3_vits16_disthyper_residual_gennerator_640x640_640x640 | 89.3 MB | 24.97 ms | Download |
git clone https://github.com/PINTO0309/reflection-removal.git && cd reflection-removal
curl -LsSf https://astral.sh/uv/install.sh | sh
uv sync
source .venv/bin/activate- All pretrained weights (VGG-19) are fetched automatically via
torchvision, so no manual download is required. - Additional backbones are supported:
- DINOv3 ViT-Tiny (
dinov3_vitt) — DEIMv2-finetuned weights stored atckpts/deimv2_dinov3_s_wholebody34.pth. Notorch.hubdownload is available, so keep this file locally. - HGNetV2 (
hgnetv2) — DEIMv2-finetuned CNN backbone stored atckpts/deimv2_hgnetv2_n_wholebody34.pth. - DINOv3 standard variants (
dinov3_vits16,dinov3_vits16plus,dinov3_vitb16) — place the provided checkpoints inside./ckpts/(files nameddinov3_*_pretrain_lvd1689m-*.pth). If the files are missing, they will be downloaded automatically throughtorch.hub.
- DINOv3 ViT-Tiny (
There is a significant lack of real-world reflection data, but it should be enough to ensure that learning is proceeding normally. Synthetic data is useless, so we should diligently collect authentic data from the real world.
https://github.com/ceciliavision/perceptual-reflection-removal?tab=readme-ov-file#dataset
reflection-dataset/
├── real
│ ├── blended
│ │ ├── 100001.jpg
│ │ ├── 100002.jpg
│ │ ├── 100003.jpg
│ │ ├── 100004.jpg
│ │ └── 100005.jpg
│ └── transmission_layer
│ ├── 100001.jpg
│ ├── 100002.jpg
│ ├── 100003.jpg
│ ├── 100004.jpg
│ └── 100005.jpg
└── synthetic
├── reflection_layer
│ ├── 100001.jpg
│ ├── 100002.jpg
│ ├── 100003.jpg
│ ├── 100004.jpg
│ └── 100005.jpg
└── transmission_layer
├── 100001.jpg
├── 100002.jpg
├── 100003.jpg
├── 100004.jpg
└── 100005.jpg
backbone: "vgg19", "hgnetv2", "dinov3_vitt", "dinov3_vits16", "dinov3_vits16plus", "dinov3_vitb16"
# Baseline
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16 \
--use_amp# Initialized with DINOv3 default weights + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_residual \
--residual_skips \
--residual_init 0.1 \
--use_ampResidual generator variants:
--residual_skipsenables single-stage residual scaling after each dilated block.--residual_in_residual_skipsstacks the eight dilated blocks into two Residual-in-Residual groups (4 layers each) for deeper skip modulation. This flag supersedes--residual_skipswhen both are provided.
# Initialized with pretrained weights + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_residual \
--ckpt_file ckpts/reflection_removal_dinov3_vits16.pt \
--residual_skips \
--residual_init 0.1 \
--use_amp# Distributed Hypercolumn + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_disthyper_residual \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_skips \
--residual_init 0.1 \
--use_amp--use_distributed_hypercolumnreplaces the full hypercolumn concat (raw backbone features + RGB) with a distributed projection: each backbone stage is first reduced by a learnable 1×1 convolution and the concatenated tensor is then compressed to 64 channels via a final 1×1. This greatly lowers memory usage while keeping the generator interface unchanged. Enable this flag whenever you plan to export compact ONNX graphs or train on high resolutions; checkpoints store the necessary projection weights.--hypercolumn_channel_reduction_scalecontrols how aggressively each stage is reduced. The reducer for a layer withCchannels will emitceil(C / scale)channels. The default (4) keeps roughly 25 % of the original channels per stage; higher values (e.g.8,16) shrink both parameters and FLOPs linearly, at the cost of fewer hypercolumn features. The final post-projection always outputs 64 channels, so downstream generator layers remain compatible across scales. When loading a checkpoint, the correct scale is inferred automatically, so you only need to pass this flag during training if you want a non-default compression ratio.- Practical guidance:
- Start with
--use_distributed_hypercolumn --hypercolumn_channel_reduction_scale 4for general training—it balances memory and fidelity. - If you are memory-bound, increase the scale to
8or16. Expect parameter counts and compute inside the hypercolumn projector to drop by roughly ½ and ¼ respectively when moving from 4→8→16. - If you disable distributed hypercolumns, the generator reverts to concatenating fully-resolved backbone maps; this offers maximal information but is significantly heavier and may not match ONNX deployment paths.
- Start with
To resume from a previous checkpoint inside dinov3_vits16/:
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vits16 \
--exp_name dinov3_vits16_disthyper_residual \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_skips \
--residual_init 0.1 \
--use_amp \
--resume| Value | Note |
|---|---|
| loss | The average content loss calculated by adding the L1 coefficient + 0.2×perceptual + grad. When actually updating the generator, this is multiplied by 100 and added to the next adv. |
| percep | The average perceptual loss, which measures distance in feature space (DINO/VGG). |
| grad | The average exclusion/gradient loss, which prevents the gradients of the transmitted image and the reflected image from overlapping. |
| adv | The average adversarial loss (BCE), which is used to make the classifier believe the image is "real." |
| feat_dist | Mean MSE between the student and projected teacher feature maps; reported only when feature distillation is enabled. |
| pix_dist | Mean L1 distance between the student outputs and the frozen teacher outputs (transmission, and reflection when available); reported only when pixel distillation is enabled. |
Therefore, the loss display is not the final total loss, but an indicator for checking the basic loss balance on the content side.
If you want the teacher’s influence to fade during late epochs, pass --enable_distill_decay. This keeps the original distillation weights for a warmup period (default --distill_decay_warmup_epochs 5) and then applies a cosine schedule that smoothly decays both feature and pixel distillation weights to zero by the final epoch. The decay is disabled by default; omit the flag to keep static teacher weights. Distillation weights/decay settings are stored inside each checkpoint, so resuming training reproduces the same behavior even if you forget to re-specify the CLI flags.
Distillation from dinov3_vits16 to dinov3_vitt and fine-tuned backbone in DEIMv2 for students to learn.
# Without Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill \
--ckpt_dir ckpts \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16.pt \
--use_amp# With Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_residual.pt \
--use_amp# With Distributed Hypercolumn + Residual blocks
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 4 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 8 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp
uv run python main.py \
--data_syn_dir reflection-dataset/synthetic \
--data_real_dir reflection-dataset/real \
--backbone dinov3_vitt \
--exp_name dinov3_vitt_distill_disthyper_rir \
--ckpt_dir ckpts \
--use_distributed_hypercolumn \
--hypercolumn_channel_reduction_scale 16 \
--residual_in_residual_skips \
--residual_init 0.1 \
--distill_teacher_backbone dinov3_vits16 \
--distill_teacher_checkpoint ckpts/reflection_removal_dinov3_vits16_disthyper4_residual.pt \
--enable_distill_decay \
--use_amp- Content vs. adversarial balance: For each generator update we optimise
content_loss = L1(reflection) + 0.2 × perceptual + grad, and the actual objective that is backpropagated istotal_g = 100 × content_loss + adv. A rise inlossusually means the reconstruction terms need attention; a rise inadvmeans the discriminator is currently winning. - Synthetic vs. real batches: When training on synthetic pairs both the transmission and reflection branches contribute to the perceptual/L1 terms, so
lossshould generally be higher than during real batches (where reflection supervision is zero). Expectgradto be non-zero only for synthetic samples. - Healthy dynamics: You want
loss,percep, andgradto trend downward slowly whileadvoscillates. If all climb together, the model is diverging. Ifadvcollapses near 0 while others stagnate, the discriminator may be too weak—lower its learning rate or add more synthetic data. Ifadvstays very high but the other terms shrink, the discriminator is too strong—consider reducing the GAN weight (e.g., scaling the final+ adv) or adding label smoothing. - Practical monitoring: Track the logged scalars in TensorBoard. Focus on the moving averages per epoch; transient spikes after checkpoint saves are normal. Compare checkpoints by running
--test_onlyso you can visually confirm whether changes in the metrics translate to better separation.
After every training epoch the current generator runs inference on the images specified by --test_dir, and the outputs are written to test_results/<exp_name>/epoch_<NNNN>/. If --test_dir does not resolve to any images, the script falls back to a fixed set of up to 10 blended inputs gathered from --data_real_dir, so you still get a consistent visual trace of progress across epochs. Clean up these folders periodically if disk usage grows too large.
Each validation sample produces a directory numbered by processing order (e.g. test_results/<exp_name>/epoch_0001/0001_<image_stem>/) containing three PNGs:
input.png— the raw blended input frame.t_output.png— the predicted transmission layer.r_output.png— the predicted reflection layer.
Checkpoints, intermediate predictions, train.log, and TensorBoard summaries (saved directly inside runs/dinov3_vits16/) are all stored under runs/dinov3_vits16/. Launch TensorBoard via:
tensorboard --logdir runs--exp_name: experiment name. Training artifacts (checkpoints, train.log, image dumps) are stored under ./runs/<exp_name>/, and inference outputs under ./test_results/<exp_name>/.
--data_syn_dir: comma-separated list of synthetic dataset roots
--data_real_dir: comma-separated list of real dataset roots
--save_model_freq: frequency to save model and the output images
--keep_checkpoint_history: number of saved checkpoint epochs (epoch_<NNNN> folders under runs/<exp_name>/) to retain (0 keeps all)
--backbone: feature extractor for hypercolumns and perceptual loss (vgg19, hgnetv2, dinov3_vitt, dinov3_vits16, dinov3_vits16plus, dinov3_vitb16). Hypercolumn features are always enabled; older runs that used --is_hyper now default to the same behaviour.
--ckpt_dir: directory where backbone checkpoints are searched (default ckpts)
--ckpt_file: optional generator checkpoint that seeds training; loads weights before the first epoch (ignored with --resume)
--test_only: skip training and run inference only
--resume: resume training from the last checkpoint in runs/<exp_name>/
--use_amp: enable torch.cuda.amp mixed precision (effective on CUDA devices)
--epochs: number of training epochs (default 100)
--device: device string such as cuda:0 or cpu
uv run python main.py \
--exp_name dinov3_vits16_test \
--test_only \
--backbone dinov3_vits16 \
--test_dir ./test_imagesMake sure the --backbone flag matches the model that produced the checkpoint you are loading.
If --test_only is omitted, the script trains by default and writes checkpoints/metrics to runs/<exp_name>/.
Test outputs are written to ./test_results/<exp_name>/<image_name>/.
Use compute_pnsr_ssim_metrics.py to compute transmission-layer PSNR and SSIM for the samples stored under reflection-dataset/. By default it compares the paired ground-truth images directly:
uv run python compute_pnsr_ssim_metrics.py --subset realKey options:
--subset {real,synthetic,all}– choose the dataset split.--max-samples N– limit the number of pairs per split (handy for quick spot checks).--output-csv metrics.csv– dump per-image metrics for further analysis.--skip-missing/--no-skip-missing– control what happens when a pair is absent (--skip-missingis enabled by default).--synthetic-seed– seed controlling the on-the-fly blend generation used for the synthetic split (requires a model flag).
To benchmark predicted transmissions from a PyTorch checkpoint (either a .pt file or a runs/<exp>/epoch_xxxx/ directory), point the script at the artifact:
uv run python compute_pnsr_ssim_metrics.py \
--subset real \
--max-samples 100 \
--checkpoint runs/dinov3_vitt_distill_disthyper_residual/epoch_0033/checkpoint.pt \
--device cuda:0 \
--output-csv metrics.csvIf you prefer ONNX Runtime instead, supply --onnx-model <path> and select the execution providers in priority order via repeated --providers flags (cpu, cuda, tensorrt):
uv run python compute_pnsr_ssim_metrics.py \
--subset all \
--max-samples 500 \
--onnx-model dinov3_vitt_gennerator_640x640_640x640.onnx \
--providers cuda --providers cpu \
--output-csv metrics.csvWhen a model is provided, the script feeds each blended (or synthetic reflection) image through the generator, resizes the transmission prediction back to the ground-truth resolution, and reports split-level summary statistics plus optional CSV output.
For the synthetic split, the script now mirrors training by generating blended inputs on the fly from every reflection/transmission pair, so you must specify either --checkpoint or --onnx-model whenever you evaluate --subset synthetic (or --subset all, which includes it implicitly).
- Example output
[INFO] Loading ONNX model from dinov3_vitt_gennerator_640x640_640x640.onnx real: blended vs transmission_layer: 100%|███████| 500/500 [01:54<00:00, 4.37it/s] [RESULT] split=real (blended vs transmission_layer) samples=500 PSNR -> mean=25.0195, min=9.6244, max=36.3645 SSIM -> mean=0.8617, min=0.1860, max=0.9840 synthetic: blended vs transmission_layer: 100%|███████| 500/500 [01:15<00:00, 6.61it/s] [RESULT] split=synthetic (synthetic_blended vs transmission_layer) samples=500 PSNR -> mean=24.5095, min=9.9238, max=36.4029 SSIM -> mean=0.9001, min=0.1582, max=0.9919 [INFO] Wrote per-image metrics to metrics.csv
/generator/blocks.0/Add_1_output_0: float32[1,1601,192]->float32[1,192,640,640]/generator/blocks.4/Add_1_output_0: float32[1,1601,192]->float32[1,192,640,640]/generator/blocks.7/Add_1_output_0: float32[1,1601,192]->float32[1,192,640,640]/generator/blocks.11/Add_1_output_0: float32[1,1601,192]->float32[1,192,640,640]input: float32[1,3,640,640]1 + 2 + 3 + 4 + 5->/generator/Concat_14_output_0: float32[1,771,640,640]
-
Backbone + Head
uv run python export_onnx.py \ --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \ --output dinov3_vitt_gennerator_640x640_640x640.onnx \ --backbone dinov3_vitt \ --static_shape \ --height 640 \ --width 640 \ --head_height 640 \ --head_width 640 uv run python export_onnx.py \ --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \ --output dinov3_vitt_gennerator_640x640_320x320.onnx \ --backbone dinov3_vitt \ --static_shape \ --height 640 \ --width 640 \ --head_height 320 \ --head_width 320 H=640 W=640 # dinov3_vits16 # dinov3_vitt VAR=dinov3_vitt # _distill DIST=_distill EPOCH=0013 pushd ../.. uv run python export_onnx.py \ --checkpoint runs/${VAR}${DIST}_disthyper_rir/epoch_${EPOCH}/checkpoint.pt \ --output ${VAR}${DIST}_disthyper_rir_gennerator_640x640_${H}x${W}.onnx \ --backbone ${VAR} \ --static_shape \ --height 640 \ --width 640 \ --head_height ${H} \ --head_width ${W} uv run python demo_reflection_removal.py \ --input test_images \ --output runs/test \ --model ${VAR}${DIST}_disthyper_rir_gennerator_640x640_${H}x${W}.onnx \ --provider CUDAExecutionProvider popd
-
Head only
uv run python export_onnx.py \ --checkpoint runs/dinov3_vitt/epoch_0001/checkpoint.pt \ --output dinov3_vitt_gennerator_headonly_640x640_320x320.onnx \ --backbone dinov3_vitt \ --static_shape \ --height 640 \ --width 640 \ --head_height 320 \ --head_width 320 \ --head_only
uv run python demo_reflection_removal.py \
--input test_images \
--output runs/test \
--model dinov3_vits16_gennerator_640x640_640x640.onnx \
--provider CUDAExecutionProviderIf you use this repository in your research, please cite both the original method and this implementation:
@software{Hyodo_2025_reflection_removal,
author = {Katsuya Hyodo},
title = {reflection-removal: Reflection-Removal},
year = {2025},
month = {nov},
publisher = {Zenodo},
version = {1.0.0},
doi = {10.5281/zenodo.17595044},
url = {https://github.com/PINTO0309/reflection-removal},
abstract = {A model for removing reflections using a single image.},
}- https://github.com/ceciliavision/perceptual-reflection-removal
@inproceedings{zhang2018single, title = {Single Image Reflection Separation with Perceptual Losses}, author = {Zhang, Xuaner and Ng, Ren and Chen, Qifeng}, booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, year = {2018} }
- https://github.com/Intellindust-AI-Lab/DEIMv2
@article{huang2025deimv2, title={Real-Time Object Detection Meets DINOv3}, author={Huang, Shihua and Hou, Yongjie and Liu, Longfei and Yu, Xuanlong and Shen, Xi}, journal={arXiv}, year={2025} }
- https://github.com/facebookresearch/dinov3
@misc{simeoni2025dinov3, title={{DINOv3}}, author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr}, year={2025}, eprint={2508.10104}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2508.10104}, }
- https://github.com/PINTO0309/PINTO_model_zoo/tree/main/472_DEIMv2-Wholebody34
@software{DEIMv2-Wholebody34, author={Katsuya Hyodo}, title={Lightweight human detection models generated on high-quality human data sets. It can detect objects with high accuracy and speed in a total of 28 classes: body, adult, child, male, female, body_with_wheelchair, body_with_crutches, head, front, right-front, right-side, right-back, back, left-back, left-side, left-front, face, eye, nose, mouth, ear, collarbone, shoulder, solar_plexus, elbow, wrist, hand, hand_left, hand_right, abdomen, hip_joint, knee, ankle, foot.}, url={https://github.com/PINTO0309/PINTO_model_zoo/tree/main/472_DEIMv2-Wholebody34}, year={2025}, month={10}, doi={10.5281/zenodo.10229410} }



