From 3f16cbf0fb4b321f4d8e24b3b7c818d551bfafb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20R=C3=B8vang?= Date: Fri, 21 Jul 2023 20:26:24 +0200 Subject: [PATCH] formatting --- central_processing.py | 34 +++++++++++++---------- config/config.py | 8 +++--- data_download.py | 37 +++++++++++++++----------- neotemplate/base_central_processing.py | 19 +++++++------ noxfile.py | 36 ++++++++++++++++++++----- tests/general_test.py | 1 - utils/helpers.py | 4 +-- xmodules/models/mock_model.py | 3 +-- 8 files changed, 87 insertions(+), 55 deletions(-) diff --git a/central_processing.py b/central_processing.py index 3a896b1..91ebf0c 100755 --- a/central_processing.py +++ b/central_processing.py @@ -30,10 +30,14 @@ def __init__(self) -> NoReturn: """Constructor for the central processing unit.""" super().__init__() logger.info("Initializing central processing unit") - self.test_data = np.random.randn(32, 32, 32) # Please make sure this data mimics your own data + self.test_data = np.random.randn( + 32, 32, 32 + ) # Please make sure this data mimics your own data @timer - def preprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray: + def preprocess( + self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {} + ) -> np.ndarray: """Preprocess the data before training/val/test/predict. Parameters @@ -96,14 +100,14 @@ def predict_step(self, data: np.ndarray, model: config.ModelInput) -> np.ndarray # TODO: Your prediction code here # --------------------- # - logger.success(f"=> Prediction completed successfully") + logger.success("=> Prediction completed successfully") return data except ( - NameError, - ValueError, - TypeError, - AttributeError, - RuntimeError, + NameError, + ValueError, + TypeError, + AttributeError, + RuntimeError, ) as e: msg = f"I failed predicting the image with error: {e}" logger.exception(msg) @@ -112,7 +116,9 @@ def predict_step(self, data: np.ndarray, model: config.ModelInput) -> np.ndarray logger.exception(msg) @timer - def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray: + def postprocess( + self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {} + ) -> np.ndarray: """Postprocess the data after training/val/test/predict Parameters @@ -144,11 +150,11 @@ def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) - logger.success("=> Postprocessing completed successfully") return data except ( - NameError, - ValueError, - TypeError, - AttributeError, - RuntimeError, + NameError, + ValueError, + TypeError, + AttributeError, + RuntimeError, ) as e: msg = f"I failed postprocessing with error {e}" logger.exception(msg) diff --git a/config/config.py b/config/config.py index e441fa8..eef68cc 100644 --- a/config/config.py +++ b/config/config.py @@ -15,8 +15,10 @@ # EXTRAS DATA_SAVE_DIR = "./data" -#github repo metadata -PROJECTMETADATAURL = "https://raw.githubusercontent.com/NeoMedSys/gingerbread_sc/main/pyproject.toml" +# github repo metadata +PROJECTMETADATAURL = ( + "https://raw.githubusercontent.com/NeoMedSys/gingerbread_sc/main/pyproject.toml" +) class BaseModel(PydanticBaseModel): @@ -25,4 +27,4 @@ class Config: class ModelInput(BaseModel): - model: torch.nn.Module \ No newline at end of file + model: torch.nn.Module diff --git a/data_download.py b/data_download.py index af77f6a..0a9ac93 100755 --- a/data_download.py +++ b/data_download.py @@ -20,16 +20,17 @@ class MedqueryDataDownloader: """ def __init__(self): - self.mq = pymq.PyMedQuery() logger.info("MedqueryDataDownloader initialized.") - def download_data(self, - project_id: str, - get_affines: bool = False, - get_all: bool = True, - include_mask: bool = False, - batch_size: int = 20) -> NoReturn: + def download_data( + self, + project_id: str, + get_affines: bool = False, + get_all: bool = True, + include_mask: bool = False, + batch_size: int = 20, + ) -> NoReturn: """Download data from MedQuery and save it to disk. Note @@ -55,16 +56,18 @@ def download_data(self, """ try: - large_data = self.mq.batch_extract(get_all=get_all, - get_affines=get_affines, - project_id=project_id, - batch_size=batch_size, - include_mask=include_mask) + large_data = self.mq.batch_extract( + get_all=get_all, + get_affines=get_affines, + project_id=project_id, + batch_size=batch_size, + include_mask=include_mask, + ) if not os.path.exists(cfg.DATA_SAVE_DIR): os.makedirs(cfg.DATA_SAVE_DIR) logger.info(f"Downloading data from MedQuery for project {project_id}") - with h5py.File(f'{cfg.DATA_SAVE_DIR}/{project_id}.hdf5', 'w') as f: + with h5py.File(f"{cfg.DATA_SAVE_DIR}/{project_id}.hdf5", "w") as f: for batch in tqdm(large_data, desc="Saving data to disk..."): for key, value in batch.items(): f.create_dataset(key, data=value) @@ -94,7 +97,7 @@ def hdf5_to_nifti_all(self, hdf5_path: str, output_dir: str) -> NoReturn: # make output directory if it does not exist if not os.path.exists(output_dir): os.makedirs(output_dir) - with h5py.File(hdf5_path, 'r') as hdf5: + with h5py.File(hdf5_path, "r") as hdf5: for series_uid, value_ in hdf5.items(): if "affine" in series_uid: continue @@ -107,7 +110,9 @@ def hdf5_to_nifti_all(self, hdf5_path: str, output_dir: str) -> NoReturn: except IndexError: logger.exception(f"Error with hdf5 indexing") - def hdf5_to_nifti_single(self, hdf5_path: str, output_dir: str, series_uid: str) -> NoReturn: + def hdf5_to_nifti_single( + self, hdf5_path: str, output_dir: str, series_uid: str + ) -> NoReturn: """Convert single series to nifti file. Parameters @@ -128,7 +133,7 @@ def hdf5_to_nifti_single(self, hdf5_path: str, output_dir: str, series_uid: str) This method assumed that the hdf5 file contains affine matrices and data. If this is not the case, the method will not work. """ try: - with h5py.File(hdf5_path, 'r') as hdf5: + with h5py.File(hdf5_path, "r") as hdf5: data = hdf5[series_uid] affine_uid = series_uid.replace("series", "affine") affine = hdf5[affine_uid] diff --git a/neotemplate/base_central_processing.py b/neotemplate/base_central_processing.py index 4293d93..bfa9e75 100644 --- a/neotemplate/base_central_processing.py +++ b/neotemplate/base_central_processing.py @@ -12,7 +12,7 @@ class CPNeoTemplate(nn.Module): - """Central processing unit for the NeoTemplate. """ + """Central processing unit for the NeoTemplate.""" def __init__(self) -> NoReturn: """Constructor for the central processing unit @@ -79,14 +79,13 @@ def save_checkpoint(self, checkpoint_path: str) -> NoReturn: # save with the state_dict and the hyperparameters logger.info(f"Saving checkpoint to {checkpoint_path}") torch.save( - { - "state_dict": self.state_dict(), - "hyperparameters": self.args - }, + {"state_dict": self.state_dict(), "hyperparameters": self.args}, checkpoint_path, ) - def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray: + def postprocess( + self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {} + ) -> np.ndarray: """Postprocess the data after training/val/test/predict Parameters @@ -113,7 +112,9 @@ def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) - except Exception as e: logger.error(f"Postprocessing failed with error {e}") - def preprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray: + def preprocess( + self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {} + ) -> np.ndarray: """Preprocess the data before training/val/test/predict Parameters @@ -162,8 +163,6 @@ def predict_step(self, data: np.ndarray) -> np.ndarray: except Exception: logger.exception("predict_step failed") - - def check_version(self): try: # Fetch the contents of the .toml file @@ -183,4 +182,4 @@ def check_version(self): ) except requests.exceptions.HTTPError as e: - logger.error(f"Error in version check: {e}") \ No newline at end of file + logger.error(f"Error in version check: {e}") diff --git a/noxfile.py b/noxfile.py index 68f2f9d..ada8450 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,15 +1,37 @@ import nox import os + @nox.session() def tests(session): - session.run('poetry', 'install', '--with', 'dev') - session.run('poetry', 'run', 'pytest', './tests', '--junitxml=./junit.xml') + session.run("poetry", "install", "--with", "dev") + session.run("poetry", "run", "pytest", "./tests", "--junitxml=./junit.xml") # coverage - session.run('poetry', 'run', 'coverage', 'run', '--source=.', '--data-file', './.coverage', '-m', 'pytest', './tests') - session.run('poetry', 'run', 'coverage', 'xml') - + session.run( + "poetry", + "run", + "coverage", + "run", + "--source=.", + "--data-file", + "./.coverage", + "-m", + "pytest", + "./tests", + ) + session.run("poetry", "run", "coverage", "xml") + + @nox.session() def lint(session): - session.install('flake8') - session.run('flake8', '.', '--exit-zero', '--format=html', '--statistics', '--tee', '--output-file', 'flake8.txt') \ No newline at end of file + session.install("flake8") + session.run( + "flake8", + ".", + "--exit-zero", + "--format=html", + "--statistics", + "--tee", + "--output-file", + "flake8.txt", + ) diff --git a/tests/general_test.py b/tests/general_test.py index d61d83e..c09a481 100644 --- a/tests/general_test.py +++ b/tests/general_test.py @@ -9,7 +9,6 @@ class Test_General: - def test_imports(self): assert central_processing assert data_download diff --git a/utils/helpers.py b/utils/helpers.py index 81fec33..4ecb210 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -15,7 +15,7 @@ def timer(orig_func: Callable): ------- type elapsed runtime for the function. - + """ @wraps(orig_func) @@ -26,4 +26,4 @@ def wrapper(*args, **kwargs): logger.info("Runtime for {}: {} sec".format(orig_func.__name__, t2)) return result - return wrapper \ No newline at end of file + return wrapper diff --git a/xmodules/models/mock_model.py b/xmodules/models/mock_model.py index 9177606..136ef41 100644 --- a/xmodules/models/mock_model.py +++ b/xmodules/models/mock_model.py @@ -2,7 +2,6 @@ class MockModel(torch.nn.Module): - def __init__(self, *args, **kwargs): super().__init__() @@ -10,4 +9,4 @@ def forward(self, x): return x def load_state_dict(self, state_dict, strict=True): - pass \ No newline at end of file + pass