Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

## Latest News

* [2025/10] We hosted the [Ray x DeepSpeed Meetup](https://luma.com/3wctqteh) at Anyscale. We shared our most recent work on SuperOffload, ZenFlow, Muon Optimizer Support, Arctic Long Sequence Training and DeepCompile. Please find the meetup slides [here](https://docs.google.com/presentation/d/1eM3mY6oW9GYkRy1Xz0iOnbbEr5T1t0JJXOM5BKtR-Ks/edit?slide=id.g38615d6b4c2_0_87#slide=id.g38615d6b4c2_0_87).
* [2025/10] We hosted the [Ray x DeepSpeed Meetup](https://luma.com/3wctqteh) at Anyscale. We shared our most recent work on SuperOffload, ZenFlow, Muon Optimizer Support, Arctic Long Sequence Training and DeepCompile. Please find the meetup slides [here](https://docs.google.com/presentation/d/1eM3mY6oW9GYkRy1Xz0iOnbbEr5T1t0JJXOM5BKtR-Ks/edit?slide=id.g38615d6b4c2_0_87#slide=id.g38615d6b4c2_0_87).

* [2025/10] [SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips](https://pytorch.org/blog/superoffload-unleashing-the-power-of-large-scale-llm-training-on-superchips/)

Expand Down
12 changes: 12 additions & 0 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,15 @@ def wrapper(*args, **kwargs):

def is_compiling():
return torch_is_compiling()


def dummy_decorator(func):
return func


# robust version of @torch.compile
def compile():
if hasattr(torch, "compile"):
return torch.compile
else:
return dummy_decorator
3 changes: 3 additions & 0 deletions deepspeed/runtime/zero/muon/original_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@

import torch
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
from deepspeed.runtime import compiler


@compiler.compile()
def zeropower_via_newtonschulz5(G, steps: int):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
Expand Down Expand Up @@ -60,6 +62,7 @@ def zeropower_via_newtonschulz5(G, steps: int):
return X


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
Expand Down
8 changes: 2 additions & 6 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,21 +1901,17 @@ def get_flat_partition(self,
tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
# need to check the total # of elements in the parameters in this group and this partition
total_size = sum([t.numel() for t in tensor_list])
flatten_bf_list = [torch.zeros([total_size], dtype=dtype)] # put on cpu to save space
flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)]
self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list)

buffer_idx = 0
for i, tensor in enumerate(tensor_list):
grad_accum = self.all_grad_tensors[param_group_idx][i]
if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower():
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
# create a gpu copy
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size()).to(device).to(dtype)
tensor.numel()).view(tensor.size())
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
# write back to the cpu copy
torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).data.copy_(buffer.view(-1).data)
tensor = grad_accum
num_elements = tensor.numel()
buffer_idx += num_elements
Expand Down