diff --git a/examples/knowledge_distillation/README.md b/examples/knowledge_distillation/README.md new file mode 100644 index 0000000000..bfcee276c3 --- /dev/null +++ b/examples/knowledge_distillation/README.md @@ -0,0 +1,62 @@ +# Knowledge Distillation Example + +This example shows how to run **knowledge distillation (KD)** using slime. A student model learns from a teacher model by minimizing KL divergence on teacher-generated trajectories. + +## Key Features + +- **Online KD**: Teacher generates trajectories via external SGLang server (`--rm-url`) +- **Offline KD**: Load pre-saved teacher data from JSONL files (no teacher server needed) +- **Top-K KL**: Forward KL on teacher's top-K tokens (configurable via `KD_TOP_K`) + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `KD_TOP_K` | `8` | Top-K tokens for KL (0 = sampled KL on generated tokens) | +| `KD_TEMPERATURE` | `1.0` | Temperature for top-K KL loss | +| `KD_SAVE_PATH` | - | Save teacher data (supports `{rollout_id}` placeholder) | +| `KD_LOAD_PATH` | - | Load teacher data for offline KD | + +## Components + +- `knowledge_distillation.py`: Online rollout function that calls teacher server and optionally saves data +- `offline_kd.py`: Offline rollout function that loads pre-saved teacher data +- `kd_loss.py`: KD loss functions (top-K KL and sampled KL) + +## Data Format + +Teacher data is saved as JSONL with metadata header: + +```json +{"__metadata__": true, "distillation_type": "top_k", "top_k": 8} +{"prompt": "...", "tokens": [...], "response": "...", "response_length": 1024, "teacher_log_probs": [...], "teacher_top_k_ids": [[...]], "teacher_top_k_logprobs": [[...]]} +``` + +| Field | Description | +|-------|-------------| +| `tokens` | Full sequence (prompt + response token IDs) | +| `teacher_log_probs` | Teacher's log-prob for each response token | +| `teacher_top_k_ids` | Top-K token IDs at each position | +| `teacher_top_k_logprobs` | Top-K log-probs at each position | + +## Running the Example + +### Online KD + +1. Start teacher SGLang server: +```bash +python -m sglang.launch_server --model-path /path/to/teacher --port 13141 +``` + +2. Run training: +```bash +bash examples/knowledge_distillation/run-qwen3-1.7B-kd.sh +``` + +### Offline KD + +```bash +# First generate teacher data with online KD (set KD_SAVE_PATH) +# Then train from saved data: +bash examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh +``` diff --git a/examples/knowledge_distillation/__init__.py b/examples/knowledge_distillation/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/examples/knowledge_distillation/__init__.py @@ -0,0 +1 @@ + diff --git a/examples/knowledge_distillation/kd_loss.py b/examples/knowledge_distillation/kd_loss.py new file mode 100644 index 0000000000..22be273e0b --- /dev/null +++ b/examples/knowledge_distillation/kd_loss.py @@ -0,0 +1,103 @@ +import os + +import torch +from slime.backends.megatron_utils.loss import get_log_probs_and_entropy + +KD_TEMPERATURE = float(os.environ.get("KD_TEMPERATURE", "1.0")) +KD_TOP_K = int(os.environ.get("KD_TOP_K", "8")) + +_topk_data_store = {} + + +def store_topk_data(samples): + """Store top-k data indexed by token prefix for loss function retrieval.""" + global _topk_data_store + for group in samples: + for s in group: + if s.train_metadata: + _topk_data_store[tuple(s.tokens[:20])] = s.train_metadata + + +def _get_topk_data(tokens): + key = tuple(tokens[:20].tolist() if hasattr(tokens, "tolist") else tokens[:20]) + return _topk_data_store.get(key) + + +def sampled_kl_loss(args, batch, logits, sum_of_sample_mean): + """Forward KL on teacher-sampled tokens (KD_TOP_K=0).""" + _, log_probs_result = get_log_probs_and_entropy( + logits, + args=args, + unconcat_tokens=batch["unconcat_tokens"], + total_lengths=batch["total_lengths"], + response_lengths=batch["response_lengths"], + with_entropy=True, + max_seq_lens=batch.get("max_seq_lens"), + ) + student_lps = log_probs_result["log_probs"] + entropy = log_probs_result.get("entropy", []) + + kl_terms = [] + for s_lp, t_lp in zip(student_lps, batch["teacher_log_probs"], strict=False): + kl_terms.append(t_lp.to(s_lp) - s_lp) + + loss = sum_of_sample_mean(torch.cat(kl_terms)) + log = {"kd/loss": loss.detach()} + if entropy: + log["kd/entropy"] = sum_of_sample_mean(torch.cat(entropy)).detach() + return loss, log + + +def _extract_response_log_probs(logits, unconcat_tokens, total_lengths, response_lengths): + results = [] + packed = logits.shape[0] == 1 and len(unconcat_tokens) > 1 + offset = 0 + for i in range(len(unconcat_tokens)): + total_len, resp_len = int(total_lengths[i]), int(response_lengths[i]) + prompt_len = total_len - resp_len + if packed: + row = logits[0, offset + prompt_len - 1 : offset + total_len - 1] + offset += total_len + else: + row = logits[i, prompt_len - 1 : total_len - 1] + results.append(torch.log_softmax(row.float(), dim=-1)) + return results + + +def topk_kl_loss(args, batch, logits, sum_of_sample_mean): + """Forward KL on teacher's top-K tokens with temperature scaling.""" + student_full_lps = _extract_response_log_probs( + logits, + batch["unconcat_tokens"], + batch["total_lengths"], + batch["response_lengths"], + ) + + topk_data_list = [_get_topk_data(tokens) for tokens in batch["unconcat_tokens"]] + valid_data = [ + (s_lp, data) for s_lp, data in zip(student_full_lps, topk_data_list, strict=False) if data is not None + ] + + if not valid_data: + return sampled_kl_loss(args, batch, logits, sum_of_sample_mean) + + tau = KD_TEMPERATURE + kl_terms = [] + for s_lp, data in valid_data: + t_ids = torch.tensor(data["teacher_top_k_ids"], device=s_lp.device, dtype=torch.long) + t_lps = torch.tensor(data["teacher_top_k_logprobs"], device=s_lp.device, dtype=s_lp.dtype) + + s_topk = s_lp.gather(1, t_ids) + t_renorm = torch.log_softmax(t_lps / tau, dim=-1) + s_renorm = torch.log_softmax(s_topk / tau, dim=-1) + kl_terms.append((tau**2) * (t_renorm.exp() * (t_renorm - s_renorm)).sum(dim=-1)) + + loss = sum_of_sample_mean(torch.cat(kl_terms)) + return loss, {"kd/loss": loss.detach()} + + +def kd_loss_function(args, batch, logits, sum_of_sample_mean): + """KD loss: top-K KL if KD_TOP_K > 0, else sampled KL.""" + if KD_TOP_K > 0: + return topk_kl_loss(args, batch, logits, sum_of_sample_mean) + return sampled_kl_loss(args, batch, logits, sum_of_sample_mean) diff --git a/examples/knowledge_distillation/knowledge_distillation.py b/examples/knowledge_distillation/knowledge_distillation.py new file mode 100644 index 0000000000..c638f1207b --- /dev/null +++ b/examples/knowledge_distillation/knowledge_distillation.py @@ -0,0 +1,146 @@ +import asyncio +import json +import logging +import os + +import aiohttp +from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from slime.utils.async_utils import run +from slime.utils.processing_utils import load_tokenizer + +logger = logging.getLogger(__name__) + +KD_TOP_K = int(os.environ.get("KD_TOP_K", "8")) +KD_SAVE_PATH = os.environ.get("KD_SAVE_PATH") +TOKENIZER = None + + +def _get_tokenizer(args): + global TOKENIZER + if TOKENIZER is None: + TOKENIZER = load_tokenizer(args.hf_checkpoint, trust_remote_code=True) + return TOKENIZER + + +def _build_sampling_params(args): + return { + "temperature": args.rollout_temperature, + "top_p": args.rollout_top_p, + "top_k": args.rollout_top_k, + "max_new_tokens": args.rollout_max_response_len, + "stop": args.rollout_stop, + "stop_token_ids": args.rollout_stop_token_ids, + "skip_special_tokens": args.rollout_skip_special_tokens, + "no_stop_trim": True, + "spaces_between_special_tokens": False, + } + + +async def _generate_sample(args, sample, sampling_params, tokenizer, session, semaphore): + assert isinstance(sample.prompt, str), "KD rollout requires string prompts. Enable --apply-chat-template." + + prompt_ids = tokenizer.encode(sample.prompt, add_special_tokens=False) + payload = { + "input_ids": prompt_ids, + "sampling_params": sampling_params, + "return_logprob": True, + "logprob_start_len": 0, + } + if KD_TOP_K > 0: + payload["top_logprobs_num"] = KD_TOP_K + + async with semaphore: + async with session.post(args.rm_url, json=payload) as resp: + resp.raise_for_status() + output = await resp.json() + + meta = output["meta_info"] + generated = meta["output_token_logprobs"] + response_tokens = [int(item[1]) for item in generated] + + sample.tokens = prompt_ids + response_tokens + sample.response = output.get("text", "") + sample.response_length = len(response_tokens) + sample.teacher_log_probs = [float(item[0]) for item in generated] + sample.reward = 0.0 + + if KD_TOP_K > 0: + top_logprobs = meta["output_top_logprobs"] + sample.train_metadata = { + "teacher_top_k_ids": [[int(e[1]) for e in pos[:KD_TOP_K]] for pos in top_logprobs], + "teacher_top_k_logprobs": [[float(e[0]) for e in pos[:KD_TOP_K]] for pos in top_logprobs], + } + + return sample + + +def _save_to_jsonl(samples, save_path, rollout_id): + save_path = save_path.format(rollout_id=rollout_id) + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + + with open(save_path, "w") as f: + metadata = {"__metadata__": True, "distillation_type": "top_k" if KD_TOP_K > 0 else "sampled_kl"} + if KD_TOP_K > 0: + metadata["top_k"] = KD_TOP_K + f.write(json.dumps(metadata) + "\n") + + for group in samples: + for s in group: + record = { + "prompt": s.prompt, + "tokens": s.tokens, + "response": s.response, + "response_length": s.response_length, + "teacher_log_probs": s.teacher_log_probs, + } + if KD_TOP_K > 0 and s.train_metadata: + record["teacher_top_k_ids"] = s.train_metadata["teacher_top_k_ids"] + record["teacher_top_k_logprobs"] = s.train_metadata["teacher_top_k_logprobs"] + f.write(json.dumps(record) + "\n") + + logger.info(f"Saved {sum(len(g) for g in samples)} samples to {save_path}") + + +async def _generate_rollout_async(args, data_source): + assert args.rollout_global_dataset + + tokenizer = _get_tokenizer(args) + samples = data_source.get_samples(args.rollout_batch_size) + sampling_params = _build_sampling_params(args) + semaphore = asyncio.Semaphore(max(getattr(args, "sglang_server_concurrency", 64), 1)) + + async with aiohttp.ClientSession() as session: + generated_groups = await asyncio.gather( + *( + asyncio.gather( + *(_generate_sample(args, s, sampling_params, tokenizer, session, semaphore) for s in group) + ) + for group in samples + ) + ) + + first = generated_groups[0][0] + logger.info( + f"KD rollout: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}" + ) + + token_count = sum(s.response_length for g in generated_groups for s in g) + return RolloutFnTrainOutput(samples=generated_groups, metrics={"kd/token_count": token_count}) + + +def generate_rollout(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={}) + assert args.rm_url, "--rm-url must be set for KD rollout." + + result = run(_generate_rollout_async(args, data_source)) + + # Store top-k data for loss function access + if KD_TOP_K > 0: + from examples.knowledge_distillation.kd_loss import store_topk_data + + store_topk_data(result.samples) + + if KD_SAVE_PATH: + _save_to_jsonl(result.samples, KD_SAVE_PATH, rollout_id) + return result diff --git a/examples/knowledge_distillation/offline_kd.py b/examples/knowledge_distillation/offline_kd.py new file mode 100644 index 0000000000..f70d3a59ba --- /dev/null +++ b/examples/knowledge_distillation/offline_kd.py @@ -0,0 +1,96 @@ +import json +import logging +import os + +from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput +from slime.utils.types import Sample + +logger = logging.getLogger(__name__) + +KD_LOAD_PATH = os.environ.get("KD_LOAD_PATH") +KD_TOP_K = int(os.environ.get("KD_TOP_K", "8")) + + +def _load_from_jsonl(load_path, rollout_id, batch_size, num_rollouts_per_prompt): + load_path = load_path.format(rollout_id=rollout_id) + assert os.path.exists(load_path), f"KD data file not found: {load_path}" + + samples, metadata = [], None + with open(load_path) as f: + for line in f: + if not line.strip(): + continue + record = json.loads(line) + if record.get("__metadata__"): + metadata = record + continue + + resp_len = record["response_length"] + assert len(record["teacher_log_probs"]) == resp_len, f"Sample {len(samples)}: log_probs length mismatch" + + sample = Sample( + prompt=record["prompt"], + tokens=record["tokens"], + response=record["response"], + response_length=resp_len, + teacher_log_probs=record["teacher_log_probs"], + reward=0.0, + status=Sample.Status.COMPLETED, + ) + + if "teacher_top_k_ids" in record and "teacher_top_k_logprobs" in record: + assert ( + len(record["teacher_top_k_ids"]) == resp_len + ), f"Sample {len(samples)}: top_k_ids length mismatch" + assert ( + len(record["teacher_top_k_logprobs"]) == resp_len + ), f"Sample {len(samples)}: top_k_logprobs length mismatch" + sample.train_metadata = { + "teacher_top_k_ids": record["teacher_top_k_ids"], + "teacher_top_k_logprobs": record["teacher_top_k_logprobs"], + } + + samples.append(sample) + + # Validate metadata + assert metadata, f"Missing metadata in {load_path}" + file_type = metadata.get("distillation_type") + expected_type = "top_k" if KD_TOP_K > 0 else "sampled_kl" + assert file_type == expected_type, f"Type mismatch: file={file_type}, expected={expected_type}" + if file_type == "top_k": + assert metadata.get("top_k") == KD_TOP_K, f"Top-K mismatch: file={metadata.get('top_k')}, KD_TOP_K={KD_TOP_K}" + + total = batch_size * num_rollouts_per_prompt + assert len(samples) >= total, f"Not enough samples: got {len(samples)}, need {total}" + + grouped = [] + for i in range(batch_size): + group = samples[i * num_rollouts_per_prompt : (i + 1) * num_rollouts_per_prompt] + for j, s in enumerate(group): + s.group_index, s.index = i, j + grouped.append(group) + + logger.info(f"Loaded {len(samples)} samples from {load_path} ({file_type}, top_k={metadata.get('top_k')})") + return grouped + + +def generate_rollout(args, rollout_id, data_source, evaluation=False): + if evaluation: + return RolloutFnEvalOutput(data={}) + assert KD_LOAD_PATH, "KD_LOAD_PATH must be set for offline KD." + + samples = _load_from_jsonl(KD_LOAD_PATH, rollout_id, args.rollout_batch_size, args.n_samples_per_prompt) + + # Store top-k data for loss function access + if KD_TOP_K > 0: + from examples.knowledge_distillation.kd_loss import store_topk_data + + store_topk_data(samples) + + first = samples[0][0] + logger.info( + f"Offline KD: prompt={first.prompt[:80]!r}, response={first.response[:80]!r}, len={first.response_length}" + ) + + token_count = sum(s.response_length for g in samples for s in g) + return RolloutFnTrainOutput(samples=samples, metrics={"kd/token_count": token_count}) diff --git a/examples/knowledge_distillation/run-qwen3-1.7B-kd.sh b/examples/knowledge_distillation/run-qwen3-1.7B-kd.sh new file mode 100644 index 0000000000..884c4506cf --- /dev/null +++ b/examples/knowledge_distillation/run-qwen3-1.7B-kd.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +# usage: bash examples/knowledge_distillation/run-qwen3-1.7B-kd.sh +# +# Knowledge distillation: Qwen3-4B-Instruct-2507 (teacher) -> Qwen3-1.7B (student) +# Prerequisites: Start teacher SGLang server before running this script. + +set -ex + +MODEL_DIR="${KD_MODEL_DIR:-/workspace}" + +# Wait for teacher server +TEACHER_IP="127.0.0.1" +TEACHER_PORT=13141 + +until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health_generate > /dev/null; do + echo "Waiting for teacher server at $TEACHER_IP:$TEACHER_PORT..." + sleep 5 +done +echo "Teacher server ready." + +# Qwen3-1.7B model architecture +MODEL_ARGS=( + --swiglu + --num-layers 28 + --hidden-size 2048 + --ffn-hidden-size 6144 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 1000000 + --vocab-size 151936 + --kv-channels 128 + --qk-layernorm +) + +export PYTHONPATH=/root/Megatron-LM/ +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export KD_SAVE_PATH=$MODEL_DIR/kd_teacher_data/rollout_{rollout_id}.jsonl + +python3 train.py \ + ${MODEL_ARGS[@]} \ + --debug-train-only \ + --rollout-function-path examples.knowledge_distillation.knowledge_distillation.generate_rollout \ + --loss-type custom_loss \ + --custom-loss-function-path examples.knowledge_distillation.kd_loss.kd_loss_function \ + --calculate-per-token-loss \ + --rm-url http://$TEACHER_IP:$TEACHER_PORT/generate \ + --prompt-data $MODEL_DIR/dapo-math-17k/dapo-math-17k.jsonl \ + --input-key prompt \ + --label-key reward_model \ + --apply-chat-template \ + --rollout-batch-size 28 \ + --n-samples-per-prompt 1 \ + --rollout-max-response-len 1024 \ + --rollout-temperature 0.8 \ + --global-batch-size 28 \ + --num-rollout 100 \ + --train-backend megatron \ + --megatron-to-hf-mode bridge \ + --hf-checkpoint $MODEL_DIR/Qwen3-1.7B \ + --ref-load $MODEL_DIR/Qwen3-1.7B \ + --save $MODEL_DIR/KD-Qwen3-1.7B \ + --save-interval 20 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 7 \ + --colocate \ + --optimizer adam \ + --lr 1e-6 \ + --lr-decay-style constant \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.98 diff --git a/examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh b/examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh new file mode 100644 index 0000000000..ea8eff8282 --- /dev/null +++ b/examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +# usage: bash examples/knowledge_distillation/run-qwen3-1.7B-offline-kd.sh +# +# Offline knowledge distillation: Load pre-saved teacher data -> Qwen3-1.7B (student) +# Prerequisites: Run online KD with KD_SAVE_PATH first to generate teacher data. + +set -ex + +MODEL_DIR="${KD_MODEL_DIR:-/workspace}" + +# Qwen3-1.7B model architecture +MODEL_ARGS=( + --swiglu + --num-layers 28 + --hidden-size 2048 + --ffn-hidden-size 6144 + --num-attention-heads 16 + --group-query-attention + --num-query-groups 8 + --use-rotary-position-embeddings + --disable-bias-linear + --normalization "RMSNorm" + --norm-epsilon 1e-6 + --rotary-base 1000000 + --vocab-size 151936 + --kv-channels 128 + --qk-layernorm +) + +export PYTHONPATH=/root/Megatron-LM/ +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export KD_LOAD_PATH=$MODEL_DIR/kd_teacher_data/rollout_{rollout_id}.jsonl + +python3 train.py \ + ${MODEL_ARGS[@]} \ + --debug-train-only \ + --rollout-function-path examples.knowledge_distillation.offline_kd.generate_rollout \ + --loss-type custom_loss \ + --custom-loss-function-path examples.knowledge_distillation.kd_loss.kd_loss_function \ + --calculate-per-token-loss \ + --prompt-data $MODEL_DIR/dapo-math-17k/dapo-math-17k.jsonl \ + --input-key prompt \ + --label-key reward_model \ + --apply-chat-template \ + --rollout-batch-size 28 \ + --n-samples-per-prompt 1 \ + --rollout-max-response-len 1024 \ + --rollout-temperature 0.8 \ + --global-batch-size 28 \ + --num-rollout 100 \ + --train-backend megatron \ + --megatron-to-hf-mode bridge \ + --hf-checkpoint $MODEL_DIR/Qwen3-1.7B \ + --ref-load $MODEL_DIR/Qwen3-1.7B \ + --save $MODEL_DIR/KD-Qwen3-1.7B-offline \ + --save-interval 20 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 7 \ + --colocate \ + --optimizer adam \ + --lr 1e-6 \ + --lr-decay-style constant \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.98