Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to save and load from huggingfce #83

Merged
merged 5 commits into from
May 17, 2024
Merged

Added ability to save and load from huggingfce #83

merged 5 commits into from
May 17, 2024

Conversation

danyoungday
Copy link
Collaborator

For #82 . New method from_pretrained used to load predictors from huggingface. Special script used to save to huggingface. Only those with the access token can push to huggingface.

@@ -1,5 +1,6 @@
# Ignores saved predictors
predictors/*/trained_models/
predictors/trained_models
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default cache dir for our huggingface models

if not (load_path / "config.json").exists() or \
not (load_path / "model.pt").exists() or \
not (load_path / "scaler.joblib").exists():
raise FileNotFoundError("Model files not found in path.")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check to see all the files were downloaded properly before loading

hf_args["local_dir"] = local_dir
snapshot_download(repo_id=path_or_url, **hf_args)

return cls.load(Path(local_dir))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation of from_pretrained.

  1. We check disk for the model
  2. If it doesn't exist we download from hub. We default our local save dir to predictors/trained_models.
  3. We load our model from the local file

self.features = model_config.get("features", None)
self.label = model_config.get("label", None)

self.config = model_config
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some config refactoring so we can save our training arguments for reproducibility

with open(save_path / "config.json", "w", encoding="utf-8") as file:
json.dump(config, file)
json.dump(self.config, file)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now dump all the arguments we used to create the model instead of just the ones we use at inference time

raise FileNotFoundError(f"Path {path} does not exist.")
if not (load_path / "config.json").exists() or not (load_path / "model.joblib").exists():
raise FileNotFoundError("Model files not found in path.")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if all the files exist before we load

model_config.pop("label", None)
self.model = LinearRegression(**model_config)
lr_config = {key: value for key, value in model_config.items() if key not in ["features", "label"]}
self.model = LinearRegression(**lr_config)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy instead of referencing config so we don't remove features and label from our actual stored config

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Script to upload model. Still have to create a readme template for the models. Takes in a token as only specified users can push to project resilience repo.

Copy link
Member

@ofrancon ofrancon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@danyoungday danyoungday merged commit 4ce6138 into main May 17, 2024
1 check passed
@danyoungday danyoungday deleted the hf-save branch May 17, 2024 17:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants