Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 324 additions & 0 deletions docker/patch/latest/sglang_delta_compression.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
diff -urN tmp_sglang_orig/python/sglang/srt/managers/io_struct.py tmp_sglang_mod/python/sglang/srt/managers/io_struct.py
--- tmp_sglang_orig/python/sglang/srt/managers/io_struct.py 2026-04-07 01:29:33.170989783 +0000
+++ tmp_sglang_mod/python/sglang/srt/managers/io_struct.py 2026-04-07 01:30:22.328323959 +0000
@@ -1362,6 +1362,7 @@
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
+ sparse_metadata: Optional[List[Dict[str, Any]]] = None
# The group name
group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights
diff -urN tmp_sglang_orig/python/sglang/srt/managers/tp_worker.py tmp_sglang_mod/python/sglang/srt/managers/tp_worker.py
--- tmp_sglang_orig/python/sglang/srt/managers/tp_worker.py 2026-04-07 01:29:33.368983074 +0000
+++ tmp_sglang_mod/python/sglang/srt/managers/tp_worker.py 2026-04-07 01:30:32.446981062 +0000
@@ -148,6 +148,7 @@
recv_req.names,
recv_req.dtypes,
recv_req.shapes,
+ recv_req.sparse_metadata,
recv_req.group_name,
recv_req.load_format,
)
diff -urN tmp_sglang_orig/python/sglang/srt/model_executor/model_runner.py tmp_sglang_mod/python/sglang/srt/model_executor/model_runner.py
--- tmp_sglang_orig/python/sglang/srt/model_executor/model_runner.py 2026-04-07 01:29:33.592975483 +0000
+++ tmp_sglang_mod/python/sglang/srt/model_executor/model_runner.py 2026-04-07 01:31:33.304918738 +0000
@@ -13,6 +13,7 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""

+import contextlib
import datetime
import gc
import inspect
@@ -247,6 +248,9 @@

logger = logging.getLogger(__name__)

+_ORIGINAL_TENSOR_COPY = torch.Tensor.copy_
+_ORIGINAL_TENSOR_FILL = torch.Tensor.fill_
+

def resolve_language_model(model: nn.Module) -> nn.Module:
model_cls_name = model.__class__.__name__
@@ -1341,6 +1345,7 @@
names,
dtypes,
shapes,
+ sparse_metadata,
group_name,
load_format: Optional[str] = None,
):
@@ -1363,6 +1368,18 @@
return self._update_bucketed_weights_from_distributed(
names, dtypes, shapes, group_name
)
+ if load_format == "distributed_delta_sparse_indices":
+ return self._apply_sparse_delta_weights_from_distributed(
+ dtypes, shapes, sparse_metadata, group_name, transport="sparse_indices"
+ )
+ if load_format == "distributed_delta_sparse_bitmask":
+ return self._apply_sparse_delta_weights_from_distributed(
+ dtypes, shapes, sparse_metadata, group_name, transport="sparse_bitmask"
+ )
+ if load_format == "distributed_delta":
+ return self._apply_delta_weights_from_distributed(
+ names, dtypes, shapes, group_name
+ )
try:
weights = []
handles = []
@@ -1395,6 +1412,138 @@
logger.error(error_msg)
return False, error_msg

+ def _apply_delta_weights_from_distributed(
+ self, names, dtypes, shapes, group_name
+ ):
+ try:
+ weights = []
+ handles = []
+ for name, dtype, shape in zip(names, dtypes, shapes):
+ target_dtype = (
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ )
+ weight = torch.empty(shape, dtype=target_dtype, device=self.device)
+ handles.append(
+ torch.distributed.broadcast(
+ weight,
+ src=0,
+ group=self._model_update_group[group_name],
+ async_op=True,
+ )
+ )
+ weights.append((name, weight))
+ for handle in handles:
+ handle.wait()
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ self.model.load_weights(weights)
+ return True, "Succeeded to apply weight deltas online."
+
+ except Exception as e:
+ error_msg = (
+ f"Failed to apply weight deltas online: {e}. "
+ f"The model weights may be in an inconsistent state. "
+ f"Please discard the whole weights."
+ )
+ logger.error(error_msg)
+ return False, error_msg
+
+ def _apply_sparse_delta_weights_from_distributed(
+ self, dtypes, shapes, sparse_metadata, group_name, transport
+ ):
+ try:
+ encoded_tensors = []
+ handles = []
+ for dtype, shape in zip(dtypes, shapes):
+ target_dtype = (
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ )
+ encoded_tensor = torch.empty(shape, dtype=target_dtype, device=self.device)
+ handles.append(
+ torch.distributed.broadcast(
+ encoded_tensor,
+ src=0,
+ group=self._model_update_group[group_name],
+ async_op=True,
+ )
+ )
+ encoded_tensors.append(encoded_tensor)
+ for handle in handles:
+ handle.wait()
+ weights = []
+ scratch_buffer = None
+ scratch_numel = 0
+ if transport == "sparse_indices":
+ packed_indices = encoded_tensors[0].to(dtype=torch.long)
+ packed_values = encoded_tensors[1]
+ for meta in sparse_metadata:
+ target_dtype = getattr(torch, meta["dtype"])
+ target_shape = tuple(meta["shape"])
+ numel = int(meta["numel"])
+ if (
+ scratch_buffer is None
+ or scratch_buffer.dtype != target_dtype
+ or scratch_numel < numel
+ ):
+ scratch_buffer = torch.empty(
+ numel, dtype=target_dtype, device=self.device
+ )
+ scratch_numel = numel
+ decoded_flat = scratch_buffer[:numel]
+ decoded_flat.zero_()
+ index_start = int(meta["index_start"])
+ index_end = int(meta["index_end"])
+ value_start = int(meta["value_start"])
+ value_end = int(meta["value_end"])
+ if index_end > index_start:
+ decoded_flat.index_copy_(
+ 0,
+ packed_indices[index_start:index_end],
+ packed_values[value_start:value_end],
+ )
+ weights.append((meta["name"], decoded_flat.view(target_shape)))
+ elif transport == "sparse_bitmask":
+ packed_masks = encoded_tensors[0]
+ packed_values = encoded_tensors[1]
+ for meta in sparse_metadata:
+ target_dtype = getattr(torch, meta["dtype"])
+ target_shape = tuple(meta["shape"])
+ numel = int(meta["numel"])
+ if (
+ scratch_buffer is None
+ or scratch_buffer.dtype != target_dtype
+ or scratch_numel < numel
+ ):
+ scratch_buffer = torch.empty(
+ numel, dtype=target_dtype, device=self.device
+ )
+ scratch_numel = numel
+ decoded_flat = scratch_buffer[:numel]
+ decoded_flat.zero_()
+ mask_start = int(meta["mask_start"])
+ mask_end = int(meta["mask_end"])
+ value_start = int(meta["value_start"])
+ value_end = int(meta["value_end"])
+ unpacked_mask = _unpack_bitmask(
+ packed_masks[mask_start:mask_end], numel, self.device
+ )
+ decoded_flat[unpacked_mask] = packed_values[value_start:value_end]
+ weights.append((meta["name"], decoded_flat.view(target_shape)))
+ else:
+ raise ValueError(f"Unsupported sparse delta transport: {transport}")
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ self.model.load_weights(weights)
+ return True, "Succeeded to apply sparse weight deltas online."
+ except Exception as e:
+ error_msg = (
+ f"Failed to apply sparse weight deltas online: {e}. "
+ f"The model weights may be in an inconsistent state. "
+ f"Please discard the whole weights."
+ )
+ logger.error(error_msg)
+ return False, error_msg
+
def _update_bucketed_weights_from_distributed(
self, names, dtypes, shapes, group_name
):
@@ -1437,6 +1586,10 @@
return self._update_weights_from_flattened_bucket(
flattened_tensor_bucket_dict=named_tensors
)
+ if load_format == "flattened_bucket_delta":
+ return self._apply_weight_deltas_from_flattened_bucket(
+ flattened_tensor_bucket_dict=named_tensors
+ )

# We need to get device after patch otherwise the device would be wrong
self.device_module = torch.get_device_module(self.device)
@@ -1489,6 +1642,35 @@

return True, "Success"

+ def _apply_weight_deltas_from_flattened_bucket(
+ self,
+ flattened_tensor_bucket_dict,
+ ):
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
+ metadata = flattened_tensor_bucket_dict["metadata"]
+
+ converted_metadata = []
+ for meta in metadata:
+ converted_meta = FlattenedTensorMetadata(
+ name=meta.name,
+ shape=meta.shape,
+ dtype=meta.dtype,
+ start_idx=meta.start_idx,
+ end_idx=meta.end_idx,
+ numel=meta.numel,
+ )
+ converted_metadata.append(converted_meta)
+
+ bucket = FlattenedTensorBucket(
+ flattened_tensor=flattened_tensor, metadata=converted_metadata
+ )
+ delta_tensors = bucket.reconstruct_tensors()
+
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ self.model.load_weights(delta_tensors)
+ return True, "Success"
+
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
@@ -2718,6 +2900,67 @@
return True, "Success"


