Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions nemo_rl/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
init_reference_model: bool = True,
**kwargs: Any,
):
self.is_prepared = False
self.tokenizer = tokenizer
self.processor = processor
self.is_vlm = processor is not None
Expand Down Expand Up @@ -534,6 +535,11 @@ def train(
mbs: Optional[int] = None,
) -> dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
if gbs is None:
gbs = self.cfg["train_global_batch_size"]
if mbs is None:
Expand Down Expand Up @@ -910,6 +916,11 @@ def get_logprobs(
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
The logprob of input token i is specified at position i in the output logprobs tensor.
"""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
logprob_batch_size = (
micro_batch_size
if micro_batch_size is not None
Expand Down Expand Up @@ -1197,6 +1208,11 @@ def get_logprobs(
# TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094)
@wrap_with_nvtx_name("dtensor_policy_worker/score")
def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]:
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
global_batch_size = min(self.cfg["batch_size"], data.size)

sequence_dim = 1
Expand Down Expand Up @@ -1819,6 +1835,7 @@ def broadcast_weights_for_collective(self) -> None:

@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference")
def prepare_for_lp_inference(self) -> None:
self.is_prepared = True
if not self.cpu_offload:
self.move_to_cuda(self.model)
else:
Expand All @@ -1829,6 +1846,7 @@ def prepare_for_lp_inference(self) -> None:

@wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training")
def prepare_for_training(self, *args, **kwargs) -> None:
self.is_prepared = True
# onload models and optimizer state to cuda
if not self.cpu_offload:
self.move_to_cuda(self.model)
Expand Down
18 changes: 18 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
init_reference_model: bool = True,
processor: Optional[AutoProcessor] = None,
):
self.is_prepared = False
if weights_path:
weights_path = os.path.abspath(weights_path)
if optimizer_path:
Expand Down Expand Up @@ -257,6 +258,11 @@ def get_logprobs(
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
The logprob of input token i is specified at position i in the output logprobs tensor.
"""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
dp_size = self.sharding_annotations.get_axis_size("data_parallel")
sharded_data: list[SlicedDataDict]
unsorted_data_indices: list[int]
Expand Down Expand Up @@ -452,6 +458,11 @@ def train(
mbs: Optional[int] = None,
) -> dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
batch_size = gbs or self.cfg["train_global_batch_size"]
micro_batch_size = mbs or self.cfg["train_micro_batch_size"]
# Shard and replicate the batch
Expand Down Expand Up @@ -541,6 +552,11 @@ def train(
def generate(
self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False
) -> BatchedDataDict[GenerationOutputSpec]:
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training()?"
)
"""Generate a batch of data using the policy."""
# Verify input data is right-padded
assert isinstance(data, BatchedDataDict), (
Expand Down Expand Up @@ -631,11 +647,13 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool:
return True

def prepare_for_training(self, *args: Any, **kwargs: Any) -> None:
self.is_prepared = True
# onload everything to the GPU
futures = self.worker_group.run_all_workers_single_data("prepare_for_training")
ray.get(futures)

def prepare_for_lp_inference(self, *args: Any, **kwargs: Any) -> None:
self.is_prepared = True
futures = self.worker_group.run_all_workers_single_data(
"prepare_for_lp_inference"
)
Expand Down
18 changes: 18 additions & 0 deletions nemo_rl/models/policy/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __init__(
**kwargs: Any,
):
self.is_generation_colocated = None
self.is_prepared = False
if "generation" in config and config["generation"] is not None:
self.is_generation_colocated = config["generation"]["colocated"]["enabled"]

Expand Down Expand Up @@ -883,6 +884,11 @@ def train(
mbs: Optional[int] = None,
) -> dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training() or prepare_for_lp_inference()?"
)
self.model.zero_grad_buffer()
if hasattr(self.model, "inference_params"):
self.model.inference_params = None
Expand Down Expand Up @@ -1146,6 +1152,11 @@ def get_logprobs(
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
The logprob of input token i is specified at position i in the output logprobs tensor.
"""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training() or prepare_for_lp_inference()?"
)
no_grad = torch.no_grad()
no_grad.__enter__()
logprob_batch_size = (
Expand Down Expand Up @@ -1436,6 +1447,11 @@ def generate(
- logprobs: Log probabilities for each token
- generation_lengths: Lengths of each response
"""
if not self.is_prepared:
raise RuntimeError(
"Model is not prepared for GPU execution. "
"Did you forget to call prepare_for_training() or prepare_for_lp_inference()?"
)
no_grad = torch.no_grad()
no_grad.__enter__()
self.model.config.flash_decode = True
Expand Down Expand Up @@ -1763,12 +1779,14 @@ def broadcast_weights_for_collective(self) -> None:
self.model_update_group.broadcast(tensor, src=0)

def prepare_for_lp_inference(self):
self.is_prepared = True
self.model = self.move_model(self.model, "cuda", move_grads=False)
self.model.eval()
self.offload_before_refit()

def prepare_for_training(self, *args, **kwargs):
# onload models and optimizer state to cuda
self.is_prepared = True
self.model = self.move_model(
self.model, "cuda", move_grads=True, move_params=True
)
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.lm_policy import Policy
from nemo_rl.models.policy.megatron_policy_worker import MegatronPolicyWorker
from tests.unit.test_utils import SimpleLoss


Expand Down Expand Up @@ -1932,3 +1933,30 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path):
finally:
policy.shutdown()
cluster.shutdown()


def test_megatron_policy_worker_raises_if_not_prepared(tiny_llama_model_path):
"""Test that MegatronPolicyWorker methods raise if prepare_for_training or prepare_for_lp_inference is not called."""
config = PolicyConfig(model_name=tiny_llama_model_path)
tokenizer = get_tokenizer({"name": tiny_llama_model_path})

worker = MegatronPolicyWorker(
config=config,
tokenizer=tokenizer,
worker_sharding_annotations=None,
pre_init_communication_queue=None,
)

dummy_data = BatchedDataDict({"input_ids": None, "input_lengths": None})

# train should raise
with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"):
worker.train(dummy_data, loss_fn=None)

# get_logprobs should raise
with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"):
worker.get_logprobs(data=dummy_data)

# generate should raise
with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"):
worker.generate(data=dummy_data)