Skip to content

The code for our paper on interpreting process reward models

Notifications You must be signed in to change notification settings

somvy/prm_interp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This repository contains code for our paper "Out of Distribution, Out of Luck: Process Rewards Misguide Reasoning Models", with reward model inference, linear probes training, sparse autoencoder training and analysis.

Trained SAEs are available at: https://huggingface.co/therem/qwen-prm-saes

Installation

# Install dependencies using uv
uv sync

# Activate the virtual environment
source .venv/bin/activate

Environment Variables

Configure paths using environment variables, example in example.env:

export REWARD_SAES_DATA_DIR="./data"
export SAE_CHECKPOINT_PATH="./checkpoints/sae.pt"
export SAE_CHECKPOINT_DIR="./checkpoints"
export WANDB_API_KEY="your_key"      
export HF_TOKEN="your_token"                    

Quick Start

I use invoke to manage runs, see main commands below or in the tasks.py file.

Infer models on datasets

MODEL_NAME can be any model supported by HuggingFace, e.g., ,mistralai/Mathstral-7B-v0.1,AI-MO/NuminaMath-7B-CoT,deepseek-ai/deepseek-math-7b-instruct,Qwen/Qwen3-8B,deepseek-ai/DeepSeek-R1-Distill-Llama-8B, openai/gpt-oss-20b

The script will infer the model on the specified datasets (math-ai/minervamath,HuggingFaceH4/MATH-500,math-ai/olympiadbench) by default and save the generations in a json file.

invoke infer --model_name MODEL_NAME  [[ --datasets math-ai/minervamath,HuggingFaceH4/MATH-500,math-ai/olympiadbench ]]

Note: for the inference of the gpt-oss, you may need to update the libs:

python -m pip install --upgrade transformers kernels accelerate "triton>=3.4" torchvision vllm

Infer model with PRM

invoke infer-with-prm --model_name MODEL_NAME 

Calculate metrics on the inferred datasets

Pass the path to the json file generated in the previous step:

invoke metrics --dataset data/deepseek_llama_math-ai_minervamath_beam_search_prm.json

Save activations with rewards

Save rewards model activations on the specified generations along with reward scores to train probes. Will create a lot of npz files in the data directory.

invoke save-acts --ds-name "data/deepseek_ai_deepseek_math_7b_instructHuggingFaceH4_MATH_500.jsonl"

Optional parameters:

  • --base-path: Base path for data (default: ./data)
  • --ds-split: Dataset split (default: full)
  • --mode: Mode for saving (default: rm)
  • --limit: Maximum number of samples (default: 10000)

Train probes

invoke train-linreg --ds-name "math-rm-deepseek_ai_deepseek_math_7b_instructHuggingFaceH4_MATH_500.jsonl"

Read saved activations and train linear regression probes to predict reward scores, save metrics. Optional parameters:

  • --layers: Layer range to train on (default: 0-27, can also be a single layer like 15)
  • --data-type: Type of activations (default: hiddens)
  • --base-dir: Base directory for data (default: ./data)

Figures

The code for plotting figures from the paper is in figures folder.

Training a SAE

Train simultaneously multiple SAEs on different layers of a model (this way it is faster then training them one by one).

python src/sparse/pt/multi_sae.py 

For the fill list of arguments, check the script or run with --help.

Generate a dashboard with feature visualizations (on which tokens features are most active)

invoke dash

For the parameters, check out the src/dash/create.py

Feature Scoring

Compare SAE feature activations between reasoning and non-reasoning datasets, find which features are more active on which dataset.

invoke feature-diff --part 1 --total 10

List of parameters:

  • --total: Total number of parts, of pairs (reasoning vs non-reasoning datasets) to process
  • --part: Part index

License

MIT License

About

The code for our paper on interpreting process reward models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages