Skip to content

Commit eaf0412

Browse files
authored
torch-loader(example): use persistent workers to reduce test time (#694)
1 parent 250ce97 commit eaf0412

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

examples/get_started/torch-loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from datachain.torch import label_to_int
2020

2121
STORAGE = "gs://datachain-demo/dogs-and-cats/"
22-
NUM_EPOCHS = os.getenv("NUM_EPOCHS", "3")
22+
NUM_EPOCHS = int(os.getenv("NUM_EPOCHS", "3"))
2323

2424
# Define transformation for data preprocessing
2525
transform = v2.Compose(
@@ -68,7 +68,8 @@ def forward(self, x):
6868
train_loader = DataLoader(
6969
ds.to_pytorch(transform=transform),
7070
batch_size=25,
71-
num_workers=4,
71+
num_workers=max(4, os.cpu_count() or 2),
72+
persistent_workers=True,
7273
multiprocessing_context=multiprocessing.get_context("spawn"),
7374
)
7475

@@ -77,7 +78,7 @@ def forward(self, x):
7778
optimizer = optim.Adam(model.parameters(), lr=0.001)
7879

7980
# Train the model
80-
for epoch in range(int(NUM_EPOCHS)):
81+
for epoch in range(NUM_EPOCHS):
8182
with tqdm(
8283
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
8384
) as loader:

noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def examples(session: nox.Session) -> None:
8181
session.install(".[examples]")
8282
session.run(
8383
"pytest",
84+
"--durations=0",
85+
"tests/examples",
8486
"-m",
8587
"examples",
8688
*session.posargs,

tests/examples/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def smoke_test(example: str, env: Optional[dict] = None):
5252
@pytest.mark.get_started
5353
@pytest.mark.parametrize("example", get_started_examples)
5454
def test_get_started_examples(example):
55-
smoke_test(example, {"NUM_EPOCHS": "1"})
55+
smoke_test(example)
5656

5757

5858
@pytest.mark.examples

0 commit comments

Comments
 (0)