Skip to content

Commit 4d80c88

Browse files
authored
[fix] Save custom module kwargs if specified (#3112)
* Save custom module kwargs if specified This should have been included in the save all along * Also try to load a 'dynamic module' if not trust-remote, but local model
1 parent 5dfd360 commit 4d80c88

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

sentence_transformers/SentenceTransformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,11 @@ def save(
11931193
# For other cases, we want to add the class name:
11941194
elif not class_ref.startswith("sentence_transformers."):
11951195
class_ref = f"{class_ref}.{type(module).__name__}"
1196-
modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref})
1196+
1197+
module_config = {"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref}
1198+
if self.module_kwargs and name in self.module_kwargs and (module_kwargs := self.module_kwargs[name]):
1199+
module_config["kwargs"] = module_kwargs
1200+
modules_config.append(module_config)
11971201

11981202
with open(os.path.join(path, "modules.json"), "w") as fOut:
11991203
json.dump(modules_config, fOut, indent=2)
@@ -1556,7 +1560,7 @@ def _load_module_class_from_ref(
15561560
if class_ref.startswith("sentence_transformers."):
15571561
return import_from_string(class_ref)
15581562

1559-
if trust_remote_code:
1563+
if trust_remote_code or os.path.exists(model_name_or_path):
15601564
code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
15611565
try:
15621566
return get_class_from_dynamic_module(

0 commit comments

Comments
 (0)