Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/PyTorchModelLogger.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
" if batch_idx % 100 == 0:\n",
" print(\n",
" f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} \"\n",
" f\"({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" f\"({100.0 * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" )\n",
"\n",
" file_name = f\"MNSIT_epoch_{epoch}_loss_{loss.item():.4e}\"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/PyTorchModelLogger.old.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@
" if batch_idx % 100 == 0:\n",
" print(\n",
" f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} \"\n",
" f\"({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" f\"({100.0 * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n",
" )\n",
"\n",
" file_name = f\"MNSIT_epoch_{epoch}_loss_{loss.item():.4e}\"\n",
Expand Down
7 changes: 5 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ package_dir =
# For more information, check out https://semver.org/.
install_requires =
importlib-metadata; python_version<"3.8"
datafed
GPutil
m3learning-util
psutil
datafed
pyright
ruff
sphinx
torch
torchvision
m3learning-util

[options.packages.find]
where = src
Expand Down
24 changes: 12 additions & 12 deletions src/datafed_torchflow/JSON.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class UniversalEncoder(json.JSONEncoder):
A custom JSON encoder that can handle numpy data types, sets, and objects with __dict__ attributes.
"""

def default(self, obj):
def default(self, o):
"""
Override the default method to provide custom serialization for unsupported data types.

Expand All @@ -19,16 +19,16 @@ def default(self, obj):
any: The serialized form of the object.
"""
# Convert numpy types to their Python equivalents
if isinstance(obj, np.integer):
return int(obj) # Convert numpy integers to Python int
elif isinstance(obj, np.floating):
return float(obj) # Convert numpy floats to Python float
elif isinstance(obj, np.ndarray):
return obj.tolist() # Convert numpy arrays to lists
elif isinstance(obj, set):
return list(obj) # Convert sets to lists
elif hasattr(obj, "__dict__"):
return obj.__dict__ # Serialize object attributes
if isinstance(o, np.integer):
return int(o) # Convert numpy integers to Python int
elif isinstance(o, np.floating):
return float(o) # Convert numpy floats to Python float
elif isinstance(o, np.ndarray):
return o.tolist() # Convert numpy arrays to lists
elif isinstance(o, set):
return list(o) # Convert sets to lists
elif hasattr(o, "__dict__"):
return o.__dict__ # Serialize object attributes
else:
# Call the default method for other cases
return super().default(obj)
return super().default(o)
6 changes: 3 additions & 3 deletions src/datafed_torchflow/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def get_memory_info():
"""
mem = psutil.virtual_memory()
return {
"total": f"{mem.total / (1024 ** 3):.2f} GB",
"available": f"{mem.available / (1024 ** 3):.2f} GB",
"used": f"{mem.used / (1024 ** 3):.2f} GB",
"total": f"{mem.total / (1024**3):.2f} GB",
"available": f"{mem.available / (1024**3):.2f} GB",
"used": f"{mem.used / (1024**3):.2f} GB",
"percent": mem.percent,
}

Expand Down
35 changes: 26 additions & 9 deletions src/datafed_torchflow/datafed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import traceback
from datetime import datetime
from typing import Optional

import numpy as np
from datafed.CommandLib import API
Expand Down Expand Up @@ -29,10 +30,10 @@ class DataFed(API):

