Rethinking Feature Extraction: Gradient-based Localized Feature Extraction for End-to-End Surgical Downstream Tasks
This repository contains the reference code for the paper "Rethinking Feature Extraction: Gradient-based Localized Feature Extraction for End-to-End Surgical Downstream Tasks"
To be added
- Clone the repository
git clone https://github.com/PangWinnie0219/GradCAMDownstreamTask.git
- Install the packages required using the
requirements.txt
file:
pip install -r requirements.txt
Note: Python 3.6 is required to run our code.
We are using the dataset from Cholec80 and Robotic Instrument Segmentation Dataset from MICCAI2018 Endoscopic Vision Challenge.
Cholec80 dataset: As the tissue label is required for captioning and interaction tasks, we added one extra label at the end of the original tool annotations of all samples, as shown in figure below. Since many types of tissues are present in the Cholec80 datasets (e.g. gallbladder, cystic plate and liver), the tissue label added in this work does not refer to the specific tissue but referring to the interacting tissue. For simplicity, we assume interacting tissue appears at all the frames in Cholec80 dataset.
Run python3.6 baseline.py
to start training the classification model. Ensure save
is set to True
as this checkpoint will be used for visualization and feature extraction later.
Otherwise, you can downloaded the trained model file:
- GC-A: [miccai2018_9class_ResNet50_256,320_32_lr_0.001_dropout_0.2_best_checkpoint.pth.tar] (To be added)
- GC-B: [miccai2018_9class_cholecResNet50_256,320_32_lr_0.001_dropout_0.2_best_checkpoint.pth.tar] (To be added)
- GC-C: [miccai2018_11class_cholec_ResNet50_256,320_32_lr_0.001_best_checkpoint.pth.tar] (To be added)
- GC-D: [combine_miccai18_ResNet50_256,320_170_best_checkpoint.pth.tar] (To be added)
Place the trained model file inside the ./best_model_checkpoints
.
cd into the utils
directory
cd utils
You can visualise the Grad-CAM heatmap and bounding box using
python3.6 miccai_bbox.py
In order to select a specific frame and heatmap of specific class, you can define them with bidx
and tclass
respectively. For example if you want to view the heatmap
for class 3 of the 15th image in the dataset, you can run the following:
python3.6 miccai_bbox.py --bidx 15 --tclass 3
The threshold, T_ROI can be defined using threshold
to see the effect of thresholding to the bounding box generation.
To be added
Set the result_filename in the code to accordingly if you are training the Grad-CAM model from scratch. If you are using our checkpoint, set gc_model
to 1
, 2
3
or 4
to load the checkpoint from GC-A, GC-B, GC-C and GC-D respectively. If you are using gc_model
= 1
or 2
, set cls
to 9, else, set cls
to 11.
This method is similar to the conventional feature extraction method. The region images will be cropped from the raw image and these cropped region images will be forwarded to the feature extractor.
- Crop the region images based on the predicted bounding box
python3.6 utils/crop_bbox.py
- Forward the cropped region image to the model again
python3.6 image_extract_feature.py
The features is extracted from the feature map of the classification model based on the bounding box coordinates.
python3.6 bbox_extract_feature.py
The features is extracted from the feature map of the classification model in a single-pass based on the heatmap (no bounding box generation).
python3.6 heatmap_extract_feature.py
Code adopted and modified from : pytorch-grad-cam
The features extracted can be used for the downstream task such as:
- Captioning
- Paper: Meshed-Memory Transformer for Image Captioning
- Official implementation code
- Interaction
- Paper: CogTree: Cognition Tree Loss for Unbiased Scene Graph Generation
- Official implementation code