Skip to content

Commit ea08148

Browse files
committed
recompile limit
1 parent 3a61081 commit ea08148

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/models/testing_utils/compile.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ def test_torch_compile_repeated_blocks(self, recompile_limit=1):
9292
model.eval()
9393
model.compile_repeated_blocks(fullgraph=True)
9494

95-
if self.model_class.__name__ == "UNet2DConditionModel":
96-
recompile_limit = 2
97-
9895
with (
9996
torch._inductor.utils.fresh_inductor_cache(),
10097
torch._dynamo.config.patch(recompile_limit=recompile_limit),

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,9 @@ def test_ip_adapter_plus(self):
11991199
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
12001200
"""Torch compile tests for UNet2DConditionModel."""
12011201

1202+
def test_torch_compile_repeated_blocks(self):
1203+
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
1204+
12021205

12031206
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
12041207
"""LoRA hot-swapping tests for UNet2DConditionModel."""

0 commit comments

Comments
 (0)