Attention-aware LRP (AttnLRP) outperforms gradient-, decomposition- and perturbation-based methods, provides faithful attributions for the entirety of a black-box transformer model while scaling in computational complexitiy
Since we get relevance values for each single neuron in the model as a by-product, we know exactly how important each neuron is for the prediction of the model. Combined with Activation Maximization, we can label neurons or SAE features in LLMs and even steer the generation process of the LLM by activating specialized knowledge neurons in latent space!
For the mathematical details and foundational work, please take a look at our paper:
Achtibat, et al. “AttnLRP: Attention-Aware Layer-Wise Relevance Propagation for Transformers.” ICML 2024.
A collection of papers that have utilized LXT:
- Arras, et al. “Close Look at Decomposition-based XAI-Methods for Transformer Language Models.” arXiv preprint, 2025.
- Pan, et al. “The Hidden Dimensions of LLM Alignment: A Multi-Dimensional Safety Analysis.” arXiv preprint, 2025.
- Hu, et al. “LRP4RAG: Detecting Hallucinations in Retrieval-Augmented Generation via Layer-wise Relevance Propagation“ arXiv preprint, 2024.
- Sarti, et al. “Quantifying the Plausibility of Context Reliance in Neural Machine Translation.” ICLR 2024.
This project is licensed under the BSD-3 Clause License, which means that LRP is a patented technology that can only be used free of charge for personal and scientific purposes.
pip install lxt
Tested with: transformers==4.48.3
, torch==2.6.0
, python==3.11
You find example scripts in the examples/*
directory. For an in-depth tutorial, take a look at the Quickstart in the Documentation.
To get an overview, you can keep reading below ⬇️
Layer-wise Relevance Propagation is a rule-based backpropagation algorithm. This means, that we can implement LRP in a single backward pass! For this, LXT offers two different approaches:
Uses a Gradient*Input formulation, which simplifies LRP to a standard & fast gradient computation via monkey patching the model class.
from lxt.efficient import monkey_patch
# Patch module first
monkey_patch(your_module)
# Forward pass with gradient tracking
outputs = model(inputs_embeds=input_embeds.requires_grad_())
# Backward pass
outputs.logits[...].backward()
# Get relevance at *ANY LAYER* in your model. Simply multiply the gradient * activation!
# here for the input embeddings:
relevance = (input_embeds.grad * input_embeds).sum(-1)
This is the recommended approach for most users as it's significantly faster and easier to use. This implementation technique is introduced in Arras, et al. “Close Look at Decomposition-based XAI-Methods for Transformer Language Models.” arXiv preprint, 2025.
This was used in the original ICML 2024 paper. It's more complex and slower, but useful for understanding the mathematical foundations of LRP.
To achieve this, we have implemented custom PyTorch autograd Functions for commonly used operations in transformers. These functions behave identically in the forward pass, but substitute the gradient with LRP attributions in the backward pass. To compute the
import lxt.explicit.functional as lf
y = lf.linear_epsilon(x.requires_grad_(), W, b)
y.backward(y)
relevance = x.grad
There are also "super-functions" that wrap an arbitrary nn.Module and compute LRP rules via automatic vector-Jacobian products! These rules are simple to attach to models:
from lxt.explicit.core import Composite
import lxt.explicit.rules as rules
model = nn.Sequential(
nn.Linear(10, 10),
RootMeanSquareNorm(),
)
Composite({
nn.Linear: rules.EpsilonRule,
RootMeanSquareNorm: rules.IdentityRule,
}).register(model)
print(model)
Click here to read the documentation.
Feel free to explore the code and experiment with different datasets and models. We encourage contributions and feedback from the community. We are especially grateful for providing support for new model architectures! 🙏
@InProceedings{pmlr-v235-achtibat24a,
title = {{A}ttn{LRP}: Attention-Aware Layer-Wise Relevance Propagation for Transformers},
author = {Achtibat, Reduan and Hatefi, Sayed Mohammad Vakilzadeh and Dreyer, Maximilian and Jain, Aakriti and Wiegand, Thomas and Lapuschkin, Sebastian and Samek, Wojciech},
booktitle = {Proceedings of the 41st International Conference on Machine Learning},
pages = {135--168},
year = {2024},
editor = {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
volume = {235},
series = {Proceedings of Machine Learning Research},
month = {21--27 Jul},
publisher = {PMLR}
}
The code is heavily inspired by Zennit, a tool for LRP attributions in PyTorch using hooks. Zennit is 100% compatible with the explicit version of LXT and offers even more LRP rules 🎉