Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 54 additions & 73 deletions nemo_rl/models/policy/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@
create_context_parallel_ctx,
get_train_context,
)
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel.components.distributed.grad_utils import (
clip_grad_by_total_norm_,
get_grad_norm,
)
from nemo_automodel.components.distributed.parallelizer import (
fsdp2_strategy_parallelize,
)
from nemo_automodel.components.distributed.tensor_utils import (
get_cpu_state_dict,
to_local_if_dtensor,
Expand All @@ -53,7 +51,6 @@
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
OffloadPolicy,
)
from torch.distributed.tensor import DTensor, Shard
from transformers import (
Expand Down Expand Up @@ -143,8 +140,10 @@ def __init__(
configure_dynamo_cache()

self.cfg = config
self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"]
# torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call
torch.distributed.init_process_group(backend="nccl")
# backend = "nccl" if not self.cpu_offload else "cuda:nccl,cpu:gloo"
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo")
self.rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
model_name = self.cfg["model_name"]
Expand Down Expand Up @@ -175,6 +174,20 @@ def __init__(

hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {}

# Choose attention implementation
# - Packed sequence requires FA2 and CP must be 1
# - CP > 1 requires SDPA
attn_impl = (
"flash_attention_2"
if (
self.enable_seq_packing
and self.cfg["dtensor_cfg"]["context_parallel_size"] == 1
)
else (
"sdpa" if self.cfg["dtensor_cfg"]["context_parallel_size"] > 1 else None
)
)

model_config = AutoConfig.from_pretrained(
model_name,
# Always load the model in float32 to keep master weights in float32.
Expand All @@ -184,9 +197,7 @@ def __init__(
**sliding_window_overwrite(
model_name
), # due to https://github.com/huggingface/transformers/issues/38002
attn_implementation="flash_attention_2"
if self.enable_seq_packing
else None,
attn_implementation=attn_impl,
**hf_config_overrides,
)

Expand Down Expand Up @@ -255,9 +266,7 @@ def __init__(
# https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
self.model = model_class.from_config(
model_config,
attn_implementation="flash_attention_2"
if self.enable_seq_packing
else None,
attn_implementation=attn_impl,
use_liger_kernel=False,
trust_remote_code=True,
torch_dtype=str(model_config.torch_dtype),
Expand All @@ -268,15 +277,16 @@ def __init__(

tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"]
cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"]
sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"]
dp_size = self.cfg["dtensor_cfg"].get("data_parallel_size", None)
dp_replicate_size = self.cfg["dtensor_cfg"].get(
"data_parallel_replicate_size", 1
)
self.dp_replicate_size = dp_replicate_size
if cp_size > 1 and self.enable_seq_packing:
raise ValueError(
"Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details."
)
dp_size = world_size // tp_size // cp_size
sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"]
assert world_size == dp_size * tp_size * cp_size, (
f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor"
)

if sequence_parallel_enabled and tp_size == 1:
print(
Expand All @@ -303,70 +313,41 @@ def __init__(
"Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models."
)

# For FSDP2 compatibility, we need to support HSDP structure
# For now, we use dp_replicate_size = 1 (no hybrid sharding)
dp_replicate_size = 1
dp_shard_size = dp_size

# torch==2.8 uses LOCAL_RANK to set the device here (https://github.com/pytorch/pytorch/blob/ba56102387ef21a3b04b357e5b183d48f0afefc7/torch/distributed/device_mesh.py#L500),
# but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0.
# TODO: consider changing the default LOCAL_RANK set in worker_groups.py
prev_local_rank = os.environ["LOCAL_RANK"]
os.environ["LOCAL_RANK"] = "0"

# Create device mesh with HSDP structure for FSDP2 compatibility
device_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(dp_replicate_size, dp_shard_size, cp_size, tp_size),
mesh_dim_names=("dp_replicate", "dp_shard", "cp", "tp"),
)
os.environ["LOCAL_RANK"] = prev_local_rank

# Create flattened submeshes for different use cases
# Flatten dp_replicate + dp_shard for the "dp" dimension (backward compatibility)
device_mesh[("dp_replicate", "dp_shard")]._flatten(mesh_dim_name="dp")

# Flatten dp_shard + cp for FSDP2 sharding
device_mesh[("dp_shard", "cp")]._flatten(mesh_dim_name="dp_shard_cp")

# Flatten dp_replicate + dp_shard + cp for gradient operations
device_mesh[("dp_replicate", "dp_shard", "cp")]._flatten(mesh_dim_name="dp_cp")

# Store mesh references for backward compatibility
self.dp_cp_mesh = device_mesh["dp_cp"]
self.dp_mesh = device_mesh["dp"]
self.tp_mesh = device_mesh["tp"]
self.cp_mesh = device_mesh["cp"]

self.dp_size = dp_size
self.tp_size = tp_size
self.cp_size = cp_size
self.device_mesh = device_mesh

# ------------------------------------------------
# 3) Move to GPU + Composable FSDP
# (Initialize device mesh, shard submodules, then shard entire model)
# ------------------------------------------------
self.model = fsdp2_strategy_parallelize(
self.model,
device_mesh=self.device_mesh,
manager = FSDP2Manager(
dp_size=dp_size,
dp_replicate_size=dp_replicate_size,
tp_size=tp_size,
cp_size=cp_size,
sequence_parallel=sequence_parallel_enabled,
backend="nccl",
use_hf_tp_plan=self.cfg["dtensor_cfg"].get("use_hf_tp_plan", False),
mp_policy=MixedPrecisionPolicy(
param_dtype=self.dtype,
reduce_dtype=torch.float32,
output_dtype=torch.float32,
),
offload_policy=CPUOffloadPolicy(pin_memory=False)
if self.cpu_offload
else OffloadPolicy(),
sequence_parallel=sequence_parallel_enabled,
activation_checkpointing=self.cfg["dtensor_cfg"][
"activation_checkpointing"
],
tp_shard_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"],
dp_replicate_mesh_name="dp_replicate",
dp_shard_cp_mesh_name="dp_shard_cp",
tp_mesh_name="tp",
else None,
world_size=world_size,
)
self.device_mesh = manager.device_mesh

# Store mesh references for backward compatibility
self.dp_mesh = self.device_mesh["dp"]
self.dp_shard_cp_mesh = self.device_mesh["dp_shard_cp"]
self.tp_mesh = self.device_mesh["tp"]
self.cp_mesh = self.device_mesh["cp"]
self.dp_size = manager.dp_size
self.dp_shard_size = manager.dp_shard_size
self.tp_size = manager.tp_size
self.cp_size = manager.cp_size

# ------------------------------------------------
# 3) Move to GPU + Composable FSDP
# (Initialize device mesh, shard submodules, then shard entire model)
# ------------------------------------------------
self.model = manager.parallelize(self.model)

print(f"[Rank {self.rank}] Loading state dict from rank 0...")
# This will broadcast the state dict from rank 0 to all other ranks
Expand Down Expand Up @@ -830,7 +811,7 @@ def train(
with torch.no_grad():
grad_norm = get_grad_norm(
self.model.parameters(),
dp_cp_group=self.dp_cp_mesh.get_group(),
dp_cp_group=self.dp_shard_cp_mesh.get_group(),
tp_group=self.tp_mesh.get_group(),
dtype=torch.float32,
)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def test_dtensor_worker_logprob_tp2_or_cp2_matches_unsharded(use_v2, logprob_set


@pytest.mark.hf_gated
@pytest.mark.skip(reason="Disabled temporarily")
@pytest.mark.parametrize("use_v2", [True, False])
def test_dtensor_tp_and_tied_model_with_custom_parallel_plan(
use_v2, two_gpu_virtual_cluster, tiny_llama_tied_model_path
Expand Down
Loading