|
58 | 58 | "from monai.networks.layers import Norm\n",
|
59 | 59 | "from monai.networks.nets import UNet\n",
|
60 | 60 | "from monai.losses import DiceLoss\n",
|
61 |
| - "from monai.metrics import compute_meandice" |
| 61 | + "from monai.metrics import compute_meandice\n", |
| 62 | + "from monai.utils import set_determinism" |
62 | 63 | ]
|
63 | 64 | },
|
64 | 65 | {
|
|
146 | 147 | " return epoch_num, time.time() - total_start, epoch_loss_values, metric_values, epoch_times"
|
147 | 148 | ]
|
148 | 149 | },
|
149 |
| - { |
150 |
| - "cell_type": "markdown", |
151 |
| - "metadata": {}, |
152 |
| - "source": [ |
153 |
| - "## Define deterministic training for reproducibility" |
154 |
| - ] |
155 |
| - }, |
156 |
| - { |
157 |
| - "cell_type": "code", |
158 |
| - "execution_count": 3, |
159 |
| - "metadata": {}, |
160 |
| - "outputs": [], |
161 |
| - "source": [ |
162 |
| - "def set_deterministic(train_transforms):\n", |
163 |
| - " train_transforms.set_random_state(seed=0)\n", |
164 |
| - " torch.manual_seed(0)\n", |
165 |
| - " torch.backends.cudnn.deterministic = True\n", |
166 |
| - " torch.backends.cudnn.benchmark = False\n", |
167 |
| - " return train_transforms" |
168 |
| - ] |
169 |
| - }, |
170 | 150 | {
|
171 | 151 | "cell_type": "markdown",
|
172 | 152 | "metadata": {},
|
|
185 | 165 | },
|
186 | 166 | {
|
187 | 167 | "cell_type": "code",
|
188 |
| - "execution_count": 4, |
| 168 | + "execution_count": 3, |
189 | 169 | "metadata": {},
|
190 | 170 | "outputs": [],
|
191 | 171 | "source": [
|
|
220 | 200 | },
|
221 | 201 | {
|
222 | 202 | "cell_type": "code",
|
223 |
| - "execution_count": 5, |
| 203 | + "execution_count": 4, |
224 | 204 | "metadata": {},
|
225 | 205 | "outputs": [],
|
226 | 206 | "source": [
|
227 |
| - "train_transforms = Compose([\n", |
228 |
| - " LoadNiftid(keys=['image', 'label']),\n", |
229 |
| - " AddChanneld(keys=['image', 'label']),\n", |
230 |
| - " Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n", |
231 |
| - " Orientationd(keys=['image', 'label'], axcodes='RAS'),\n", |
232 |
| - " ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n", |
233 |
| - " CropForegroundd(keys=['image', 'label'], source_key='image'),\n", |
234 |
| - " # randomly crop out patch samples from big image based on pos / neg ratio\n", |
235 |
| - " # the image centers of negative samples must be in valid image area\n", |
236 |
| - " RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n", |
237 |
| - " neg=1, num_samples=4, image_key='image', image_threshold=0),\n", |
238 |
| - " ToTensord(keys=['image', 'label'])\n", |
239 |
| - "])\n", |
| 207 | + "def transformations():\n", |
| 208 | + " train_transforms = Compose([\n", |
| 209 | + " LoadNiftid(keys=['image', 'label']),\n", |
| 210 | + " AddChanneld(keys=['image', 'label']),\n", |
| 211 | + " Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n", |
| 212 | + " Orientationd(keys=['image', 'label'], axcodes='RAS'),\n", |
| 213 | + " ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n", |
| 214 | + " CropForegroundd(keys=['image', 'label'], source_key='image'),\n", |
| 215 | + " # randomly crop out patch samples from big image based on pos / neg ratio\n", |
| 216 | + " # the image centers of negative samples must be in valid image area\n", |
| 217 | + " RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,\n", |
| 218 | + " neg=1, num_samples=4, image_key='image', image_threshold=0),\n", |
| 219 | + " ToTensord(keys=['image', 'label'])\n", |
| 220 | + " ])\n", |
240 | 221 | "\n",
|
241 |
| - "# NOTE: No random cropping in the validation data, we will evaluate the entire image using a sliding window.\n", |
242 |
| - "val_transforms = Compose([\n", |
243 |
| - " LoadNiftid(keys=['image', 'label']),\n", |
244 |
| - " AddChanneld(keys=['image', 'label']),\n", |
245 |
| - " Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n", |
246 |
| - " Orientationd(keys=['image', 'label'], axcodes='RAS'),\n", |
247 |
| - " ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n", |
248 |
| - " CropForegroundd(keys=['image', 'label'], source_key='image'),\n", |
249 |
| - " ToTensord(keys=['image', 'label'])\n", |
250 |
| - "])" |
| 222 | + " # NOTE: No random cropping in the validation data, we will evaluate the entire image using a sliding window.\n", |
| 223 | + " val_transforms = Compose([\n", |
| 224 | + " LoadNiftid(keys=['image', 'label']),\n", |
| 225 | + " AddChanneld(keys=['image', 'label']),\n", |
| 226 | + " Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=(3, 0)),\n", |
| 227 | + " Orientationd(keys=['image', 'label'], axcodes='RAS'),\n", |
| 228 | + " ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),\n", |
| 229 | + " CropForegroundd(keys=['image', 'label'], source_key='image'),\n", |
| 230 | + " ToTensord(keys=['image', 'label'])\n", |
| 231 | + " ])\n", |
| 232 | + " return train_transforms, val_transforms" |
251 | 233 | ]
|
252 | 234 | },
|
253 | 235 | {
|
254 | 236 | "cell_type": "markdown",
|
255 | 237 | "metadata": {},
|
256 | 238 | "source": [
|
257 |
| - "## `Dataset`: Define, train and validate\n", |
| 239 | + "## Enable deterministic training and regular `Dataset`\n", |
258 | 240 | "\n",
|
259 | 241 | "Load each original dataset and transform each time it is needed."
|
260 | 242 | ]
|
|
265 | 247 | "metadata": {},
|
266 | 248 | "outputs": [],
|
267 | 249 | "source": [
|
268 |
| - "train_transforms = set_deterministic(train_transforms)\n", |
269 |
| - "train_ds = Dataset(data=train_files, transform=train_transforms)\n", |
270 |
| - "val_ds = Dataset(data=val_files, transform=val_transforms)\n", |
| 250 | + "set_determinism(seed=0)\n", |
| 251 | + "train_trans, val_trans = transformations()\n", |
| 252 | + "train_ds = Dataset(data=train_files, transform=train_trans)\n", |
| 253 | + "val_ds = Dataset(data=val_files, transform=val_trans)\n", |
271 | 254 | "\n",
|
272 | 255 | "epoch_num, total_time, epoch_loss_values, metric_values, epoch_times = train_process(train_ds, val_ds)\n",
|
273 | 256 | "print('Total training time of {} epochs with regular Dataset: {:.4f}'.format(epoch_num, total_time))"
|
|
277 | 260 | "cell_type": "markdown",
|
278 | 261 | "metadata": {},
|
279 | 262 | "source": [
|
280 |
| - "## `PersistentDataset`: Define, train and validate\n", |
| 263 | + "## Enable deterministic training and `PersistentDataset`\n", |
281 | 264 | "\n",
|
282 | 265 | "Use persistent storage of non-random transformed training and validation data computed once and stored in persistently across runs"
|
283 | 266 | ]
|
|
291 | 274 | "persistent_cache: Path = Path(\"./persistent_cache\")\n",
|
292 | 275 | "persistent_cache.mkdir(parents=True, exist_ok=True)\n",
|
293 | 276 | "\n",
|
294 |
| - "train_transforms = set_deterministic(train_transforms)\n", |
295 |
| - "train_persitence_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache)\n", |
296 |
| - "val_persitence_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache)\n", |
| 277 | + "set_determinism(seed=0)\n", |
| 278 | + "train_trans, val_trans = transformations()\n", |
| 279 | + "train_persitence_ds = PersistentDataset(data=train_files, transform=train_trans, cache_dir=persistent_cache)\n", |
| 280 | + "val_persitence_ds = PersistentDataset(data=val_files, transform=val_trans, cache_dir=persistent_cache)\n", |
297 | 281 | "\n",
|
298 | 282 | "persistence_epoch_num, persistence_total_time, persistence_epoch_loss_values, \\\n",
|
299 | 283 | " persistence_metric_values, persistence_epoch_times = \\\n",
|
|
306 | 290 | "cell_type": "markdown",
|
307 | 291 | "metadata": {},
|
308 | 292 | "source": [
|
309 |
| - "## `CacheDataset`: Define, train and validate\n", |
| 293 | + "## Enable deterministic training and `CacheDataset`\n", |
310 | 294 | "\n",
|
311 | 295 | "Precompute all non-random transforms of original data and store in memory."
|
312 | 296 | ]
|
|
317 | 301 | "metadata": {},
|
318 | 302 | "outputs": [],
|
319 | 303 | "source": [
|
320 |
| - "train_transforms = set_deterministic(train_transforms)\n", |
| 304 | + "set_determinism(seed=0)\n", |
| 305 | + "train_trans, val_trans = transformations()\n", |
321 | 306 | "cache_init_start = time.time()\n",
|
322 |
| - "cache_train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)\n", |
323 |
| - "cache_val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)\n", |
| 307 | + "cache_train_ds = CacheDataset(data=train_files, transform=train_trans, cache_rate=1.0, num_workers=4)\n", |
| 308 | + "cache_val_ds = CacheDataset(data=val_files, transform=val_trans, cache_rate=1.0, num_workers=4)\n", |
324 | 309 | "cache_init_time = time.time() - cache_init_start\n",
|
325 | 310 | "\n",
|
326 | 311 | "cache_epoch_num, cache_total_time, cache_epoch_loss_values, cache_metric_values, cache_epoch_times = \\\n",
|
|
0 commit comments