Skip to content

Commit

Permalink
Update imports for image transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Nov 21, 2023
1 parent 08b4659 commit cec48e8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cyclops/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, config
from monai.transforms import Compose # type: ignore
from sklearn.compose import ColumnTransformer
from sklearn.exceptions import NotFittedError
from torchvision.transforms import Compose

from cyclops.data.slicer import SliceSpec
from cyclops.evaluate.evaluator import evaluate
Expand Down Expand Up @@ -419,7 +419,7 @@ def predict(
splits_mapping = {"test": "test"}
model_name, model = self.get_model(model_name)
if transforms:
transforms = partial(apply_image_transforms, transforms=transforms) # type: ignore
transforms = partial(apply_image_transforms, transforms=transforms)
if isinstance(dataset, (Dataset, DatasetDict)):
return model.predict(
dataset,
Expand Down
9 changes: 5 additions & 4 deletions docs/source/tutorials/nihcxr/cxr_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@
"\n",
"import numpy as np\n",
"import plotly.express as px\n",
"from monai.transforms import Compose, Lambdad, Resized\n",
"from torchvision.transforms import Compose\n",
"from torchxrayvision.models import DenseNet\n",
"\n",
"from cyclops.data.loader import load_nihcxr\n",
"from cyclops.data.slicer import (\n",
" SliceSpec,\n",
" filter_value, # noqa: E402\n",
")\n",
"from cyclops.data.transforms import Lambdad, Resized\n",
"from cyclops.data.utils import apply_transforms\n",
"from cyclops.evaluate import evaluator\n",
"from cyclops.evaluate.metrics.factory import create_metric\n",
Expand Down Expand Up @@ -635,9 +636,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "cyclops",
"language": "python",
"name": "python3"
"name": "cyclops"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -649,7 +650,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
5 changes: 3 additions & 2 deletions docs/source/tutorials/nihcxr/generate_nihcxr_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import numpy as np
import numpy.typing as npt
import plotly.express as px
from monai.transforms import Compose, Lambdad, Resized # type: ignore[attr-defined]
from torchvision.transforms import Compose
from torchxrayvision.models import DenseNet

from cyclops.data.loader import load_nihcxr
from cyclops.data.slicer import (
SliceSpec,
filter_value, # noqa: E402
)
from cyclops.data.transforms import Lambdad, Resized
from cyclops.data.utils import apply_transforms
from cyclops.evaluate import evaluator
from cyclops.evaluate.metrics.factory import create_metric
Expand Down Expand Up @@ -55,7 +56,7 @@
allow_missing_keys=True,
),
Lambdad(
("image",),
keys=("image",),
func=lambda x: np.mean(x, axis=0)[np.newaxis, :] if x.shape[0] != 1 else x,
),
],
Expand Down

0 comments on commit cec48e8

Please sign in to comment.