Skip to content

mahmoodlab/MAMMOTH

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mammoth: Mixture of Mini Experts in Pathology

PyPI version Downloads License: CC BY-NC-ND 4.0

Mixture of Mini Experts: Overcoming the Linear Layer Bottleneck in Multiple Instance Learning, ICLR 2026.
Daniel Shao, Joel Runevic, Richard J. Chen, Drew F. K. Williamson, Ahrong Kim, Andrew H. Song*, Faisal Mahmood*

A parameter-efficient mixture of experts module for multiple instance learning in computational pathology
Paper | OpenReview | Citation


How does Mammoth work?

Mammoth architecture
 

Key Ideas

In Multiple Instance Learning (MIL) for whole-slide images, the standard pipeline is:

  1. Extract patch features (e.g. from a pretrained encoder),
  2. Transform them with a linear layer into task-specific patch features,
  3. Aggregate patches into a slide-level representation for classification.

Most works focus on (1) and (3). Mammoth explicitly targets (2): it replaces the single linear layer with a low-rank mixture of experts such that each patch gets a transformation tailored to its phenotype. This is done with a comparable number of parameters as the original linear layer.

  • Low-rank: Each expert is a factorized (LoRA-style) linear layer, keeping the parameter count close to a single linear layer.
  • Mixture of experts: Slot-based routing assigns each patch to a combination of experts; the final representation is a weighted combination of expert outputs.
  • Plug-and-play: Drop-in replacement for the patch embedding linear layer in any MIL method. Works with mean/max pooling, attention, CLAM, TransMIL, and others.
Main Findings (click to expand)
  • Improved performance: Across 8 MIL methods and 19 classification tasks, Mammoth improves performance in 130/152 configurations with an average +3.8% change, and often has a larger effect than the choice of aggregation method. Shown is the average performance per MIL method, averaged across all tasks
  • Structured Feature Space: Mammoth yields a structured feature space, with outputs forming distinct clusters per expert, and subclusters per slot.
  • Expert specialization: Mammoth experts focus on diverse morphological phenotypes, enabling context-specific processing
  • Mitigated Instance-Gradient Intereference: Heterogeneous instances yield conflicting gradient updates for the standard linear layer, which is mitigated by Mammoth's expert routing.

Overall performance

MIL Model Linear
(Morph, T=6)
Mammoth
(Morph, T=6)
Linear
(Molec, T=13)
MAMMOTH
(Molec, T=13)
ABMIL 75.2 78.4 72.8 74.6
CLAM 71.7 78.5 72.9 73.7
TransMIL 72.8 76.5 72.2 73.7
Transformer 73.5 77.5 71.8 74.2
ILRA 71.5 77.7 71.6 72.8
MeanMIL 72.5 77.0 72.6 74.5
MaxMIL 71.9 74.8 72.9 74.1
DSMIL 72.7 75.6 72.1 73.3

Shown is average performance for the standard linear layer vs. Mammoth across different MIL methods with UNI patch features. Balanced accuracy is reported for morphological subtyping tasks, and AUROC is reported for molecular subtyping tasks.

Installation

Install via pip

pip install mammoth-moe

or from source

git clone https://github.com/mahmoodlab/MAMMOTH.git
cd mammoth_draft
pip install -e .

The mammoth.py module only depends on:

  • PyTorch
  • einops

For quickstart instructions to use the MIL models in this repository, including environment setup, please see the MIL-Lab

Note

The pip package only contains the mammoth.py module. Install from source to use the full suite of MIL methods in this repository.


Minimal example: adding Mammoth to any MIL model

Mammoth is a drop-in replacement for the first linear layer that maps patch features to the dimension used by the rest of your MIL model. Below, a simple mean-pooling MIL model uses either a linear layer or Mammoth:

import torch
import torch.nn as nn
from mammoth import Mammoth

class MeanMIL(nn.Module):
    def __init__(self, in_dim, out_dim, num_classes, moe_args={}):
        super().__init__()
        if moe_args and moe_args.get('num_experts', 0) > 0:
            self.fc = Mammoth(**moe_args)
        else:
            self.fc = nn.Linear(in_dim, out_dim)
        self.classifier = nn.Linear(out_dim, num_classes)

    def forward(self, x):
        # x: (batch, num_patches, in_dim)
        x = self.fc(x)           # -> (batch, num_patches, out_dim)
        x = torch.mean(x, dim=1) # aggregate
        return self.classifier(x)


in_dim = 1024   # e.g. patch feature dimension from a backbone
dim = 512       # dimension for aggregation / classifier
num_classes = 3

# our recommended hyperparameters for MAMMOTH
moe_args = {
    "input_dim": in_dim,
    "dim": dim,
    "num_experts": 30,  
    "num_slots": 10,
    "num_heads": 16,
    "slot_dim": 256,
    "keep_slots": True,  # if True, return the E*S aggregated features instead of the N transformed patch features
    "share_lora_weights": True,  # share the weights of the first low rank layer 
    "dropout": 0.1,
    "auto_rank": True,   # automatically calculate the appropriate low rank for parameter efficiency
}

