From 6aa92e57e448d1b04d4bcc2f1aeef9d7e73a2677 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 12:34:49 -0400 Subject: [PATCH 1/7] Sort dependencies --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index cd8a8f1..d277304 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,12 +49,12 @@ package_dir = # For more information, check out https://semver.org/. install_requires = importlib-metadata; python_version<"3.8" + datafed GPutil + m3learning-util psutil - datafed torch torchvision - m3learning-util [options.packages.find] where = src From 9492bbae3b680ba9fdd085ccb6cc8614acf7f210 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 12:36:29 -0400 Subject: [PATCH 2/7] Add a couple dev dependencies --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index d277304..b980d0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,8 @@ install_requires = GPutil m3learning-util psutil + pyright + ruff torch torchvision From efcf9ed81401cc38316dc3f5761133341dfe78b0 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 13:14:49 -0400 Subject: [PATCH 3/7] Add a missing dependency --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index b980d0b..3967ff6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,6 +55,7 @@ install_requires = psutil pyright ruff + sphinx torch torchvision From c09e59abb9f0d96c4a995a88bef7a69293dbcc63 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 13:15:22 -0400 Subject: [PATCH 4/7] Fix overload function signature --- src/datafed_torchflow/JSON.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/datafed_torchflow/JSON.py b/src/datafed_torchflow/JSON.py index 3341c01..ee1f12c 100644 --- a/src/datafed_torchflow/JSON.py +++ b/src/datafed_torchflow/JSON.py @@ -8,7 +8,7 @@ class UniversalEncoder(json.JSONEncoder): A custom JSON encoder that can handle numpy data types, sets, and objects with __dict__ attributes. """ - def default(self, obj): + def default(self, o): """ Override the default method to provide custom serialization for unsupported data types. @@ -19,16 +19,16 @@ def default(self, obj): any: The serialized form of the object. """ # Convert numpy types to their Python equivalents - if isinstance(obj, np.integer): - return int(obj) # Convert numpy integers to Python int - elif isinstance(obj, np.floating): - return float(obj) # Convert numpy floats to Python float - elif isinstance(obj, np.ndarray): - return obj.tolist() # Convert numpy arrays to lists - elif isinstance(obj, set): - return list(obj) # Convert sets to lists - elif hasattr(obj, "__dict__"): - return obj.__dict__ # Serialize object attributes + if isinstance(o, np.integer): + return int(o) # Convert numpy integers to Python int + elif isinstance(o, np.floating): + return float(o) # Convert numpy floats to Python float + elif isinstance(o, np.ndarray): + return o.tolist() # Convert numpy arrays to lists + elif isinstance(o, set): + return list(o) # Convert sets to lists + elif hasattr(o, "__dict__"): + return o.__dict__ # Serialize object attributes else: # Call the default method for other cases - return super().default(obj) + return super().default(o) From 471c6cd520ec8fe3623c0b5b2bee1b7bbf4a6f84 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 13:15:32 -0400 Subject: [PATCH 5/7] Auto-format files --- examples/PyTorchModelLogger.ipynb | 2 +- examples/PyTorchModelLogger.old.ipynb | 2 +- src/datafed_torchflow/computer.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/PyTorchModelLogger.ipynb b/examples/PyTorchModelLogger.ipynb index 11c3db4..0fb8f34 100644 --- a/examples/PyTorchModelLogger.ipynb +++ b/examples/PyTorchModelLogger.ipynb @@ -256,7 +256,7 @@ " if batch_idx % 100 == 0:\n", " print(\n", " f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} \"\n", - " f\"({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n", + " f\"({100.0 * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n", " )\n", "\n", " file_name = f\"MNSIT_epoch_{epoch}_loss_{loss.item():.4e}\"\n", diff --git a/examples/PyTorchModelLogger.old.ipynb b/examples/PyTorchModelLogger.old.ipynb index d6e3297..f70787e 100644 --- a/examples/PyTorchModelLogger.old.ipynb +++ b/examples/PyTorchModelLogger.old.ipynb @@ -223,7 +223,7 @@ " if batch_idx % 100 == 0:\n", " print(\n", " f\"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} \"\n", - " f\"({100. * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n", + " f\"({100.0 * batch_idx / len(train_loader):.0f}%)]\\tLoss: {loss.item():.6f}\"\n", " )\n", "\n", " file_name = f\"MNSIT_epoch_{epoch}_loss_{loss.item():.4e}\"\n", diff --git a/src/datafed_torchflow/computer.py b/src/datafed_torchflow/computer.py index c5816a4..b08d136 100644 --- a/src/datafed_torchflow/computer.py +++ b/src/datafed_torchflow/computer.py @@ -52,9 +52,9 @@ def get_memory_info(): """ mem = psutil.virtual_memory() return { - "total": f"{mem.total / (1024 ** 3):.2f} GB", - "available": f"{mem.available / (1024 ** 3):.2f} GB", - "used": f"{mem.used / (1024 ** 3):.2f} GB", + "total": f"{mem.total / (1024**3):.2f} GB", + "available": f"{mem.available / (1024**3):.2f} GB", + "used": f"{mem.used / (1024**3):.2f} GB", "percent": mem.percent, } From c774fce5b856013703700adb4d120038324ee04a Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 13:17:17 -0400 Subject: [PATCH 6/7] Fix some type errors --- src/datafed_torchflow/pytorch.py | 40 +++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/datafed_torchflow/pytorch.py b/src/datafed_torchflow/pytorch.py index 805ad11..a61df2a 100644 --- a/src/datafed_torchflow/pytorch.py +++ b/src/datafed_torchflow/pytorch.py @@ -1,6 +1,7 @@ import os import sys from datetime import datetime +from typing import Any, Optional import torch import torch.nn as nn @@ -60,7 +61,7 @@ def __init__( script_path=None, local_model_path="/.", log_file_path="log.txt", - input_data_shape=None, + input_data_shape: Optional[tuple[int, ...]] = None, # Is this right? dataset_id_or_path=None, logging=False, download_kwargs={"wait": True, "orig_fname": True}, @@ -140,7 +141,12 @@ def optimizer(self, optimizer): """ self._optimizer = optimizer - def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs): + def getMetadata( + self, + local_vars: Optional[list[tuple[str, Any]]] = None, + model_hyperparameters: Optional[dict[str, Any]] = None, + **kwargs, + ): """ Gathers metadata including the serialized model, optimizer, system info, and user details. @@ -170,6 +176,12 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs): "System Information": {}, } + if local_vars is None: + raise ValueError("local_vars cannot be None") + + if model_hyperparameters is None: + raise ValueError("model_hyperparameters cannot be None") + # loop through the local variables to add to the metadata dictionary for key, value in local_vars: # exclude modules and other undesired local variables. Use casefold string matching for flexibility @@ -259,7 +271,10 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs): Warning(warning_message) # put the model hyperparameters in the Model Hyperparameters sub-dictionary (the hyperparameters might be 1-value torch tensors or just floats) elif key in model_hyperparameters.keys(): - if type(value) in [np.ndarray, torch.Tensor]: + if ( + isinstance(value, (np.ndarray, torch.Tensor)) + and self.input_data_shape is not None + ): if value.shape < self.input_data_shape: DataFed_record_metadata["Model Parameters"][ "Model Hyperparameters" @@ -271,8 +286,11 @@ def getMetadata(self, local_vars=None, model_hyperparameters=None, **kwargs): # convert numpy arrays and torch tensors that are small enough (arbitrarily chosen to be smaller than the input data dimensions) # into lists so they can be serialized into JSON - elif type(value) in [np.ndarray, torch.Tensor]: - if value.shape < self.input_data_shape: + elif isinstance(value, (np.ndarray, torch.Tensor)): + if ( + self.input_data_shape is not None + and value.shape < self.input_data_shape + ): # put other lists into the Model Parameters dictionary try: json.dumps(value.tolist()) @@ -423,9 +441,12 @@ def save_notebook(self): # generate a checksum (and scipt path) for the notebook self.notebook_metadata = getNotebookMetadata(self.__file__) + if self.notebook_metadata is None: + raise ValueError(f"Failed to get metadata for notebook {self.__file__}") # extract the checksum new_checksum = self.notebook_metadata["script"]["checksum"] + old_checksum = None # if the notebook has a DataFed record ID, extract the checksum and compare to the new checksum if self.notebook_record_id is not None: @@ -495,10 +516,13 @@ def save( """ # include the model architecture state dictionary and model hyperparameters in the checkpoint - if not str(local_file_path).endswith(".zip") and not os.path.exists( - str(local_file_path) + if ( + local_file_path is not None + and not str(local_file_path).endswith(".zip") + and not os.path.exists(str(local_file_path)) ): - checkpoint = self.getModelArchitectureStateDict() | model_hyperparameters + checkpoint = self.getModelArchitectureStateDict() + checkpoint.update(model_hyperparameters or {}) # Save the model state dict locally torch.save(checkpoint, local_file_path) From 8838b294f673cd9abc1c813f73390d4be72ae803 Mon Sep 17 00:00:00 2001 From: Theo Beers <32523293+theodore-s-beers@users.noreply.github.com> Date: Tue, 6 May 2025 13:37:51 -0400 Subject: [PATCH 7/7] Fix more type errors --- src/datafed_torchflow/datafed.py | 35 ++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/src/datafed_torchflow/datafed.py b/src/datafed_torchflow/datafed.py index 01a4c9b..7a1fdb0 100644 --- a/src/datafed_torchflow/datafed.py +++ b/src/datafed_torchflow/datafed.py @@ -2,6 +2,7 @@ import os import traceback from datetime import datetime +from typing import Optional import numpy as np from datafed.CommandLib import API @@ -29,10 +30,10 @@ class DataFed(API): def __init__( self, - datafed_path, + datafed_path: str, local_model_path="./Trained Models", log_file_path="log.txt", - dataset_id_or_path=None, + dataset_id_or_path: Optional[str] = None, download_kwargs={"wait": True, "orig_fname": True}, upload_kwargs={"wait": True}, logging=False, @@ -91,11 +92,12 @@ def upload_dataset_to_DataFed(self): The DataFed record ID for the dataset files, as a string for a single dataset file and a list of strings for multiple dataset files. """ if self.dataset_id_or_path is not None: + ls_resp = self.collectionItemsList(self.collection_id) + if isinstance( self.dataset_id_or_path, list ): # to specify multiple dataset files dataset_ids = [] - ls_resp = self.collectionItemsList(self.collection_id) for dataset in self.dataset_id_or_path: if dataset.startswith("d/"): dataset_ids.append(dataset) @@ -163,6 +165,8 @@ def upload_dataset_to_DataFed(self): if self.dataset_id_or_path.startswith("d/"): dataset_ids = self.dataset_id_or_path else: + dataset = self.dataset_id_or_path # Is this right? + try: path_id = ( ls_resp[0] @@ -176,6 +180,7 @@ def upload_dataset_to_DataFed(self): ] .id ) + # update record (dependencies have been added) record_id = self.get_notebook_DataFed_ID_from_path_and_title( dataset, path_id=path_id @@ -345,14 +350,17 @@ def check_if_endpoint_set(self): ) @property - def user_id(self): + def user_id(self) -> str: """ Gets the user ID from the authenticated user's information. Returns: str: The user ID extracted from the authenticated user information. """ - return self.getAuthUser().split("/")[-1] + auth_user = self.getAuthUser() + if auth_user is None: + raise RuntimeError("Failed to get authenticated user information") + return auth_user.split("/")[-1] @staticmethod def check_string_for_dot_or_slash(s): @@ -524,7 +532,7 @@ def get_notebook_DataFed_ID_from_path_and_title( def data_record_create( self, metadata=None, - record_title=None, + record_title: Optional[str] = None, parent_collection=None, deps=None, **kwargs, @@ -545,6 +553,9 @@ def data_record_create( # make sure the user is logged into DataFed self.check_if_logged_in() + if record_title is None: + raise ValueError("record_title cannot be None") + # If the record title is longer than the maximum allowed by DataFed (80 characters) # truncate the record title to 80 characters. If logging is true, print out a statement letting the user # know the record_title has been truncated. @@ -593,7 +604,7 @@ def data_record_create( def data_record_update( self, record_id=None, - record_title=None, + record_title: Optional[str] = None, metadata=None, deps=None, overwrite_metadata=False, @@ -617,6 +628,9 @@ def data_record_update( # make sure the user is logged into DataFed self.check_if_logged_in() + if record_title is None: + raise ValueError("record_title cannot be None") + # If the record title is longer than the maximum allowed by DataFed (80 characters) # truncate the record title to 80 characters. If logging is true, print out a statement letting the user # know the record_title has been truncated. @@ -835,7 +849,6 @@ def _get_metadata_list(self, record_ids, exclude=None): return metadata - @staticmethod def required_keys(self, dict_list, required_keys): """ Filters a list of dictionaries to include only those that contain all specified required keys. @@ -1155,12 +1168,16 @@ def getFileExtension(self): # Split the file name by '.' and return the last part as the extension return "." + self.getFileName(self.dataset_id_or_path).split(".")[-1] - def getData(self, dataset_id=None): + def getData(self, dataset_id: Optional[str] = None): """ Downloads the data from the dataset """ + self.data_path = None # Bandaid fix; this needs to be defined for the class! + if dataset_id is None: + if self.dataset_id_or_path is None: + raise ValueError("dataset_id_or_path is not set") dataset_id = self.dataset_id_or_path # if a data path is not provided, download the data to the current directory