Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinRovang committed Jul 21, 2023
1 parent 753c5dc commit 3f16cbf
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 55 deletions.
34 changes: 20 additions & 14 deletions central_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ def __init__(self) -> NoReturn:
"""Constructor for the central processing unit."""
super().__init__()
logger.info("Initializing central processing unit")
self.test_data = np.random.randn(32, 32, 32) # Please make sure this data mimics your own data
self.test_data = np.random.randn(
32, 32, 32
) # Please make sure this data mimics your own data

@timer
def preprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray:
def preprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Preprocess the data before training/val/test/predict.
Parameters
Expand Down Expand Up @@ -96,14 +100,14 @@ def predict_step(self, data: np.ndarray, model: config.ModelInput) -> np.ndarray
# TODO: Your prediction code here
# --------------------- #

logger.success(f"=> Prediction completed successfully")
logger.success("=> Prediction completed successfully")
return data
except (
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
) as e:
msg = f"I failed predicting the image with error: {e}"
logger.exception(msg)
Expand All @@ -112,7 +116,9 @@ def predict_step(self, data: np.ndarray, model: config.ModelInput) -> np.ndarray
logger.exception(msg)

@timer
def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray:
def postprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Postprocess the data after training/val/test/predict
Parameters
Expand Down Expand Up @@ -144,11 +150,11 @@ def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -
logger.success("=> Postprocessing completed successfully")
return data
except (
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
NameError,
ValueError,
TypeError,
AttributeError,
RuntimeError,
) as e:
msg = f"I failed postprocessing with error {e}"
logger.exception(msg)
Expand Down
8 changes: 5 additions & 3 deletions config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# EXTRAS
DATA_SAVE_DIR = "./data"

#github repo metadata
PROJECTMETADATAURL = "https://raw.githubusercontent.com/NeoMedSys/gingerbread_sc/main/pyproject.toml"
# github repo metadata
PROJECTMETADATAURL = (
"https://raw.githubusercontent.com/NeoMedSys/gingerbread_sc/main/pyproject.toml"
)


class BaseModel(PydanticBaseModel):
Expand All @@ -25,4 +27,4 @@ class Config:


class ModelInput(BaseModel):
model: torch.nn.Module
model: torch.nn.Module
37 changes: 21 additions & 16 deletions data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ class MedqueryDataDownloader:
"""

def __init__(self):

self.mq = pymq.PyMedQuery()
logger.info("MedqueryDataDownloader initialized.")

def download_data(self,
project_id: str,
get_affines: bool = False,
get_all: bool = True,
include_mask: bool = False,
batch_size: int = 20) -> NoReturn:
def download_data(
self,
project_id: str,
get_affines: bool = False,
get_all: bool = True,
include_mask: bool = False,
batch_size: int = 20,
) -> NoReturn:
"""Download data from MedQuery and save it to disk.
Note
Expand All @@ -55,16 +56,18 @@ def download_data(self,
"""
try:
large_data = self.mq.batch_extract(get_all=get_all,
get_affines=get_affines,
project_id=project_id,
batch_size=batch_size,
include_mask=include_mask)
large_data = self.mq.batch_extract(
get_all=get_all,
get_affines=get_affines,
project_id=project_id,
batch_size=batch_size,
include_mask=include_mask,
)

if not os.path.exists(cfg.DATA_SAVE_DIR):
os.makedirs(cfg.DATA_SAVE_DIR)
logger.info(f"Downloading data from MedQuery for project {project_id}")
with h5py.File(f'{cfg.DATA_SAVE_DIR}/{project_id}.hdf5', 'w') as f:
with h5py.File(f"{cfg.DATA_SAVE_DIR}/{project_id}.hdf5", "w") as f:
for batch in tqdm(large_data, desc="Saving data to disk..."):
for key, value in batch.items():
f.create_dataset(key, data=value)
Expand Down Expand Up @@ -94,7 +97,7 @@ def hdf5_to_nifti_all(self, hdf5_path: str, output_dir: str) -> NoReturn:
# make output directory if it does not exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with h5py.File(hdf5_path, 'r') as hdf5:
with h5py.File(hdf5_path, "r") as hdf5:
for series_uid, value_ in hdf5.items():
if "affine" in series_uid:
continue
Expand All @@ -107,7 +110,9 @@ def hdf5_to_nifti_all(self, hdf5_path: str, output_dir: str) -> NoReturn:
except IndexError:
logger.exception(f"Error with hdf5 indexing")

def hdf5_to_nifti_single(self, hdf5_path: str, output_dir: str, series_uid: str) -> NoReturn:
def hdf5_to_nifti_single(
self, hdf5_path: str, output_dir: str, series_uid: str
) -> NoReturn:
"""Convert single series to nifti file.
Parameters
Expand All @@ -128,7 +133,7 @@ def hdf5_to_nifti_single(self, hdf5_path: str, output_dir: str, series_uid: str)
This method assumed that the hdf5 file contains affine matrices and data. If this is not the case, the method will not work.
"""
try:
with h5py.File(hdf5_path, 'r') as hdf5:
with h5py.File(hdf5_path, "r") as hdf5:
data = hdf5[series_uid]
affine_uid = series_uid.replace("series", "affine")
affine = hdf5[affine_uid]
Expand Down
19 changes: 9 additions & 10 deletions neotemplate/base_central_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class CPNeoTemplate(nn.Module):
"""Central processing unit for the NeoTemplate. """
"""Central processing unit for the NeoTemplate."""

def __init__(self) -> NoReturn:
"""Constructor for the central processing unit
Expand Down Expand Up @@ -79,14 +79,13 @@ def save_checkpoint(self, checkpoint_path: str) -> NoReturn:
# save with the state_dict and the hyperparameters
logger.info(f"Saving checkpoint to {checkpoint_path}")
torch.save(
{
"state_dict": self.state_dict(),
"hyperparameters": self.args
},
{"state_dict": self.state_dict(), "hyperparameters": self.args},
checkpoint_path,
)

def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray:
def postprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Postprocess the data after training/val/test/predict
Parameters
Expand All @@ -113,7 +112,9 @@ def postprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -
except Exception as e:
logger.error(f"Postprocessing failed with error {e}")

def preprocess(self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}) -> np.ndarray:
def preprocess(
self, data: np.ndarray, extras: Optional[Dict[str, Any]] = {}
) -> np.ndarray:
"""Preprocess the data before training/val/test/predict
Parameters
Expand Down Expand Up @@ -162,8 +163,6 @@ def predict_step(self, data: np.ndarray) -> np.ndarray:
except Exception:
logger.exception("predict_step failed")



def check_version(self):
try:
# Fetch the contents of the .toml file
Expand All @@ -183,4 +182,4 @@ def check_version(self):
)

except requests.exceptions.HTTPError as e:
logger.error(f"Error in version check: {e}")
logger.error(f"Error in version check: {e}")
36 changes: 29 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
import nox
import os


@nox.session()
def tests(session):
session.run('poetry', 'install', '--with', 'dev')
session.run('poetry', 'run', 'pytest', './tests', '--junitxml=./junit.xml')
session.run("poetry", "install", "--with", "dev")
session.run("poetry", "run", "pytest", "./tests", "--junitxml=./junit.xml")
# coverage
session.run('poetry', 'run', 'coverage', 'run', '--source=.', '--data-file', './.coverage', '-m', 'pytest', './tests')
session.run('poetry', 'run', 'coverage', 'xml')

session.run(
"poetry",
"run",
"coverage",
"run",
"--source=.",
"--data-file",
"./.coverage",
"-m",
"pytest",
"./tests",
)
session.run("poetry", "run", "coverage", "xml")


@nox.session()
def lint(session):
session.install('flake8')
session.run('flake8', '.', '--exit-zero', '--format=html', '--statistics', '--tee', '--output-file', 'flake8.txt')
session.install("flake8")
session.run(
"flake8",
".",
"--exit-zero",
"--format=html",
"--statistics",
"--tee",
"--output-file",
"flake8.txt",
)
1 change: 0 additions & 1 deletion tests/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class Test_General:

def test_imports(self):
assert central_processing
assert data_download
Expand Down
4 changes: 2 additions & 2 deletions utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def timer(orig_func: Callable):
-------
type
elapsed runtime for the function.
"""

@wraps(orig_func)
Expand All @@ -26,4 +26,4 @@ def wrapper(*args, **kwargs):
logger.info("Runtime for {}: {} sec".format(orig_func.__name__, t2))
return result

return wrapper
return wrapper
3 changes: 1 addition & 2 deletions xmodules/models/mock_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@


class MockModel(torch.nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()

def forward(self, x):
return x

def load_state_dict(self, state_dict, strict=True):
pass
pass

0 comments on commit 3f16cbf

Please sign in to comment.