Skip to content

2. Training the Model and Making Predictions

Joshua Levy edited this page Aug 14, 2019 · 3 revisions

Now it's time to train the model. To do this, we'll use the pathflowai-train_model module. The main command within this module is train_model.

Let's start off with a classification task:

pathflowai-train_model train_model -h

This exposes a whole plethora of options, many of which are devoted to tasks such as class balancing, or even irrelevant tasks that are likely tangential to the needs of your workflow. Thus, let's focus on a few with an example:

CUDA_VISIBLE_DEVICES=0 pathflowai-train_model train_model --patch_size 512 -pr 224 --save_location outcomes_model.pkl -a resnet34 --input_dir inputs/ -nt 1 -t 10000 -lr 1e-4 -ne 10 -ss 0.5 -ssv 0.3 -tt 0.1 -bt 0.01 -imb -pi patch_information.db -bs 32 -ca

If we collapse the parenchyma annotations with the background (either adjusting the SQL or the initial masks), we can focus on a binary classification of portal. We specify one target (-nt) then, while this can be expanded to multiple depending on the number of classes (though we have two classes, we can specify -nt to be 1 or 2 in this case). Our SQL db (-pi) has a patch size "512" table, which contains the patch information of where to access the patches from the ZARR files, and we pull these patches out by specifying 512 as the patch size (--patch_size). We can resize (-pr) it during training to accommodate input into the model/architecture ResNet34. We feed in training, validation and test samples with with a batch size (-bs) of 32 and train for 10 epochs (-ne). Every training epoch, I'm choosing to randomly subsample 10k images (-t) to speed up training, but ahead of time before I run any training, I can subsample the training and validation sets ahead of time using -ss for training and -ssv for validation. -imb balances the classes by downsampling the majority classes during training, while -imb2 balances through weighting the training loss function by inverse class prevalence. If classifying on patch level annotations (-ca), the area of a patch from which to call the patch the main class is given by -bt, and -tc, -tt, -ov used in conjunction provide mechanisms to custom oversample a patch class label in question. We train with a learning rate of -lr 1e-4, but this is modulated throughout the training process.

Train, Validation, Test Split and Whole Slide Image Labels

The dataset is usually automatically split into training, validation, and testing during training, stored in train_val_test.pkl, but this can be custom supplied or modified. Alternatively, supplying the --dataset_df (e.g. slide_labels.csv) with --target_names (eg. -tg Slide_LabelB -tg Slide_LabelD) will incorporate slide level labels into the prediction model (--pos_annotation_class and --other_annotations can force these slide level annotations that are positive to be positive for the pos_annotation_class, while slides that are of the other annotations remain 0 and other slides are removed, if activated; else all patches that have a positive WSI label will be positive and then when making the predictions, some thresholding must be used to rate an entire slide as positive for that label). This CSV can also be used to split the data into train val test sets.

slide_labels.csv:

