Skip to content

Commit afafb24

Browse files
committed
fix dduf
1 parent 6c2e10a commit afafb24

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/pipelines/transformers_loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,14 @@ def _load_transformers_model_from_dduf(
112112
tensors = safetensors.torch.load(mmap)
113113
# Update the state dictionary with tensors
114114
state_dict.update(tensors)
115-
return cls.from_pretrained(
115+
model = cls.from_pretrained(
116116
pretrained_model_name_or_path=None,
117117
config=config,
118118
generation_config=generation_config,
119119
state_dict=state_dict,
120120
**kwargs,
121121
)
122+
# Models loaded via from_pretrained are in eval mode by default,
123+
# but we need to preserve training mode for consistency with non-DDUF loading
124+
model.train()
125+
return model

0 commit comments

Comments
 (0)