From bbfeb0794cb72623e0ced258e6cd5b06b6b40fb7 Mon Sep 17 00:00:00 2001 From: gegejun Date: Thu, 5 Mar 2026 13:11:10 +0800 Subject: [PATCH 1/6] fix batchsize when using BatchSampler --- src/lightning/pytorch/utilities/data.py | 17 +++++-- .../trainer/test_batch_sampler.py | 45 +++++++++++++++++++ 2 files changed, 59 insertions(+), 3 deletions(-) create mode 100644 tests/tests_pytorch/trainer/test_batch_sampler.py diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index b04bc0dfdc2da..a701b2d444f9a 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -169,9 +169,9 @@ def _get_dataloader_init_args_and_kwargs( if was_wrapped: # if the dataloader was wrapped in a hook, only take arguments with default values # and assume user passes their kwargs correctly - params.update({ - k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty - }) + params.update( + {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} + ) else: params.update(inspect.signature(DataLoader.__init__).parameters) params.pop("self", None) @@ -332,6 +332,17 @@ def _dataloader_init_kwargs_resolve_sampler( "batch_size": 1, "drop_last": False, } + if batch_sampler is not None and batch_sampler_cls is BatchSampler: + # This is a PyTorch `BatchSampler` but maybe created by user, so batch_size and drop_last should be preserved + batch_size = batch_sampler.batch_size + drop_last = batch_sampler.drop_last if not is_predicting else False + return { + "sampler": sampler, + "shuffle": False, + "batch_sampler": None, + "batch_size": batch_size, + "drop_last": drop_last, + } return {"sampler": sampler, "shuffle": False, "batch_sampler": None} diff --git a/tests/tests_pytorch/trainer/test_batch_sampler.py b/tests/tests_pytorch/trainer/test_batch_sampler.py new file mode 100644 index 0000000000000..bd984a6712834 --- /dev/null +++ b/tests/tests_pytorch/trainer/test_batch_sampler.py @@ -0,0 +1,45 @@ +import pytest +from torch.utils.data import RandomSampler, BatchSampler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.distributed import DistributedSampler +from lightning.pytorch import Callback, Trainer, seed_everything +from tests_pytorch.helpers.runif import RunIf +from lightning.pytorch.demos.boring_classes import ( + BoringModel, + RandomDataset, +) + + +class DistribBatchSamplerCallback(Callback): + def __init__(self, expected_batch_size, expected_drop_last): + self.expected_batch_size = expected_batch_size + self.expected_drop_last = expected_drop_last + + def on_train_start(self, trainer, pl_module): + assert isinstance(trainer.train_dataloader.sampler, DistributedSampler) + assert trainer.train_dataloader.batch_size == self.expected_batch_size + assert trainer.train_dataloader.drop_last == self.expected_drop_last + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("drop_last", [False, True]) +@RunIf(min_cuda_gpus=2, skip_windows=True) +def test_dataloader_distributed_batch_sampler(tmp_path, batch_size, drop_last): + """Test BatchSampler and it's arguments for DDP backend.""" + seed_everything(123) + dataset = RandomDataset(32, 64) + sampler = RandomSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last) + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) + print(batch_sampler.drop_last, dataloader.drop_last) + model = BoringModel() + trainer = Trainer( + accelerator="gpu", + devices=[0, 1], + num_nodes=1, + strategy="ddp", + default_root_dir=tmp_path, + max_steps=1, + callbacks=[DistribBatchSamplerCallback(expected_batch_size=batch_size, expected_drop_last=drop_last)], + ) + trainer.fit(model, train_dataloaders=dataloader) From cb229f23c753a7de1a9215892d109f5e698e84d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 07:05:40 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/utilities/data.py | 6 +++--- tests/tests_pytorch/trainer/test_batch_sampler.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/utilities/data.py b/src/lightning/pytorch/utilities/data.py index a701b2d444f9a..df565eea9af5d 100644 --- a/src/lightning/pytorch/utilities/data.py +++ b/src/lightning/pytorch/utilities/data.py @@ -169,9 +169,9 @@ def _get_dataloader_init_args_and_kwargs( if was_wrapped: # if the dataloader was wrapped in a hook, only take arguments with default values # and assume user passes their kwargs correctly - params.update( - {k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty} - ) + params.update({ + k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty + }) else: params.update(inspect.signature(DataLoader.__init__).parameters) params.pop("self", None) diff --git a/tests/tests_pytorch/trainer/test_batch_sampler.py b/tests/tests_pytorch/trainer/test_batch_sampler.py index bd984a6712834..5f1771af1d7a7 100644 --- a/tests/tests_pytorch/trainer/test_batch_sampler.py +++ b/tests/tests_pytorch/trainer/test_batch_sampler.py @@ -1,13 +1,14 @@ import pytest -from torch.utils.data import RandomSampler, BatchSampler +from torch.utils.data import BatchSampler, RandomSampler from torch.utils.data.dataloader import DataLoader from torch.utils.data.distributed import DistributedSampler + from lightning.pytorch import Callback, Trainer, seed_everything -from tests_pytorch.helpers.runif import RunIf from lightning.pytorch.demos.boring_classes import ( BoringModel, RandomDataset, ) +from tests_pytorch.helpers.runif import RunIf class DistribBatchSamplerCallback(Callback): From 489306a6775d6c694a22374abf577d5fd1e0a903 Mon Sep 17 00:00:00 2001 From: Deependu Date: Thu, 5 Mar 2026 16:33:25 +0530 Subject: [PATCH 3/6] Apply suggestion from @deependujha --- tests/tests_pytorch/trainer/test_batch_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_batch_sampler.py b/tests/tests_pytorch/trainer/test_batch_sampler.py index 5f1771af1d7a7..9471c72958df2 100644 --- a/tests/tests_pytorch/trainer/test_batch_sampler.py +++ b/tests/tests_pytorch/trainer/test_batch_sampler.py @@ -24,7 +24,7 @@ def on_train_start(self, trainer, pl_module): @pytest.mark.parametrize("batch_size", [1, 5]) @pytest.mark.parametrize("drop_last", [False, True]) -@RunIf(min_cuda_gpus=2, skip_windows=True) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) def test_dataloader_distributed_batch_sampler(tmp_path, batch_size, drop_last): """Test BatchSampler and it's arguments for DDP backend.""" seed_everything(123) From 0736deeeff6953240e5c77ca6e850075228a049d Mon Sep 17 00:00:00 2001 From: gegejun Date: Thu, 5 Mar 2026 19:59:54 +0800 Subject: [PATCH 4/6] move test_batch_sampler.py into test_dataloader.py --- .../trainer/test_batch_sampler.py | 46 ------------------- .../tests_pytorch/trainer/test_dataloaders.py | 37 ++++++++++++++- 2 files changed, 36 insertions(+), 47 deletions(-) delete mode 100644 tests/tests_pytorch/trainer/test_batch_sampler.py diff --git a/tests/tests_pytorch/trainer/test_batch_sampler.py b/tests/tests_pytorch/trainer/test_batch_sampler.py deleted file mode 100644 index 9471c72958df2..0000000000000 --- a/tests/tests_pytorch/trainer/test_batch_sampler.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -from torch.utils.data import BatchSampler, RandomSampler -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from lightning.pytorch import Callback, Trainer, seed_everything -from lightning.pytorch.demos.boring_classes import ( - BoringModel, - RandomDataset, -) -from tests_pytorch.helpers.runif import RunIf - - -class DistribBatchSamplerCallback(Callback): - def __init__(self, expected_batch_size, expected_drop_last): - self.expected_batch_size = expected_batch_size - self.expected_drop_last = expected_drop_last - - def on_train_start(self, trainer, pl_module): - assert isinstance(trainer.train_dataloader.sampler, DistributedSampler) - assert trainer.train_dataloader.batch_size == self.expected_batch_size - assert trainer.train_dataloader.drop_last == self.expected_drop_last - - -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize("drop_last", [False, True]) -@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) -def test_dataloader_distributed_batch_sampler(tmp_path, batch_size, drop_last): - """Test BatchSampler and it's arguments for DDP backend.""" - seed_everything(123) - dataset = RandomDataset(32, 64) - sampler = RandomSampler(dataset) - batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last) - dataloader = DataLoader(dataset, batch_sampler=batch_sampler) - print(batch_sampler.drop_last, dataloader.drop_last) - model = BoringModel() - trainer = Trainer( - accelerator="gpu", - devices=[0, 1], - num_nodes=1, - strategy="ddp", - default_root_dir=tmp_path, - max_steps=1, - callbacks=[DistribBatchSamplerCallback(expected_batch_size=batch_size, expected_drop_last=drop_last)], - ) - trainer.fit(model, train_dataloaders=dataloader) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a69176b00d74f..3bb21137ba34f 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -18,7 +18,7 @@ import pytest import torch from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import RandomSampler +from torch.utils.data import RandomSampler, BatchSampler from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler @@ -814,6 +814,41 @@ def test_dataloader_distributed_sampler(tmp_path): trainer.test(model) +class DistribBatchSamplerCallback(Callback): + def __init__(self, expected_batch_size, expected_drop_last): + self.expected_batch_size = expected_batch_size + self.expected_drop_last = expected_drop_last + + def on_train_start(self, trainer, pl_module): + assert isinstance(trainer.train_dataloader.sampler, DistributedSampler) + assert trainer.train_dataloader.batch_size == self.expected_batch_size + assert trainer.train_dataloader.drop_last == self.expected_drop_last + + +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize("drop_last", [False, True]) +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) +def test_dataloader_distributed_batch_sampler(tmp_path, batch_size, drop_last): + """Test BatchSampler and it's arguments for DDP backend.""" + seed_everything(123) + dataset = RandomDataset(32, 64) + sampler = RandomSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last) + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) + print(batch_sampler.drop_last, dataloader.drop_last) + model = BoringModel() + trainer = Trainer( + accelerator="gpu", + devices=[0, 1], + num_nodes=1, + strategy="ddp", + default_root_dir=tmp_path, + max_steps=1, + callbacks=[DistribBatchSamplerCallback(expected_batch_size=batch_size, expected_drop_last=drop_last)], + ) + trainer.fit(model, train_dataloaders=dataloader) + + class TestModelUniqueDDPSampling(BoringModel): def __init__(self): super().__init__() From 4ca8393d3ea1f0679727500eca758a7b450babb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:01:45 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/trainer/test_dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 3bb21137ba34f..1729763c8309c 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -18,7 +18,7 @@ import pytest import torch from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import RandomSampler, BatchSampler +from torch.utils.data import BatchSampler, RandomSampler from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, IterableDataset from torch.utils.data.distributed import DistributedSampler From 4b0de341eef89b086dc7f7d6a757f13e8356390c Mon Sep 17 00:00:00 2001 From: Deependu Date: Thu, 5 Mar 2026 17:39:01 +0530 Subject: [PATCH 6/6] Apply suggestion from @deependujha --- tests/tests_pytorch/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 1729763c8309c..49ea1430ce1a6 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -835,7 +835,6 @@ def test_dataloader_distributed_batch_sampler(tmp_path, batch_size, drop_last): sampler = RandomSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=drop_last) dataloader = DataLoader(dataset, batch_sampler=batch_sampler) - print(batch_sampler.drop_last, dataloader.drop_last) model = BoringModel() trainer = Trainer( accelerator="gpu",