,ID,Slide_LabelA,Slide_LabelB,Slide_LabelC,Slide_LabelD,Slide_LabelE,Slide_LabelF,set
0,A1,0.0,0.0,0.0,0.0,0.0,0.0,train
1,A2,0.0,0.0,0.0,0.0,0.0,0.0,train
2,A3,0.0,0.0,0.0,0.0,0.0,0.0,train
3,A4,0.0,0.0,0.0,0.0,0.0,0.0,train
4,A5,0.0,0.0,0.0,0.0,0.0,0.0,train
5,A6,0.0,0.0,0.0,0.0,0.0,0.0,train
6,A7,0.0,0.0,0.0,0.0,0.0,0.0,train
7,A8,1.0,1.0,1.0,0.0,0.0,0.0,train
8,A9,0.0,0.0,0.0,0.0,0.0,0.0,train
9,A10,0.0,0.0,0.0,0.0,0.0,0.0,train
10,A11,0.0,0.0,0.0,0.0,0.0,0.0,train
11,A12,1.0,1.0,1.0,0.0,0.0,0.0,train
12,A13,0.0,0.0,0.0,0.0,0.0,0.0,train
13,A14,0.0,0.0,0.0,0.0,0.0,0.0,train
14,A15,0.0,0.0,0.0,0.0,0.0,0.0,train
15,A16,0.0,0.0,0.0,0.0,0.0,0.0,train
16,A17,0.0,0.0,0.0,0.0,0.0,0.0,train
17,A18,0.0,0.0,0.0,0.0,0.0,0.0,train
18,A19,1.0,1.0,0.0,0.0,1.0,1.0,train
19,A20,0.0,0.0,0.0,0.0,0.0,0.0,train
20,A21,0.0,0.0,0.0,0.0,0.0,0.0,val
21,A22,0.0,0.0,0.0,0.0,0.0,0.0,val
22,A23,1.0,1.0,0.0,0.0,1.0,1.0,val
23,A24,0.0,0.0,0.0,0.0,0.0,0.0,val
24,A25,1.0,1.0,0.0,1.0,0.0,0.0,val
25,A26,0.0,0.0,0.0,0.0,0.0,0.0,val
26,A27,0.0,0.0,0.0,0.0,0.0,0.0,val
27,A28,0.0,0.0,0.0,0.0,0.0,0.0,val
28,A29,0.0,0.0,0.0,0.0,0.0,0.0,val

Saved Model and Transfer Learning

The model is saved at --save_location outcomes_model.pkl. If wishing to transfer learn and apply the model to a new dataset, access this model via --pretrained_save_location and use a new --save_location.

Segmentation Models

Training the model for segmentation tasks is quite similar to the classification task, just additionally supplying the -s flag and changing --num_targets to the ideal number of labels in the segmentation mask files.

Prediction

The only change that must be made to make a prediction on the test set is adding the --prediction option:

CUDA_VISIBLE_DEVICES=0 pathflowai-train_model train_model --prediction --patch_size 512 -pr 224 --save_location outcomes_model.pkl -a resnet34 --input_dir inputs/ -nt 1 -t 10000 -lr 1e-4 -ne 10 -ss 0.5 -ssv 0.3 -tt 0.1 -bt 0.01 -imb -pi patch_information.db -bs 32 -ca

For classification tasks, this creates a new SQL database (predictions.db) with the patch information and the prediction information (probabilities and predicted class). For segmentation tasks, it stores new [basename]_predict.npy segmentation masks into the --prediction_output_dir directory.

External testing datasets replace the original set aside training sets from your data, and can be predicted on using the --external_test_db and --external_test_dir options together.

Model Extraction

To extract the model to send off for deployment, possibly in someone else's compute set-up, add to the --prediction flag the --extract_model flag, which pickles the entire model, and appends the --save_location name you had denoted. This extracted model can also be used for the model interpretation SHAP module.

Extract Embeddings

Although this is really necessary for the embedding and SHAP steps, for brevity, for classification tasks, replacing --extract_model with --extract_embedding stores 1000-d (can change with Issue/PR) vectors for each patch of the prediction set into a dataframe, stored in a pickle file. These embeddings can be useful for generating class separation of patches using the visualization module.

Architectures

There are a large variety of architectures one can choose from and are listed here (-a; some are classification others segmentation, note that not all architectures have been tested, but we can remedy this through posting of an issue):
['alexnet','densenet121','densenet161','densenet169',
'densenet201','inception_v3','resnet101','resnet152',
'resnet18','resnet34','resnet50','vgg11','vgg11_bn',
'unet','unet2','nested_unet','fast_scnn','vgg13',
'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19','vgg19_bn',
'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101',
'fcn_resnet50','efficientnet-b0','efficientnet-b1',
'efficientnet-b2','efficientnet-b3','efficientnet-b4',
'efficientnet-b5','efficientnet-b6','efficientnet-b7']

Clone this wiki locally