model = MeanMIL(in_dim, dim, num_classes, moe_args=moe_args)
x = torch.randn(2, 1000, in_dim)
logits = model(x)  # (2, num_classes)

Note

Mammoth is intended to be a drop-in replacement for the linear layer at comparable parameter counts. While num_experts, num_slots, and num_heads may be adjusted, we strongly recommend setting share_weights=True and auto_rank=True to automatically compute the appropriate layer sizes.


Viewing Expert Specialization

 

The routing scores for heatmaps can be generated via the parameter return_weights.

input = torch.randn(B, N, H * D)

# out is B (SE) (HD)
out = model.patch_embed(input)  

# routing_weights is B N E S H D
out, routing_weights = model.patch_embed(input, return_weights=True) 

For starter code to generate your own visualizations with the routing scores, please see this script

Full MIL models

Enabling Mammoth requires passing a moe_args dict with num_experts > 0 and the usual Mammoth arguments (num_experts, input_dim, dim, num_heads, etc.). If moe_args is empty or num_experts == 0, the model uses the original linear layer.

Example: ABMIL with Mammoth

from millab.src.models.abmil import ABMIL, ABMILGatedBaseConfig, ABMILModel

# minimal args needed to initialize MAMMOTH. This will create 30 experts, 16 heads, 10 slots/expert, weight sharing
moe_args = {
	"num_experts": 30
} 

config = ABMILGatedBaseConfig(
    in_dim=1024,
    embed_dim=512,
    num_classes=2,
    moe_args=moe_args,
)
model = ABMILModel(config)
# Forward: (B, M, D) patch features -> logits, loss, etc.

MIL models with mammoth can also be instantiated with MIL-Lab's create_model method by specifying the base_mammoth config:

from millab.src.builder import create_model

# standard abmil model with linear layer and uni's 1024 input dimension
create_model('abmil.base.uni', num_classes=5)

# use standard abmil with mammoth
create_model('abmil.base_mammoth.uni', num_classes=5)

# Specify the encoder to automatically update the input dimension
create_model('abmil.base_mammoth.conch_v15', num_classes=5) 

The following MIL implementations are available. This allows the patch_embed layer to be optionally equipped with Mammoth by passing moe_args into the model class, or with create_model.

Model Code Paper Model Class Initialization
ABMIL Link Link ABMILModel() create_model('abmil.base_mammoth')
TransMIL Link Link TransMILModel() create_model('transmil.base_mammoth')
Transformer Link Link TransformerModel() create_model('transformer.base_mammoth')
WiKG Link Link WIKGMILModel() create_model('wikg.base_mammoth')
DFTD Link Link DFTDModel() create_model('dftd.base_mammoth')
DSMIL Link Link DSMILModel() create_model('dsmil.base_mammoth')
ILRA Link Link ILRAModel() create_model('ilra.base_mammoth')
RRT Link Link RRTMILModel() create_model('rrt.base_mammoth')
CLAM Link Link CLAMModel() create_model('clam.base_mammoth')

Repository layout

Path Description
modules/mammoth.py Core Mammoth module: Mammoth, factorized experts, slot routing, and supporting layers
modules/components.py Shared utilities (e.g. ensure_batched) used by mammoth.py
MIL-Lab/src/models/ MIL model wrappers (ABMIL, CLAM, DSMIL, TransMIL, etc.) with optional Mammoth patch embedding
examples/tutorial_mammoth_visualization.py Expert dispatch heatmaps on WSIs using a saved Mammoth checkpoint
config/paths.py Central path mappings for tasks and WSI/feature directories

Issues

Funding

This work was funded by NIH NIGMS R35GM138216.

License and Terms of Use

ⓒ Mahmood Lab. This repository is released under the CC-BY-NC-ND 4.0 license and may only be used for non-commercial, academic research purposes with proper attribution. Any commercial use, sale, or other monetization of this repository is prohibited and requires prior approval. By downloading any pretrained encoder, you agree to follow the model's respective license.

Acknowledgements

The project was built on top of amazing repositories such as HuggingFace and open-source contributions for all MIL models from the community. We thank the authors and developers for their contribution.

Citation

If you use this code, the Mammoth method, or the MIL model implementations in your work, please cite:

@inproceedings{shao2026mammoth,
  title={Mixture of Mini Experts: Overcoming the Linear Layer Bottleneck in Multiple Instance Learning},
  author={Shao, Daniel and Runevic, Joel and Chen, Richard J. and Williamson, Drew F. K. and Kim, Ahrong and Song, Andrew H. and Mahmood, Faisal},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2026},
  url={https://openreview.net/forum?id=S5Io33pc78}
}

@inproceedings{shao2025do,
    title={Do Multiple Instance Learning Models Transfer?},
    author={Shao, Daniel and Chen, Richard J and Song, Andrew H and Runevic, Joel and Lu, Ming Y. and Ding, Tong and and Mahmood, Faisal},
    booktitle={International conference on machine learning},
    year={2025},
}

Releases

No releases published

Packages

 
 
 

Languages