From bbb4bfcf6e2d8db824fc2bf2be02f3c9883ab6a5 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Fri, 24 Oct 2025 02:58:07 -0700 Subject: [PATCH 1/4] make muon optimizer totally running on GPU Signed-off-by: Guokai Ma --- deepspeed/runtime/zero/stage_1_and_2.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 0bc1f939a4a2..0b312db8e39e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1878,7 +1878,7 @@ def get_flat_partition(self, tensor_list[0], 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): # need to check the total # of elements in the parameters in this group and this partition total_size = sum([t.numel() for t in tensor_list]) - flatten_bf_list = [torch.zeros([total_size], dtype=dtype)] # put on cpu to save space + flatten_bf_list = [torch.zeros([total_size], dtype=dtype, device=device)] self.optimizer.state[flatten_copy]["momentum_buffer"] = self.flatten(flatten_bf_list) buffer_idx = 0 @@ -1886,13 +1886,9 @@ def get_flat_partition(self, grad_accum = self.all_grad_tensors[param_group_idx][i] if getattr(tensor, 'use_muon', False) and 'muon' in self.optimizer.__class__.__name__.lower(): assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}" - # create a gpu copy buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, - tensor.numel()).view(tensor.size()).to(device).to(dtype) + tensor.numel()).view(tensor.size()) grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum']) - # write back to the cpu copy - torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx, - tensor.numel()).data.copy_(buffer.view(-1).data) tensor = grad_accum num_elements = tensor.numel() buffer_idx += num_elements From e02c0ec98bf07b85436195fbaac2728a7e0e39b9 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Fri, 24 Oct 2025 07:14:47 -0700 Subject: [PATCH 2/4] apply torch.compile to Muon optimizer Signed-off-by: Guokai Ma --- deepspeed/runtime/zero/muon/original_muon.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/muon/original_muon.py b/deepspeed/runtime/zero/muon/original_muon.py index 4f477882ec11..96f143ba62d8 100644 --- a/deepspeed/runtime/zero/muon/original_muon.py +++ b/deepspeed/runtime/zero/muon/original_muon.py @@ -31,6 +31,7 @@ import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check +@torch.compile def zeropower_via_newtonschulz5(G, steps: int): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -60,6 +61,7 @@ def zeropower_via_newtonschulz5(G, steps: int): return X +@torch.compile def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): momentum.lerp_(grad, 1 - beta) update = grad.lerp_(momentum, beta) if nesterov else momentum From 632ab6b2b088f059bab431c5748dae6caa783a58 Mon Sep 17 00:00:00 2001 From: Guokai Ma Date: Sun, 26 Oct 2025 20:06:36 -0700 Subject: [PATCH 3/4] make torch.compile more adaptive to old pytorch version Signed-off-by: Guokai Ma --- deepspeed/runtime/compiler.py | 12 ++++++++++++ deepspeed/runtime/zero/muon/original_muon.py | 5 +++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 8dcd6dab4e1d..84f2067a0dce 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -71,3 +71,15 @@ def wrapper(*args, **kwargs): def is_compiling(): return torch_is_compiling() + + +def dummy_decorator(func): + return func + + +# robust version of @torch.compile +def compile(): + if hasattr(torch, "compile"): + return torch.compile + else: + return dummy_decorator diff --git a/deepspeed/runtime/zero/muon/original_muon.py b/deepspeed/runtime/zero/muon/original_muon.py index 96f143ba62d8..f4dc7a0909bb 100644 --- a/deepspeed/runtime/zero/muon/original_muon.py +++ b/deepspeed/runtime/zero/muon/original_muon.py @@ -29,9 +29,10 @@ import torch import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check +from deepspeed.runtime import compiler -@torch.compile +@compiler.compile() def zeropower_via_newtonschulz5(G, steps: int): """ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a @@ -61,7 +62,7 @@ def zeropower_via_newtonschulz5(G, steps: int): return X -@torch.compile +@compiler.compile() def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): momentum.lerp_(grad, 1 - beta) update = grad.lerp_(momentum, beta) if nesterov else momentum From 3b6a6d92e493e761009fd3deebe1d3b23b2a00eb Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 26 Nov 2025 15:07:41 +0800 Subject: [PATCH 4/4] Fix trailing space Signed-off-by: Ma, Guokai --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 660b0d70de4f..5fc248dbfe72 100755 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ ## Latest News -* [2025/10] We hosted the [Ray x DeepSpeed Meetup](https://luma.com/3wctqteh) at Anyscale. We shared our most recent work on SuperOffload, ZenFlow, Muon Optimizer Support, Arctic Long Sequence Training and DeepCompile. Please find the meetup slides [here](https://docs.google.com/presentation/d/1eM3mY6oW9GYkRy1Xz0iOnbbEr5T1t0JJXOM5BKtR-Ks/edit?slide=id.g38615d6b4c2_0_87#slide=id.g38615d6b4c2_0_87). +* [2025/10] We hosted the [Ray x DeepSpeed Meetup](https://luma.com/3wctqteh) at Anyscale. We shared our most recent work on SuperOffload, ZenFlow, Muon Optimizer Support, Arctic Long Sequence Training and DeepCompile. Please find the meetup slides [here](https://docs.google.com/presentation/d/1eM3mY6oW9GYkRy1Xz0iOnbbEr5T1t0JJXOM5BKtR-Ks/edit?slide=id.g38615d6b4c2_0_87#slide=id.g38615d6b4c2_0_87). * [2025/10] [SuperOffload: Unleashing the Power of Large-Scale LLM Training on Superchips](https://pytorch.org/blog/superoffload-unleashing-the-power-of-large-scale-llm-training-on-superchips/)