Skip to content

Commit a7918a7

Browse files
authored
Use SGD for list of models from PyTorch. (pytorch#6324)
1 parent ab28b23 commit a7918a7

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

benchmarks/torchbench_model.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,35 @@
2929
"detectron2_fcos_r_50_fpn",
3030
}
3131

32+
# torchbench models that might OOM using Adam.
33+
# This list was extracted from PyTorch's repository: benchmarks/dynamo/common.py
34+
TRAIN_WITH_SGD = {
35+
"BERT_pytorch",
36+
"LearningToPaint",
37+
"alexnet",
38+
"dcgan",
39+
"demucs",
40+
"densenet121",
41+
"dlrm",
42+
"fastNLP_Bert",
43+
"mobilenet_v2",
44+
"phlippe_densenet",
45+
"phlippe_resnet",
46+
"pytorch_stargan",
47+
"resnet18",
48+
"shufflenet_v2_x1_0",
49+
"speech_transformer",
50+
"squeezenet1_1",
51+
"stable_diffusion_text_encoder",
52+
"timm_efficientdet",
53+
"timm_nfnet",
54+
"timm_regnet",
55+
"timm_vision_transformer",
56+
"timm_vovnet",
57+
"vgg16",
58+
"hf_T5",
59+
}
60+
3261
# Skip the experiment of a model if any of the experiment configs in the list is fully matched
3362
DENY_LIST = {
3463
"doctr_det_predictor": [{
@@ -179,7 +208,10 @@ def set_up(self):
179208
180209
This is model suite specific.
181210
"""
182-
self.optimizer_class = torch.optim.Adam
211+
if self.benchmark_experiment.test == "train" and self.model_name in TRAIN_WITH_SGD:
212+
self.optimizer_class = torch.optim.SGD
213+
else:
214+
self.optimizer_class = torch.optim.Adam
183215

184216
benchmark = self.load_benchmark()
185217

@@ -205,8 +237,6 @@ def set_up(self):
205237
if self.model_name == "yolov3":
206238
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3,
207239
384, 512),)
208-
if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS:
209-
self.optimizer = benchmark.optimizer
210240

211241
del benchmark
212242
self._cleanup()

0 commit comments

Comments
 (0)