Skip to content

[ICLR 2025] Mitigating Modality Prior-Induced Hallucinations in Multimodal Large Language Models via Deciphering Attention Causality

License

Notifications You must be signed in to change notification settings

The-Martyr/CausalMM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

[ICLR 2025] Mitigating Modality Prior-Induced Hallucinations in Multimodal Large Language Models via Deciphering Attention Causality

CausalMM

The official repo for CausalMM, a plug-and-play method for deciphering attention causality in MLLMs. Full paper can be found at: https://arxiv.org/abs/2410.04780.

🗓 To-Do List

  • ✅ Full code release (llava with vision)
  • ✅ Key code for editing attention released
  • ✅ Preprint of the paper released, check it here

Intro

We propose a framework called CausalMM to help MLLM/LVLM alleviate the multimodal hallucination caused by prior knowledge of visual, language and other modalities, with a maximum score improvement of 65.3% on 6 VLind-Bench indicators and 164 points on MME Benchmark compared to conventional methods.

Structural Causal Model

SCM

Environment Setup

cd env
conda env create -f causalmm_llava.yml
conda activate causalmm_llava.yml

Or you can simply build the LLaVA-1.5 environment following VCD.

Run CausalMM

POPE

  1. Download POPE bench (COCO2014 / AOKVQA to Dir CausalMM (like CausalMM/COCO/val2014).

  2. Download LLaVA-v1.5-7b to Dir CausalMM (like CausalMM/llava-v1.5-7b).

  3. Run the script (with one GPU <default 0>).

conda activate causalmm_llava
cd llava-1.5/experiments/scripts
CUDA_VISIBLE_DEVICES=0 bash llava1.5_eval.bash

Counterfactual Attention

Vision Counterfactual Attention

Code Sample

def edit_attention(self, attention_maps, method='shuffle'):
      batch_size, num_heads, height, width = attention_maps.shape #depends on how the vision encoder extracts attention

      if method == 'random':
            edited_attention_maps = torch.rand(batch_size, num_heads, height, width, device=attention_maps.device) * 2

      elif method == 'uniform':
            avg_value = torch.mean(attention_maps, dim=(2, 3), keepdim=True)
            edited_attention_maps = avg_value.expand(batch_size, num_heads, height, width)

      elif method == 'reversed':
            max_value_height, _ = torch.max(attention_maps, dim=2, keepdim=True)
            max_value, _ = torch.max(max_value_height, dim=3, keepdim=True)

            edited_attention_maps = max_value - attention_maps

      elif method == 'shuffle':
            edited_attention_maps = attention_maps.clone()
            for i in range(num_heads):
                  edited_attention_maps[:, i] = edited_attention_maps[:, i].view(batch_size, -1).gather(1, torch.randperm(height * width, device=attention_maps.device).expand(batch_size, -1)).view(batch_size, height, width)

      else:
            raise ValueError("Invalid method. Choose from ['random', 'uniform', 'reversed', 'shuffle']")

            return edited_attention_maps

The complete experimental code can be found in cf_encoder.

Visualization of Vision Counterfactual Attention

random
random
reverse
reverse
uniform
uniform
shuffle
shuffle

LLM Counterfactual Attention

Code Sample

def create_attention_mask(attention):
      bsz, num_heads, seq_len, _ = attention.size() #depends on the LLM decoder
      mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(attention.device)  
      return mask.unsqueeze(0).unsqueeze(0).expand(bsz, num_heads, -1, -1)  

def reverse_attention(attention):
      attention_mask = create_attention_mask(attention)
      max_values = attention.max(dim=-1, keepdim=True)[0]
      reversed_attention = max_values - attention
      reversed_attention = reversed_attention * (1 - attention_mask)  
      return reversed_attention

def normalize_attention(attention):
      attention_mask = create_attention_mask(attention)
      normalized_attention = attention / attention.sum(dim=-1, keepdim=True)
      normalized_attention = normalized_attention * (1 - attention_mask)  
      return normalized_attention

def reverse_and_normalize_attention(attention):
      reversed_attention = reverse_attention(attention)
      normalized_reversed_attention = normalize_attention(reversed_attention)
      return normalized_reversed_attention

def random_attention(attention):
      attention_mask = create_attention_mask(attention)
      random_attention = torch.rand_like(attention)
      normalized_random_attention = normalize_attention(random_attention)
      normalized_random_attention = normalized_random_attention * (1 - attention_mask)  
      return normalized_random_attention

def uniform_attention(attention):
      attention_mask = create_attention_mask(attention)
      uniform_attention = torch.ones_like(attention) / attention.size(-1)
      uniform_attention = uniform_attention * (1 - attention_mask)  
      return uniform_attention

def apply_counterfactual_attention(attention, method):
      if method == 'reverse':
            return reverse_attention(attention)
      elif method == 'reverse_and_normalize':
            return reverse_and_normalize_attention(attention)
      elif method == 'random':
            return random_attention(attention)
      elif method == 'uniform':
            return uniform_attention(attention)
      else:
            raise ValueError(f"Unknown method: {method}")

You can insert it directly in the modeling_qwen2_vl.py file of the transformers.

Visualization of LLM Counterfactual Attention

normal
normal
reverse
reverse
uniform
uniform
random
random

Modality Priors

If you want to learn more about the work on modal priors, click here.

Citation

Welcome to star our repo and cite our work:

@inproceedings{
  zhou2025mitigating,
  title={Mitigating Modality Prior-Induced Hallucinations in Multimodal Large Language Models via Deciphering Attention Causality},
  author={Guanyu Zhou and Yibo Yan and Xin Zou and Kun Wang and Aiwei Liu and Xuming Hu},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025}
}

Acknowledgement

  • LLaVA: Large Language and Vision Assistant
  • Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution
  • VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding
  • OPEAR: Alleviating Hallucination in Multi-Modal Large Language Models via Over-Trust Penalty and Retrospection-Allocation

About

[ICLR 2025] Mitigating Modality Prior-Induced Hallucinations in Multimodal Large Language Models via Deciphering Attention Causality

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published