Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to visualize attention map for model with multiple heads attention (e.g., vilbert) #917

Closed
CCYChongyanChen opened this issue Apr 30, 2021 · 6 comments

Comments

@CCYChongyanChen
Copy link

CCYChongyanChen commented Apr 30, 2021

❓ Questions and Help

Overall goal: I am trying to extract visual attention map from vilbert to explore where the vilbert is looking at the image.

My question

Question 1:
I know vilbert has three kinds of attention: image attention, text attention, and co-attention. I don't know if I should go with image attention or co-attention. Currently, I go with image attention.
Question 2:
I know for image attention, it outputs 6 vectors, each of the vector has a size (1,8,100,100). I would like to know (1) what does the 8, 100, 100 represent. (2) which vector should I select (3) and how can I visualize attention map with the image attention weights.

My understanding for Question 2:
According to https://github.com/facebookresearch/mmf/blob/3947693aafcc9cc2a16d7c1c5e1479bf0f88ed4b/mmf/configs/models/vilbert/defaults.yaml, it seems that 8 represents the number of attention heads. My guessing is 1 represents the batch size (I changed the batch size to 1), 100 is the image width and height.
If that is correct, then my question 2 becomes "how to deal with multiple attention heads?"

Possible solution for Question2:
I know how to visualize attention map if the attention weights are 1d array or 2d array....For 4d, I am not sure if it makes sense to directly use squeeze() to transform 4d into 2d for visualization. Or I should average multi-heads attention to get 2D attention weights?

Other questions

(1) I am worried about the way they represent the image in transformers makes it impossible to visualize the image attention map for vilbert:
image

(2) I got two image attention weights from Pythia, which one should I use for visualization?

Thank you in advance!

@CCYChongyanChen CCYChongyanChen changed the title How to visualize attention map for vilbert How to visualize attention map for model with multiple heads attention (e.g., vilbert) Apr 30, 2021
@hackgoofer
Copy link
Contributor

Hi @CCYChongyanChen, thanks for using mmf,

Most of what you said seems correct to me. In Vilbert, the image representation we used are from pre-extracted features on bboxes detected from object detection. The attention map on the image will be more on the bounding box level, ie., which bounding box was more/less attended to as opposed to pixel level. Alternatively, you can also use Vilbert with grid feature extractor which will give you more details in terms of the attention coverage.

Let me know if this helps.

@CCYChongyanChen
Copy link
Author

Hi @CCYChongyanChen, thanks for using mmf,

Most of what you said seems correct to me. In Vilbert, the image representation we used are from pre-extracted features on bboxes detected from object detection. The attention map on the image will be more on the bounding box level, ie., which bounding box was more/less attended to as opposed to pixel level. Alternatively, you can also use Vilbert with grid feature extractor which will give you more details in terms of the attention coverage.

Let me know if this helps.

Thanks! Do you have the pretrained VilBERT model with grid feature extractor? @ytsheng

@CCYChongyanChen
Copy link
Author

image
image
image

Finally got this work...For each image, VilBERT will generate 8(heads)*100 attention maps.

@apsdehal
Copy link
Contributor

@CCYChongyanChen if possible can you share the code for future users?

@CCYChongyanChen
Copy link
Author

CCYChongyanChen commented Jul 16, 2021

@apsdehal
Definitely.

Step 1: Allow visualization

(1) set output_all_attention_masks=True following #507

(2) set visualization=true in the config file following #507

Step 2: Edit the prediction_loop function in the evaluation_loop.py

def prediction_loop(self, dataset_type: str) -> None:

with torch.no_grad():
    self.model.eval()
    logger.info(f"Starting {dataset_type} inference predictions")
    dataloader = reporter.get_dataloader()
    I=iter(dataloader)
    next(I)
    batch=next(I)
    bboxs=batch.image_info_0.bbox.numpy()[0][:,:-1]            # the last one is cls score. Size: (100,5)
    img_index= str(batch.image_id.item())
    prepared_batch = reporter.prepare_batch(batch)
    prepared_batch = to_device(prepared_batch, torch.device("cuda"))
    model_output = self.model(prepared_batch)
    if model_name=="vilbert":
        att_n=(model_output["attention_weights"][1][1]).detach().cpu().clone().numpy()[0,0,:,:] 
    if dataset=="COCO":
        img_path="/data/vqa/datasets/COCO/test2015/COCO_test2015_000000"+img_index+".jpg"
        img = cv2.imread(img_path)
    for i in range(0, 100):  
        self.visualize_pred(img_path,bboxs,att_n[:,i])

Step 3: Visualization!

Visualize the attention map following #145 using the def attention_bbox_interpolation(im, bboxes, att) and def visualize_pred(im_path, boxes, att_weights) function

Notice that there is a small adjustment for the attention_bbox_interpolation function since the output attention of VilBERT ranges from (0,1). We should resize it to image size.
for bbox, weight in zip(bboxes, softmax): x1, y1, x2, y2 = bbox opacity[int(y1*img_h):int(y2*img_h), int(x1*img_w):int(x2*img_w)]+= weight

@HireTheHero
Copy link

FYI for some other models #1052 looks useful for exporting attention weights. Checked with mmbt and visual bert.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants