From 94324963c92302ad3f7e16f6d18d54b7ea373458 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 1 Dec 2023 09:06:12 +0800 Subject: [PATCH] Add classification template (#533) Part of tutorial#1456 ### Description Add a classification template ### Status **Ready** ### Please ensure all the checkboxes: - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: KumoLiu --- models/classification_template/LICENSE | 21 ++ .../configs/evaluate.yaml | 38 +++ .../configs/inference.yaml | 115 +++++++++ .../configs/logging.conf | 21 ++ .../configs/metadata.json | 63 +++++ .../configs/multi_gpu_train.yaml | 37 +++ .../configs/train.yaml | 233 +++++++++++++++++ models/classification_template/docs/README.md | 55 ++++ .../docs/generate_data.ipynb | 236 ++++++++++++++++++ .../classification_template/large_files.yml | 5 + .../configs/metadata.json | 3 +- .../scripts/detection_inferer.py | 2 +- 12 files changed, 827 insertions(+), 2 deletions(-) create mode 100644 models/classification_template/LICENSE create mode 100644 models/classification_template/configs/evaluate.yaml create mode 100644 models/classification_template/configs/inference.yaml create mode 100644 models/classification_template/configs/logging.conf create mode 100644 models/classification_template/configs/metadata.json create mode 100644 models/classification_template/configs/multi_gpu_train.yaml create mode 100644 models/classification_template/configs/train.yaml create mode 100644 models/classification_template/docs/README.md create mode 100644 models/classification_template/docs/generate_data.ipynb create mode 100644 models/classification_template/large_files.yml diff --git a/models/classification_template/LICENSE b/models/classification_template/LICENSE new file mode 100644 index 00000000..5a2a4c0f --- /dev/null +++ b/models/classification_template/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 MONAI Consortium + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/models/classification_template/configs/evaluate.yaml b/models/classification_template/configs/evaluate.yaml new file mode 100644 index 00000000..42ef0675 --- /dev/null +++ b/models/classification_template/configs/evaluate.yaml @@ -0,0 +1,38 @@ +# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. + +# these transforms are used for inference to load and regularise inputs +transforms: +- _target_: AsDiscreted + keys: ['@pred', '@label'] + argmax: [true, false] + to_onehot: '@num_classes' +- _target_: ToTensord + keys: ['@pred', '@label'] + device: '@device' + +postprocessing: + _target_: Compose + transforms: $@transforms + +# inference handlers to load checkpoint, gather statistics +val_handlers: +- _target_: CheckpointLoader + _disabled_: $not os.path.exists(@ckpt_path) + load_path: '@ckpt_path' + load_dict: + model: '@network' +- _target_: StatsHandler + name: null # use engine.logger as the Logger object to log to + output_transform: '$lambda x: None' +- _target_: MetricsSaver + save_dir: '@output_dir' + metrics: ['val_accuracy'] + metric_details: ['val_accuracy'] + batch_transform: "$lambda x: [xx['image'].meta for xx in x]" + summary_ops: "*" + +initialize: +- "$monai.utils.set_determinism(seed=123)" +- "$setattr(torch.backends.cudnn, 'benchmark', True)" +run: +- $@evaluator.run() diff --git a/models/classification_template/configs/inference.yaml b/models/classification_template/configs/inference.yaml new file mode 100644 index 00000000..68115e94 --- /dev/null +++ b/models/classification_template/configs/inference.yaml @@ -0,0 +1,115 @@ +# This implements the workflow for applying the network to a directory of images and measuring network performance with metrics. + +imports: +- $import os +- $import json +- $import torch +- $import glob + +# pull out some constants from MONAI +image: $monai.utils.CommonKeys.IMAGE +label: $monai.utils.CommonKeys.LABEL +pred: $monai.utils.CommonKeys.PRED + +# hyperparameters for you to modify on the command line +batch_size: 1 # number of images per batch +num_workers: 0 # number of workers to generate batches with +num_classes: 4 # number of classes in training data which network should predict +device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +# define various paths +bundle_root: . # root directory of the bundle +ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting +dataset_dir: $@bundle_root + '/data/test_data' # where data is coming from + +# network definition, this could be parameterised by pre-defined values or on the command line +network_def: + _target_: DenseNet121 + spatial_dims: 2 + in_channels: 1 + out_channels: '@num_classes' +network: $@network_def.to(@device) + +# list all niftis in the input directory +test_json: "$@bundle_root+'/data/test_samples.json'" +test_fp: "$open(@test_json,'r', encoding='utf8')" +# load json file +test_dict: "$json.load(@test_fp)" + +# these transforms are used for inference to load and regularise inputs +transforms: +- _target_: LoadImaged + keys: '@image' +- _target_: EnsureChannelFirstd + keys: '@image' +- _target_: ScaleIntensityd + keys: '@image' + +preprocessing: + _target_: Compose + transforms: $@transforms + +dataset: + _target_: Dataset + data: '@test_dict' + transform: '@preprocessing' + +dataloader: + _target_: ThreadDataLoader # generate data ansynchronously from inference + dataset: '@dataset' + batch_size: '@batch_size' + num_workers: '@num_workers' + +# should be replaced with other inferer types if training process is different for your network +inferer: + _target_: SimpleInferer + +# transform to apply to data from network to be suitable for validation +postprocessing: + _target_: Compose + transforms: + - _target_: Activationsd + keys: '@pred' + softmax: true + - _target_: AsDiscreted + keys: ['@pred', '@label'] + argmax: [true, false] + to_onehot: '@num_classes' + - _target_: ToTensord + keys: ['@pred', '@label'] + device: '@device' + +# inference handlers to load checkpoint, gather statistics +val_handlers: +- _target_: CheckpointLoader + _disabled_: $not os.path.exists(@ckpt_path) + load_path: '@ckpt_path' + load_dict: + model: '@network' +- _target_: StatsHandler + name: null # use engine.logger as the Logger object to log to + output_transform: '$lambda x: None' + +# engine for running inference, ties together objects defined above and has metric definitions +evaluator: + _target_: SupervisedEvaluator + device: '@device' + val_data_loader: '@dataloader' + network: '@network' + inferer: '@inferer' + postprocessing: '@postprocessing' + key_val_metric: + val_accuracy: + _target_: ignite.metrics.Accuracy + output_transform: $monai.handlers.from_engine([@pred, @label]) + additional_metrics: + val_f1: # can have other metrics + _target_: ConfusionMatrix + metric_name: 'f1 score' + output_transform: $monai.handlers.from_engine([@pred, @label]) + val_handlers: '@val_handlers' + +initialize: +- "$setattr(torch.backends.cudnn, 'benchmark', True)" +run: +- "$@evaluator.run()" diff --git a/models/classification_template/configs/logging.conf b/models/classification_template/configs/logging.conf new file mode 100644 index 00000000..91c1a21c --- /dev/null +++ b/models/classification_template/configs/logging.conf @@ -0,0 +1,21 @@ +[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=fullFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=fullFormatter +args=(sys.stdout,) + +[formatter_fullFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s diff --git a/models/classification_template/configs/metadata.json b/models/classification_template/configs/metadata.json new file mode 100644 index 00000000..fa2428ce --- /dev/null +++ b/models/classification_template/configs/metadata.json @@ -0,0 +1,63 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", + "version": "0.0.1", + "changelog": { + "0.0.1": "Initial version" + }, + "monai_version": "1.3.0", + "pytorch_version": "2.0.1", + "numpy_version": "1.24.4", + "optional_packages_version": { + "pytorch-ignite": "0.4.12" + }, + "name": "Classification Template", + "task": "Classification Template in 2D images", + "description": "This is a template bundle for classifying in 2D, take this as a basis for your own bundles.", + "authors": "Yun Liu", + "copyright": "Copyright (c) 2023 MONAI Consortium", + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "magnitude", + "modality": "none", + "num_channels": 1, + "spatial_shape": [ + 128, + 128 + ], + "dtype": "float32", + "value_range": [], + "is_patch_data": false, + "channel_def": { + "0": "image" + } + } + }, + "outputs": { + "pred": { + "type": "probabilities", + "format": "classes", + "num_channels": 4, + "spatial_shape": [ + 1, + 4 + ], + "dtype": "float32", + "value_range": [ + 0, + 1, + 2, + 3 + ], + "is_patch_data": false, + "channel_def": { + "0": "background", + "1": "circle", + "2": "triangle", + "3": "rectangle" + } + } + } + } +} diff --git a/models/classification_template/configs/multi_gpu_train.yaml b/models/classification_template/configs/multi_gpu_train.yaml new file mode 100644 index 00000000..d41f5800 --- /dev/null +++ b/models/classification_template/configs/multi_gpu_train.yaml @@ -0,0 +1,37 @@ +# This file contains the changes to implement DDP training with the train.yaml config. + +device: "$torch.device('cuda:' + os.environ['LOCAL_RANK'])" # assumes GPU # matches rank # + +# wrap the network in a DistributedDataParallel instance, moving it to the chosen device for this process +network: + _target_: torch.nn.parallel.DistributedDataParallel + module: $@network_def.to(@device) + device_ids: ['@device'] + find_unused_parameters: true + +train_sampler: + _target_: DistributedSampler + dataset: '@train_dataset' + even_divisible: true + shuffle: true + +train_dataloader#sampler: '@train_sampler' +train_dataloader#shuffle: false + +val_sampler: + _target_: DistributedSampler + dataset: '@val_dataset' + even_divisible: false + shuffle: false + +val_dataloader#sampler: '@val_sampler' + +initialize: +- $import torch.distributed as dist +- $dist.init_process_group(backend='nccl') +- $torch.cuda.set_device(@device) +- $monai.utils.set_determinism(seed=123) # may want to choose a different seed or not do this here +run: +- '$@trainer.run()' +finalize: +- '$dist.is_initialized() and dist.destroy_process_group()' diff --git a/models/classification_template/configs/train.yaml b/models/classification_template/configs/train.yaml new file mode 100644 index 00000000..d7eed769 --- /dev/null +++ b/models/classification_template/configs/train.yaml @@ -0,0 +1,233 @@ +# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for +# multi-GPU runs. + +imports: +- $import os +- $import json +- $import datetime +- $import torch +- $import glob + +# pull out some constants from MONAI +image: $monai.utils.CommonKeys.IMAGE +label: $monai.utils.CommonKeys.LABEL +pred: $monai.utils.CommonKeys.PRED + +# multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes +rank: 0 # without multi-gpu support consider the process as rank 0 anyway +is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks + +# hyperparameters for you to modify on the command line +val_interval: 1 # how often to perform validation after an epoch +ckpt_interval: 1 # how often to save a checkpoint after an epoch +rand_prob: 0.5 # probability a random transform is applied +batch_size: 5 # number of images per batch +num_epochs: 10 # number of epochs to train for +num_substeps: 1 # how many times to repeatly train with the same batch +num_workers: 4 # number of workers to generate batches with +learning_rate: 0.001 # initial learning rate +num_classes: 4 # number of classes in training data which network should predict +device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + +# define various paths +bundle_root: . # root directory of the bundle +ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting +dataset_dir: $@bundle_root + '/data/train_data' # where data is coming from +results_dir: $@bundle_root + '/results' # where results are being stored to +# a new output directory is chosen using a timestamp for every invocation +output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' + +# network definition, this could be parameterised by pre-defined values or on the command line +network_def: + _target_: DenseNet121 + spatial_dims: 2 + in_channels: 1 + out_channels: '@num_classes' +network: $@network_def.to(@device) + +# dataset value, this assumes a JOSN file filled with img##.nii.gz file and label +data_json: $@bundle_root + '/data/train_samples.json' # where training data is located and label +data_fp: "$open(@data_json,'r', encoding='utf8')" +data_dict: "$json.load(@data_fp)" +partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)' +train_sub: '$@partitions[0]' # train partition +val_sub: '$@partitions[1]' # validation partition + +# these transforms are used for training and validation transform sequences +base_transforms: +- _target_: LoadImaged + keys: '@image' +- _target_: EnsureChannelFirstd + keys: '@image' + +# these are the random and regularising transforms used only for training +train_transforms: +- _target_: RandAxisFlipd + keys: '@image' + prob: '@rand_prob' +- _target_: RandRotate90d + keys: '@image' + prob: '@rand_prob' +- _target_: RandGaussianNoised + keys: '@image' + prob: '@rand_prob' + std: 0.05 +- _target_: ScaleIntensityd + keys: '@image' + +# these are used for validation data so no randomness +val_transforms: +- _target_: ScaleIntensityd + keys: '@image' + +# define the Compose objects for training and validation +preprocessing: + _target_: Compose + transforms: $@base_transforms + @train_transforms + +val_preprocessing: + _target_: Compose + transforms: $@base_transforms + @val_transforms + +# define the datasets for training and validation +train_dataset: + _target_: Dataset + data: '@train_sub' + transform: '@preprocessing' + +val_dataset: + _target_: Dataset + data: '@val_sub' + transform: '@val_preprocessing' + +# define the dataloaders for training and validation +train_dataloader: + _target_: ThreadDataLoader # generate data ansynchronously from training + dataset: '@train_dataset' + batch_size: '@batch_size' + repeats: '@num_substeps' + num_workers: '@num_workers' + +val_dataloader: + _target_: DataLoader # faster transforms probably won't benefit from threading + dataset: '@val_dataset' + batch_size: '@batch_size' + num_workers: '@num_workers' + +# Simple CrossEntropy loss configured for multi-class classification +lossfn: + _target_: torch.nn.CrossEntropyLoss + reduction: sum + +# hyperparameters could be added for other arguments of this class +optimizer: + _target_: torch.optim.Adam + params: $@network.parameters() + lr: '@learning_rate' + +# should be replaced with other inferer types if training process is different for your network +inferer: + _target_: SimpleInferer + +# transform to apply to data from network to be suitable for validation +postprocessing: + _target_: Compose + transforms: + - _target_: Activationsd + keys: '@pred' + softmax: true + - _target_: AsDiscreted + keys: ['@pred', '@label'] + argmax: [true, false] + to_onehot: '@num_classes' + - _target_: ToTensord + keys: ['@pred', '@label'] + device: '@device' + +# validation handlers to gather statistics, log these to a file, and save best checkpoint +val_handlers: +- _target_: StatsHandler + name: null # use engine.logger as the Logger object to log to + output_transform: '$lambda x: None' +- _target_: LogfileHandler # log outputs from the validation engine + output_dir: '@output_dir' +- _target_: CheckpointSaver + _disabled_: '@is_not_rank0' # only need rank 0 to save + save_dir: '@output_dir' + save_dict: + model: '@network' + save_interval: 0 # don't save iterations, just when the metric improves + save_final: false + epoch_level: false + save_key_metric: true + key_metric_name: val_accuracy # save the checkpoint when this value improves + +# engine for running validation, ties together objects defined above and has metric definitions +evaluator: + _target_: SupervisedEvaluator + device: '@device' + val_data_loader: '@val_dataloader' + network: '@network' + postprocessing: '@postprocessing' + key_val_metric: + val_accuracy: + _target_: ignite.metrics.Accuracy + output_transform: $monai.handlers.from_engine([@pred, @label]) + additional_metrics: + val_f1: # can have other metrics + _target_: ConfusionMatrix + metric_name: 'f1 score' + output_transform: $monai.handlers.from_engine([@pred, @label]) + val_handlers: '@val_handlers' + +# gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately +metriclogger: + _target_: MetricLogger + evaluator: '@evaluator' + +handlers: +- '@metriclogger' +- _target_: CheckpointLoader + _disabled_: $not os.path.exists(@ckpt_path) + load_path: '@ckpt_path' + load_dict: + model: '@network' +- _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects + validator: '@evaluator' + epoch_level: true + interval: '@val_interval' +- _target_: CheckpointSaver + _disabled_: '@is_not_rank0' # only need rank 0 to save + save_dir: '@output_dir' + save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary + model: '@network' + logger: '@metriclogger' + save_interval: '@ckpt_interval' + save_final: true + epoch_level: true +- _target_: StatsHandler + name: null # use engine.logger as the Logger object to log to + tag_name: train_loss + output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value +- _target_: LogfileHandler # log outputs from the training engine + output_dir: '@output_dir' + +# engine for training, ties values defined above together into the main engine for the training process +trainer: + _target_: SupervisedTrainer + max_epochs: '@num_epochs' + device: '@device' + train_data_loader: '@train_dataloader' + network: '@network' + inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided + loss_function: '@lossfn' + optimizer: '@optimizer' + # postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing + key_train_metric: null + train_handlers: '@handlers' + +initialize: +- "$monai.utils.set_determinism(seed=123)" +- "$setattr(torch.backends.cudnn, 'benchmark', True)" +run: +- "$@trainer.run()" diff --git a/models/classification_template/docs/README.md b/models/classification_template/docs/README.md new file mode 100644 index 00000000..6a98a213 --- /dev/null +++ b/models/classification_template/docs/README.md @@ -0,0 +1,55 @@ +# Template Classification Bundle + +This bundle is meant to be an example of classification in 2D which you can copy and modify to create your own bundle. +It is only roughly trained for the synthetic data you can generate with [this notebook](./generate_data.ipynb) +so doesn't do anything useful on its own. The purpose is to demonstrate the base line for classification network bundles. + +To use this bundle, copy the contents of the whole directory and change the definitions for network, data, transforms, +or whatever else you want for your own new classification bundle. + +## Generating Demo Data + +Run all the cells of [this notebook](./generate_data.ipynb) to generate training and test data. These will be 2D +nifti files containing volumes with randomly generated circle, triangle or rectangle. The classification task +is very easy so your network will train in minutes with the default configuration of values. A test +data directory will separately be created since the inference config is configured to apply the network to +every nifti file in a given directory with a certain pattern. + +## Training + +To train a new network the `train.yaml` script can be used alone with no other arguments (assume `BUNDLE` is the root +directory of the bundle): + +``` +python -m monai.bundle run --config_file configs/train.yaml +``` + +The training config includes a number of hyperparameters like `learning_rate` and `num_workers`. These control aspects +of how training operates in terms of how many processes to use, when to perform validation, when to save checkpoints, +and other things. Other aspects of the script can be modified on the command line so these aren't exhaustive but are a +guide to the kind of parameterisation that make sense for a bundle. + +## Override the `train` config to execute multi-GPU training: + +``` +torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.yaml','configs/multi_gpu_train.yaml']" +``` + +Please note that the distributed training-related options depend on the actual running environment; thus, users may need to remove `--standalone`, modify `--nnodes`, or do some other necessary changes according to the machine used. For more details, please refer to [pytorch's official tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). + +## Override the `train` config to execute evaluation with the trained model: + +``` +python -m monai.bundle run --config_file "['configs/train.yaml','configs/evaluate.yaml']" +``` + +## Execute inference: + +``` +python -m monai.bundle run --config_file configs/inference.yaml +``` + +## Other Considerations + +There is no `scripts` directory containing a valid Python module to be imported in your configs. This wasn't necessary +for this bundle but if you want to include custom code in a bundle please follow the bundle tutorials on how to do this. diff --git a/models/classification_template/docs/generate_data.ipynb b/models/classification_template/docs/generate_data.ipynb new file mode 100644 index 00000000..3e69541d --- /dev/null +++ b/models/classification_template/docs/generate_data.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b1c9de9d-6777-4a1d-bb7c-c2413d01bd7d", + "metadata": {}, + "source": [ + "# Generate Data\n", + "\n", + "This bundle uses simple synthetic data for training and testing. Using `create_test_image_3d` we'll create images of spheres with labels for each divided into 3 classes distinguished by intensity. The network will be able to train very quickly on this of course but it's for demonstration purposes and your specialised bundle will by modified for your data and its layout. \n", + "\n", + "Assuming this notebook is being run from the `docs` directory it will create two new directories in the root of the bundle, `train_data` and `test_data`.\n", + "\n", + "First imports:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1e7cb4a8-f91a-4f15-a8aa-3136c2b954d6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import random\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import nibabel as nib\n", + "import numpy as np\n", + "\n", + "plt.rcParams[\"image.interpolation\"] = \"none\"" + ] + }, + { + "cell_type": "markdown", + "id": "2b2c3de5-01e5-4578-832b-b24a75d095d5", + "metadata": {}, + "source": [ + "As shown here, the images are spheres in a 3D volume with associated labels:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_images(image_size=128, border=20, shape_probabilities=None, shape_sizes=None):\n", + " image = np.zeros((image_size, image_size))\n", + "\n", + " if shape_probabilities is None:\n", + " shape_probabilities = [0.25, 0.2, 0.3, 0.25] # Default probabilities for circle, triangle, rectangle\n", + "\n", + " if shape_sizes is None:\n", + " shape_sizes = [(10, 30), (20, 40), (20, 40)] # Default size ranges for circle, triangle, rectangle\n", + "\n", + " def draw_zero(image):\n", + " return image\n", + "\n", + " def draw_circle(image):\n", + " center_x, center_y = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n", + " radius = np.random.randint(*shape_sizes[0])\n", + " y, x = np.ogrid[-center_x:image_size-center_x, -center_y:image_size-center_y]\n", + " mask = x ** 2 + y ** 2 <= radius ** 2\n", + " image[mask] = 1\n", + " return image\n", + "\n", + " def draw_triangle(image):\n", + " size = np.random.randint(*shape_sizes[1])\n", + " x1, y1 = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n", + " x2, y2 = x1 + size, y1\n", + " x3, y3 = x1 + size // 2, y1 - int(size * np.sqrt(3) / 2)\n", + " triangle = np.array([[x1, x2, x3], [y1, y2, y3]])\n", + " mask = plt.matplotlib.path.Path(np.transpose(triangle)).contains_points(\n", + " np.array([(i, j) for i in range(image_size) for j in range(image_size)])\n", + " )\n", + " image[mask.reshape(image_size, image_size)] = 1\n", + " return image\n", + "\n", + " def draw_rectangle(image):\n", + " x1, y1 = np.random.randint(border, image_size - border), np.random.randint(border, image_size - border)\n", + " x2, y2 = x1 + np.random.randint(*shape_sizes[2]), y1 + np.random.randint(*shape_sizes[2])\n", + " image[x1:x2, y1:y2] = 1\n", + " return image\n", + "\n", + " label, shape = random.choices([(0, draw_zero), (1, draw_circle), (2, draw_triangle), (3, draw_rectangle)], weights=shape_probabilities)[0]\n", + " image = shape(image)\n", + "\n", + " return image, label" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAHqCAYAAAAgWrY5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJ4ElEQVR4nO3dy27cOBRAQXHQ///LnMXAgTO2O51+8UiqWnlhwAQi5OCSeow559wAgJx/Vi8AAPieSANAlEgDQJRIA0CUSANAlEgDQJRIA0CUSANAlEgDQNTl1l8cY7xyHZzIvS+5cw3yLI+8aNF1yLPcch2apAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEg6rJ6AQBzzl8/jzEWroRn+/xvW1e89kQaWGZP/4HDCiINvJ04w21EGngbcYa/48YxAIgySQMvZ4KG+4g08DLiDI8RaeDpxBmeQ6SBpxFneC6RBh4mzvAaIg3cTZzhtUQa+GviDO/hOWkAiDJJAzcxPcP7maSBPxJoWMMkDfxInGEtkQa+EGdoEGlg2zZhhiJn0oBAQ5RIA0CUSAPbGGP1EoBvOJMGtm37PdS2v6HBJA18McYwXUOASAM/EmpYS6SBq0zVsI4zaeAmzqzh/UzSwF8zWcN7iDQARNnuBu7yMU3b+uYauy6PEWngIc6q4XVsdwNPY2qC5xJp4Kk8sgXPI9LAS4g1PE6kgZcSarifG8eAl3NzGdzHJA0AUSINvJWzaridSANLCDX8mUgDy5iq4To3jgHLCTV8zyQNAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBR3t292Jxz9RJ+5H3KAGuZpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEgSqQBIEqkASBKpAEg6rJ6AWc3xli9BACiTNIAEDXmnHP1IgCAr0zSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQdbn1F8cYr1wHJ3Lv11FdgzzLI1/odR3yLLdchyZpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIgSaQCIEmkAiBJpAIi6rF4A7MWcc/USfjTGWL0E4AVM0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQJdIAECXSABAl0gAQdVm9ANiLMcbqJQAnY5IGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCivBf1kzvnrZ6+ABGA1kzQARJmkt98naACoOHWkxRmAstNudws0AHWnmqSFGYA9OUWkxRmAPTp0pMUZgD077Zk0ANQdNtKmaAD27lDb3cIMwJEcItLiDMAR7TrS4gzAke32TFqgATi63UbaV6oAOLrdRnrb/gu1WANwVLuONAAc2SEibZoG4IgOEelts/UNwPHs+hGs73wOtTvAAdizw0zS3zFZA7Bnh470ttkGB2C/Dh/pD0INwN4c7kz6GufVAOzJaSZpANib00baWTUAdaeN9AehBqDq9JHeNlM1AE2nunHsT4QagBKTNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESNOedcvQgA4CuTNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARIk0AESJNABEiTQARF1u/cUxxivXwYnc+3VU1yDP8sgXel2HPMst16FJGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCiRBoAokQaAKIuqxcAAPeYc778b4wxXv43rjFJA0CUSANAlEgDQJRI32DO+ZazDwD4zI1jVwgzACuJ9DfEGYAC290AEGWS/sQEDUCJSG/iDEDTqSMtzgCUnTLS4gzAHpwq0uIMwJ6cItLiDMAeHTrS4gzAnnlOGgCiDhtpUzQAe3fYSK/+UDcAPOqwkd62/0It1gDs1aEj/UGoAdijQ9/d/dnnUDuvBmAPTjFJ/5/JGoA9OGWkAWAPThtpN5UBUHfaSH8QawCqTh/pD0INQM1p7u6+hTvAASgxSf/AZA3AaiJ9hfNqAFay3X0DoQZgBZM0AESJNABEiTQARIk0AESJNABEiTQARHkEC0JWvOnOI4bQZZIGgCiRBoAokQaAKJEGgCiRBoAokQaAKJEGgCjPSQOwS2d4xt8kDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUWPOOVcvAgD4yiQNAFEiDQBRIg0AUSINAFEiDQBRIg0AUSINAFEiDQBRIg0AUf8CeDXUWZWl5zcAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axes = plt.subplots(3, 3, figsize=(5, 5))\n", + "for i, ax in enumerate(axes.flatten()):\n", + " for j in range(9):\n", + " images, label = generate_images(128)\n", + " ax.imshow(images, cmap='gray')\n", + " ax.axis('off')\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8e08c4a1-6630-4ab3-832b-e53face81e35", + "metadata": {}, + "source": [ + "50 image/label pairs are now generated into the directory `../data/train_data`, assuming this notebook is run from the `docs` directory this will be in the bundle root:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "num_images = 50\n", + "out_dir = os.path.abspath(\"../data/train_data\")\n", + "os.makedirs(out_dir, exist_ok=True)\n", + "\n", + "train_data = []\n", + "for i in range(num_images):\n", + " data = {}\n", + " img, lbl = generate_images(128)\n", + " n = nib.Nifti1Image(img, np.eye(4))\n", + " train_file_path = os.path.join(out_dir, f\"img{i:02}.nii.gz\")\n", + " nib.save(n, train_file_path)\n", + "\n", + " data[\"image\"] = train_file_path\n", + " data[\"label\"] = lbl\n", + " train_data.append(data)\n", + "\n", + "with open(os.path.abspath(\"../data/train_samples.json\"), \"w\") as f:\n", + " json.dump(train_data, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "7fe344f7-d01d-49d5-adca-a7071939ca53", + "metadata": {}, + "source": [ + "We'll also generate some test data in a separate folder:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c3b8d8f3-8d73-4657-98f3-5605d4b1bad9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "num_images = 10\n", + "out_dir = os.path.abspath(\"../data/test_data\")\n", + "os.makedirs(out_dir, exist_ok=True)\n", + "\n", + "train_data = []\n", + "for i in range(num_images):\n", + " data = {}\n", + " img, lbl = generate_images(128)\n", + " n = nib.Nifti1Image(img, np.eye(4))\n", + " train_file_path = os.path.join(out_dir, f\"img{i:02}.nii.gz\")\n", + " nib.save(n, train_file_path)\n", + "\n", + " data[\"image\"] = train_file_path\n", + " data[\"label\"] = lbl\n", + " train_data.append(data)\n", + "\n", + "with open(os.path.abspath(\"../data/test_samples.json\"), \"w\") as f:\n", + " json.dump(train_data, f, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "599cff25-4894-481b-aec3-6aedda327a09", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "img00.nii.gz img02.nii.gz img04.nii.gz img06.nii.gz\timg08.nii.gz\n", + "img01.nii.gz img03.nii.gz img05.nii.gz img07.nii.gz\timg09.nii.gz\n" + ] + } + ], + "source": [ + "!ls {out_dir}" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/models/classification_template/large_files.yml b/models/classification_template/large_files.yml new file mode 100644 index 00000000..5b82bc39 --- /dev/null +++ b/models/classification_template/large_files.yml @@ -0,0 +1,5 @@ +large_files: + - path: "models/model.pt" + url: "https://drive.google.com/uc?id=1kClwSCzVzahn4OTVePLhbvW4vIOLKDlu" + hash_val: "915f54538655e9e6091c5d09dfdee621" + hash_type: "md5" diff --git a/models/lung_nodule_ct_detection/configs/metadata.json b/models/lung_nodule_ct_detection/configs/metadata.json index 6c4d2fba..4c3756c9 100644 --- a/models/lung_nodule_ct_detection/configs/metadata.json +++ b/models/lung_nodule_ct_detection/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json", - "version": "0.6.0", + "version": "0.6.1", "changelog": { + "0.6.1": "fix format error", "0.6.0": "remove meta_dict usage", "0.5.9": "use monai 1.2.0", "0.5.8": "update TRT memory requirement in readme", diff --git a/models/lung_nodule_ct_detection/scripts/detection_inferer.py b/models/lung_nodule_ct_detection/scripts/detection_inferer.py index 6fc93c3b..40174001 100644 --- a/models/lung_nodule_ct_detection/scripts/detection_inferer.py +++ b/models/lung_nodule_ct_detection/scripts/detection_inferer.py @@ -63,4 +63,4 @@ def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]) ) - return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs) + return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs)