Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/knowledge_distillation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Knowledge Distillation Example

This example shows how to run **knowledge distillation (KD)** using slime. A student model learns from a teacher model by minimizing KL divergence on teacher-generated trajectories.

## Key Features

- **Online KD**: Teacher generates trajectories via external SGLang server (`--rm-url`)
- **Offline KD**: Load pre-saved teacher data from JSONL files (no teacher server needed)
- **Top-K KL**: Forward KL on teacher's top-K tokens (configurable via `KD_TOP_K`)

## Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `KD_TOP_K` | `8` | Top-K tokens for KL (0 = sampled KL on generated tokens) |
| `KD_TEMPERATURE` | `1.0` | Temperature for top-K KL loss |
| `KD_SAVE_PATH` | - | Save teacher data (supports `{rollout_id}` placeholder) |
| `KD_LOAD_PATH` | - | Load teacher data for offline KD |

## Components

- `knowledge_distillation.py`: Online rollout function that calls teacher server and optionally saves data
- `offline_kd.py`: Offline rollout function that loads pre-saved teacher data
- `kd_loss.py`: KD loss functions (top-K KL and sampled KL)

## Data Format

Teacher data is saved as JSONL with metadata header:

```json
{"__metadata__": true, "distillation_type": "top_k", "top_k": 8}
{"prompt": "...", "tokens": [...], "response": "...", "response_length": 1024, "teacher_log_probs": [...], "teacher_top_k_ids": [[...]], "teacher_top_k_logprobs": [[...]]}
```

| Field | Description |
|-------|-------------|
| `tokens` | Full sequence (prompt + response token IDs) |
| `teacher_log_probs` | Teacher's log-prob for each response token |
| `teacher_top_k_ids` | Top-K token IDs at each position |
| `teacher_top_k_logprobs` | Top-K log-probs at each position |

## Running the Example

### Online KD

1. Start teacher SGLang server:
```bash
python -m sglang.launch_server --model-path /path/to/teacher --port 13141
```

2. Run training:
```bash
bash examples/knowledge_distillation/run-qwen3-1.7B-kd.sh
```

### Offline KD

```bash
# First generate teacher data with online KD (set KD_SAVE_PATH)
# Then train from saved data:
bash examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh
```
1 change: 1 addition & 0 deletions examples/knowledge_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

103 changes: 103 additions & 0 deletions examples/knowledge_distillation/kd_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

import torch
from slime.backends.megatron_utils.loss import get_log_probs_and_entropy

KD_TEMPERATURE = float(os.environ.get("KD_TEMPERATURE", "1.0"))
KD_TOP_K = int(os.environ.get("KD_TOP_K", "8"))

_topk_data_store = {}


def store_topk_data(samples):
"""Store top-k data indexed by token prefix for loss function retrieval."""
global _topk_data_store
for group in samples:
for s in group:
if s.train_metadata:
_topk_data_store[tuple(s.tokens[:20])] = s.train_metadata


def _get_topk_data(tokens):
key = tuple(tokens[:20].tolist() if hasattr(tokens, "tolist") else tokens[:20])
return _topk_data_store.get(key)


def sampled_kl_loss(args, batch, logits, sum_of_sample_mean):
"""Forward KL on teacher-sampled tokens (KD_TOP_K=0)."""
_, log_probs_result = get_log_probs_and_entropy(
logits,
args=args,
unconcat_tokens=batch["unconcat_tokens"],
total_lengths=batch["total_lengths"],
response_lengths=batch["response_lengths"],
with_entropy=True,
max_seq_lens=batch.get("max_seq_lens"),
)
student_lps = log_probs_result["log_probs"]
entropy = log_probs_result.get("entropy", [])

kl_terms = []
for s_lp, t_lp in zip(student_lps, batch["teacher_log_probs"], strict=False):
kl_terms.append(t_lp.to(s_lp) - s_lp)

loss = sum_of_sample_mean(torch.cat(kl_terms))
log = {"kd/loss": loss.detach()}
if entropy:
log["kd/entropy"] = sum_of_sample_mean(torch.cat(entropy)).detach()
return loss, log


def _extract_response_log_probs(logits, unconcat_tokens, total_lengths, response_lengths):
results = []
packed = logits.shape[0] == 1 and len(unconcat_tokens) > 1
offset = 0
for i in range(len(unconcat_tokens)):
total_len, resp_len = int(total_lengths[i]), int(response_lengths[i])
prompt_len = total_len - resp_len
if packed:
row = logits[0, offset + prompt_len - 1 : offset + total_len - 1]
offset += total_len
else:
row = logits[i, prompt_len - 1 : total_len - 1]
results.append(torch.log_softmax(row.float(), dim=-1))
return results


def topk_kl_loss(args, batch, logits, sum_of_sample_mean):
"""Forward KL on teacher's top-K tokens with temperature scaling."""
student_full_lps = _extract_response_log_probs(
logits,
batch["unconcat_tokens"],
batch["total_lengths"],
batch["response_lengths"],
)

topk_data_list = [_get_topk_data(tokens) for tokens in batch["unconcat_tokens"]]
valid_data = [
(s_lp, data) for s_lp, data in zip(student_full_lps, topk_data_list, strict=False) if data is not None
]

if not valid_data:
return sampled_kl_loss(args, batch, logits, sum_of_sample_mean)

tau = KD_TEMPERATURE
kl_terms = []
for s_lp, data in valid_data:
t_ids = torch.tensor(data["teacher_top_k_ids"], device=s_lp.device, dtype=torch.long)
t_lps = torch.tensor(data["teacher_top_k_logprobs"], device=s_lp.device, dtype=s_lp.dtype)

s_topk = s_lp.gather(1, t_ids)
t_renorm = torch.log_softmax(t_lps / tau, dim=-1)
s_renorm = torch.log_softmax(s_topk / tau, dim=-1)
kl_terms.append((tau**2) * (t_renorm.exp() * (t_renorm - s_renorm)).sum(dim=-1))

loss = sum_of_sample_mean(torch.cat(kl_terms))
return loss, {"kd/loss": loss.detach()}


def kd_loss_function(args, batch, logits, sum_of_sample_mean):
"""KD loss: top-K KL if KD_TOP_K > 0, else sampled KL."""
if KD_TOP_K > 0:
return topk_kl_loss(args, batch, logits, sum_of_sample_mean)
return sampled_kl_loss(args, batch, logits, sum_of_sample_mean)
146 changes: 146 additions & 0 deletions examples/knowledge_distillation/knowledge_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import asyncio
import json
import logging
import os

import aiohttp
from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput
from slime.utils.async_utils import run
from slime.utils.processing_utils import load_tokenizer

logger = logging.getLogger(__name__)

KD_TOP_K = int(os.environ.get("KD_TOP_K", "8"))
KD_SAVE_PATH = os.environ.get("KD_SAVE_PATH")
TOKENIZER = None


def _get_tokenizer(args):
global TOKENIZER
if TOKENIZER is None:
TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True)
return TOKENIZER


def _build_sampling_params(args):
return {
"temperature": args.rollout_temperature,
"top_p": args.rollout_top_p,
"top_k": args.rollout_top_k,
"max_new_tokens": args.rollout_max_response_len,
"stop": args.rollout_stop,
"stop_token_ids": args.rollout_stop_token_ids,
"skip_special_tokens": args.rollout_skip_special_tokens,
"no_stop_trim": True,
"spaces_between_special_tokens": False,
}


async def _generate_sample(args, sample, sampling_params, tokenizer, session, semaphore):
assert isinstance(sample.prompt, str), "KD rollout requires string prompts. Enable --apply-chat-template."

prompt_ids = tokenizer.encode(sample.prompt, add_special_tokens=False)
payload = {
"input_ids": prompt_ids,
"sampling_params": sampling_params,
"return_logprob": True,
"logprob_start_len": 0,
}
if KD_TOP_K > 0:
payload["top_logprobs_num"] = KD_TOP_K

async with semaphore:
async with session.post(args.rm_url, json=payload) as resp:
resp.raise_for_status()
output = await resp.json()

meta = output["meta_info"]
generated = meta["output_token_logprobs"]
response_tokens = [int(item[1]) for item in generated]

sample.tokens = prompt_ids + response_tokens
sample.response = output.get("text", "")
sample.response_length = len(response_tokens)
sample.teacher_log_probs = [float(item[0]) for item in generated]
sample.reward = 0.0

if KD_TOP_K > 0:
top_logprobs = meta["output_top_logprobs"]
sample.train_metadata = {
"teacher_top_k_ids": [[int(e[1]) for e in pos[:KD_TOP_K]] for pos in top_logprobs],
"teacher_top_k_logprobs": [[float(e[0]) for e in pos[:KD_TOP_K]] for pos in top_logprobs],
}

return sample


def _save_to_jsonl(samples, save_path, rollout_id):
save_path = save_path.format(rollout_id=rollout_id)
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)

with open(save_path, "w") as f:
metadata = {"__metadata__": True, "distillation_type": "top_k" if KD_TOP_K > 0 else "sampled_kl"}
if KD_TOP_K > 0:
metadata["top_k"] = KD_TOP_K
f.write(json.dumps(metadata) + "\n")

for group in samples:
for s in group:
record = {
"prompt": s.prompt,
"tokens": s.tokens,
"response": s.response,
"response_length": s.response_length,
"teacher_log_probs": s.teacher_log_probs,
}
if KD_TOP_K > 0 and s.train_metadata:
record["teacher_top_k_ids"] = s.train_metadata["teacher_top_k_ids"]
record["teacher_top_k_logprobs"] = s.train_metadata["teacher_top_k_logprobs"]
f.write(json.dumps(record) + "\n")

logger.info(f"Saved {sum(len(g) for g in samples)} samples to {save_path}")


async def _generate_rollout_async(args, data_source):
assert args.rollout_global_dataset

tokenizer = _get_tokenizer(args)
samples = data_source.get_samples(args.rollout_batch_size)
sampling_params = _build_sampling_params(args)
semaphore = asyncio.Semaphore(max(getattr(args, "sglang_server_concurrency", 64), 1))

async with aiohttp.ClientSession() as session:
generated_groups = await asyncio.gather(
*(
asyncio.gather(
*(_generate_sample(args, s, sampling_params, tokenizer, session, semaphore) for s in group)
)
for group in samples
)
)

first = generated_groups[0][0]
logger.info(
f"KD rollout: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}"
)

token_count = sum(s.response_length for g in generated_groups for s in g)
return RolloutFnTrainOutput(samples=generated_groups, metrics={"kd/token_count": token_count})


def generate_rollout(args, rollout_id, data_source, evaluation=False):
if evaluation:
return RolloutFnEvalOutput(data={})
assert args.rm_url, "--rm-url must be set for KD rollout."

result = run(_generate_rollout_async(args, data_source))

# Store top-k data for loss function access
if KD_TOP_K > 0:
from examples.knowledge_distillation.kd_loss import store_topk_data

store_topk_data(result.samples)

if KD_SAVE_PATH:
_save_to_jsonl(result.samples, KD_SAVE_PATH, rollout_id)
return result
Loading