def __init__(
self,
datafed_path,
datafed_path: str,
local_model_path="./Trained Models",
log_file_path="log.txt",
dataset_id_or_path=None,
dataset_id_or_path: Optional[str] = None,
download_kwargs={"wait": True, "orig_fname": True},
upload_kwargs={"wait": True},
logging=False,
Expand Down Expand Up @@ -91,11 +92,12 @@ def upload_dataset_to_DataFed(self):
The DataFed record ID for the dataset files, as a string for a single dataset file and a list of strings for multiple dataset files.
"""
if self.dataset_id_or_path is not None:
ls_resp = self.collectionItemsList(self.collection_id)

if isinstance(
self.dataset_id_or_path, list
): # to specify multiple dataset files
dataset_ids = []
ls_resp = self.collectionItemsList(self.collection_id)
for dataset in self.dataset_id_or_path:
if dataset.startswith("d/"):
dataset_ids.append(dataset)
Expand Down Expand Up @@ -163,6 +165,8 @@ def upload_dataset_to_DataFed(self):
if self.dataset_id_or_path.startswith("d/"):
dataset_ids = self.dataset_id_or_path
else:
dataset = self.dataset_id_or_path # Is this right?

try:
path_id = (
ls_resp[0]
Expand All @@ -176,6 +180,7 @@ def upload_dataset_to_DataFed(self):
]
.id
)

# update record (dependencies have been added)
record_id = self.get_notebook_DataFed_ID_from_path_and_title(
dataset, path_id=path_id
Expand Down Expand Up @@ -345,14 +350,17 @@ def check_if_endpoint_set(self):
)

@property
def user_id(self):
def user_id(self) -> str:
"""
Gets the user ID from the authenticated user's information.

Returns:
str: The user ID extracted from the authenticated user information.
"""
return self.getAuthUser().split("/")[-1]
auth_user = self.getAuthUser()
if auth_user is None:
raise RuntimeError("Failed to get authenticated user information")
return auth_user.split("/")[-1]

@staticmethod
def check_string_for_dot_or_slash(s):
Expand Down Expand Up @@ -524,7 +532,7 @@ def get_notebook_DataFed_ID_from_path_and_title(
def data_record_create(
self,
metadata=None,
record_title=None,
record_title: Optional[str] = None,
parent_collection=None,
deps=None,
**kwargs,
Expand All @@ -545,6 +553,9 @@ def data_record_create(
# make sure the user is logged into DataFed
self.check_if_logged_in()

if record_title is None:
raise ValueError("record_title cannot be None")

# If the record title is longer than the maximum allowed by DataFed (80 characters)
# truncate the record title to 80 characters. If logging is true, print out a statement letting the user
# know the record_title has been truncated.
Expand Down Expand Up @@ -593,7 +604,7 @@ def data_record_create(
def data_record_update(
self,
record_id=None,
record_title=None,
record_title: Optional[str] = None,
metadata=None,
deps=None,
overwrite_metadata=False,
Expand All @@ -617,6 +628,9 @@ def data_record_update(
# make sure the user is logged into DataFed
self.check_if_logged_in()

if record_title is None:
raise ValueError("record_title cannot be None")

# If the record title is longer than the maximum allowed by DataFed (80 characters)
# truncate the record title to 80 characters. If logging is true, print out a statement letting the user
# know the record_title has been truncated.
Expand Down Expand Up @@ -835,7 +849,6 @@ def _get_metadata_list(self, record_ids, exclude=None):

return metadata

@staticmethod
def required_keys(self, dict_list, required_keys):
"""
Filters a list of dictionaries to include only those that contain all specified required keys.
Expand Down Expand Up @@ -1155,12 +1168,16 @@ def getFileExtension(self):
# Split the file name by '.' and return the last part as the extension
return "." + self.getFileName(self.dataset_id_or_path).split(".")[-1]

def getData(self, dataset_id=None):
def getData(self, dataset_id: Optional[str] = None):
"""
Downloads the data from the dataset
"""

self.data_path = None # Bandaid fix; this needs to be defined for the class!

if dataset_id is None:
if self.dataset_id_or_path is None:
raise ValueError("dataset_id_or_path is not set")
dataset_id = self.dataset_id_or_path

# if a data path is not provided, download the data to the current directory
Expand Down
40 changes: 32 additions & 8 deletions src/datafed_torchflow/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
from datetime import datetime
from typing import Any, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(
script_path=None,
local_model_path="/.",
log_file_path="log.txt",
input_data_shape=None,
input_data_shape: Optional[tuple[int, ...]] = None, # Is this right?
dataset_id_or_path=None,
logging=False,
download_kwargs={"wait": True, "orig_fname": True},
Expand Down Expand Up @@ -140,7 +141,12 @@ def optimizer(self, optimizer):
"""
self._optimizer = optimizer

def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs):
def getMetadata(
self,
local_vars: Optional[list[tuple[str, Any]]] = None,
model_hyperparameters: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
Gathers metadata including the serialized model, optimizer, system info, and user details.

Expand Down Expand Up @@ -170,6 +176,12 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs):
"System Information": {},
}

if local_vars is None:
raise ValueError("local_vars cannot be None")

if model_hyperparameters is None:
raise ValueError("model_hyperparameters cannot be None")

# loop through the local variables to add to the metadata dictionary
for key, value in local_vars:
# exclude modules and other undesired local variables. Use casefold string matching for flexibility
Expand Down Expand Up @@ -259,7 +271,10 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs):
Warning(warning_message)
# put the model hyperparameters in the Model Hyperparameters sub-dictionary (the hyperparameters might be 1-value torch tensors or just floats)
elif key in model_hyperparameters.keys():
if type(value) in [np.ndarray, torch.Tensor]:
if (
isinstance(value, (np.ndarray, torch.Tensor))
and self.input_data_shape is not None
):
if value.shape < self.input_data_shape:
DataFed_record_metadata["Model Parameters"][
"Model Hyperparameters"
Expand All @@ -271,8 +286,11 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs):

# convert numpy arrays and torch tensors that are small enough (arbitrarily chosen to be smaller than the input data dimensions)
# into lists so they can be serialized into JSON
elif type(value) in [np.ndarray, torch.Tensor]:
if value.shape < self.input_data_shape:
elif isinstance(value, (np.ndarray, torch.Tensor)):
if (
self.input_data_shape is not None
and value.shape < self.input_data_shape
):
# put other lists into the Model Parameters dictionary
try:
json.dumps(value.tolist())
Expand Down Expand Up @@ -423,9 +441,12 @@ def save_notebook(self):

# generate a checksum (and scipt path) for the notebook
self.notebook_metadata = getNotebookMetadata(self.__file__)
if self.notebook_metadata is None:
raise ValueError(f"Failed to get metadata for notebook {self.__file__}")

# extract the checksum
new_checksum = self.notebook_metadata["script"]["checksum"]
old_checksum = None

# if the notebook has a DataFed record ID, extract the checksum and compare to the new checksum
if self.notebook_record_id is not None:
Expand Down Expand Up @@ -495,10 +516,13 @@ def save(
"""

# include the model architecture state dictionary and model hyperparameters in the checkpoint
if not str(local_file_path).endswith(".zip") and not os.path.exists(
str(local_file_path)
if (
local_file_path is not None
and not str(local_file_path).endswith(".zip")
and not os.path.exists(str(local_file_path))
):
checkpoint = self.getModelArchitectureStateDict() | model_hyperparameters
checkpoint = self.getModelArchitectureStateDict()
checkpoint.update(model_hyperparameters or {})

# Save the model state dict locally
torch.save(checkpoint, local_file_path)
Expand Down