+@contextlib.contextmanager
+def _restore_weight_copy_context():
+ current_copy = torch.Tensor.copy_
+ current_fill = torch.Tensor.fill_
+ torch.Tensor.copy_ = _ORIGINAL_TENSOR_COPY
+ torch.Tensor.fill_ = _ORIGINAL_TENSOR_FILL
+ try:
+ yield
+ finally:
+ torch.Tensor.copy_ = current_copy
+ torch.Tensor.fill_ = current_fill
+
+
+@contextlib.contextmanager
+def _wrap_post_load_weights_with_original_copy_context(model):
+ original_post_load = getattr(model, "post_load_weights", None)
+ if original_post_load is None:
+ yield
+ return
+
+ def wrapped_post_load_weights(*args, **kwargs):
+ with _restore_weight_copy_context():
+ return original_post_load(*args, **kwargs)
+
+ model.post_load_weights = wrapped_post_load_weights
+ try:
+ yield
+ finally:
+ model.post_load_weights = original_post_load
+
+
+@contextlib.contextmanager
+def _additive_weight_copy_context():
+ original_copy_ = torch.Tensor.copy_
+ original_fill_ = torch.Tensor.fill_
+
+ def _additive_copy_(self, src, non_blocking=False):
+ self.add_(src.to(device=self.device, dtype=self.dtype))
+ return self
+
+ def _additive_fill_(self, value):
+ self.add_(value)
+ return self
+
+ torch.Tensor.copy_ = _additive_copy_
+ torch.Tensor.fill_ = _additive_fill_
+ try:
+ yield
+ finally:
+ torch.Tensor.copy_ = original_copy_
+ torch.Tensor.fill_ = original_fill_
+
+
+def _unpack_bitmask(packed, numel, device):
+ if numel == 0:
+ return torch.empty(0, dtype=torch.bool, device=device)
+ shifts = torch.arange(8, dtype=torch.uint8, device=device)
+ expanded = ((packed.unsqueeze(1) >> shifts) & 1).reshape(-1)
+ return expanded[:numel].to(dtype=torch.bool)
+
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
55 changes: 55 additions & 0 deletions slime/backends/megatron_utils/update_weight/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from argparse import Namespace
from collections.abc import Iterator, Sequence
from dataclasses import dataclass

import torch
import torch.distributed as dist
Expand All @@ -10,6 +11,60 @@

from slime.backends.megatron_utils.misc_utils import strip_param_name_prefix
from slime.utils.types import ParamInfo
from .delta_weight_update import DeltaCompressionCommitState


@dataclass
class HFUpdate:
tensors: list[tuple[str, torch.Tensor]]
load_format: str | None
commit_state: DeltaCompressionCommitState | None
transport_byte_size: int | None = None

@property
def should_send(self) -> bool:
return bool(self.tensors)

@property
def byte_size(self) -> int:
if self.transport_byte_size is not None:
return self.transport_byte_size
return sum(tensor.numel() * tensor.element_size() for _, tensor in self.tensors)


@dataclass
class PendingHFUpdateBucket:
tensors: list[tuple[str, torch.Tensor]]
commit_states: list[DeltaCompressionCommitState | None]
load_format: str | None
byte_size: int = 0

@classmethod
def empty(cls) -> "PendingHFUpdateBucket":
return cls(tensors=[], commit_states=[], load_format=None)

@property
def has_updates(self) -> bool:
return bool(self.tensors)

def should_flush_before_add(self, update: HFUpdate, byte_limit: int) -> bool:
if not self.has_updates:
return False
if self.load_format != update.load_format:
return True
return self.byte_size + update.byte_size > byte_limit

def add(self, update: HFUpdate) -> None:
self.tensors.extend(update.tensors)
self.commit_states.append(update.commit_state)
self.load_format = update.load_format
self.byte_size += update.byte_size

def clear(self) -> None:
self.tensors.clear()
self.commit_states.clear()
self.load_format = None
self.byte_size = 0


def all_gather_param(name: str, param: torch.nn.Parameter) -> torch.Tensor:
Expand Down
Loading
Loading