diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index cfe524be8d..6d89316f3a 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -158,6 +158,7 @@ def __init__( init_reference_model: bool = True, **kwargs: Any, ): + self.is_prepared = False self.tokenizer = tokenizer self.processor = processor self.is_vlm = processor is not None @@ -534,6 +535,11 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) if gbs is None: gbs = self.cfg["train_global_batch_size"] if mbs is None: @@ -910,6 +916,11 @@ def get_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) logprob_batch_size = ( micro_batch_size if micro_batch_size is not None @@ -1197,6 +1208,11 @@ def get_logprobs( # TODO @Rayen Tian: Related Issue: Refactor shared logic between score() and get_logprobs() (https://github.com/NVIDIA-NeMo/RL/issues/1094) @wrap_with_nvtx_name("dtensor_policy_worker/score") def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]: + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) global_batch_size = min(self.cfg["batch_size"], data.size) sequence_dim = 1 @@ -1819,6 +1835,7 @@ def broadcast_weights_for_collective(self) -> None: @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_lp_inference") def prepare_for_lp_inference(self) -> None: + self.is_prepared = True if not self.cpu_offload: self.move_to_cuda(self.model) else: @@ -1829,6 +1846,7 @@ def prepare_for_lp_inference(self) -> None: @wrap_with_nvtx_name("dtensor_policy_worker/prepare_for_training") def prepare_for_training(self, *args, **kwargs) -> None: + self.is_prepared = True # onload models and optimizer state to cuda if not self.cpu_offload: self.move_to_cuda(self.model) diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 5d08003ad9..80dcfa847c 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -70,6 +70,7 @@ def __init__( init_reference_model: bool = True, processor: Optional[AutoProcessor] = None, ): + self.is_prepared = False if weights_path: weights_path = os.path.abspath(weights_path) if optimizer_path: @@ -257,6 +258,11 @@ def get_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) dp_size = self.sharding_annotations.get_axis_size("data_parallel") sharded_data: list[SlicedDataDict] unsorted_data_indices: list[int] @@ -452,6 +458,11 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) batch_size = gbs or self.cfg["train_global_batch_size"] micro_batch_size = mbs or self.cfg["train_micro_batch_size"] # Shard and replicate the batch @@ -541,6 +552,11 @@ def train( def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training()?" + ) """Generate a batch of data using the policy.""" # Verify input data is right-padded assert isinstance(data, BatchedDataDict), ( @@ -631,11 +647,13 @@ def prepare_for_generation(self, *args: Any, **kwargs: Any) -> bool: return True def prepare_for_training(self, *args: Any, **kwargs: Any) -> None: + self.is_prepared = True # onload everything to the GPU futures = self.worker_group.run_all_workers_single_data("prepare_for_training") ray.get(futures) def prepare_for_lp_inference(self, *args: Any, **kwargs: Any) -> None: + self.is_prepared = True futures = self.worker_group.run_all_workers_single_data( "prepare_for_lp_inference" ) diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 44472224df..aece338182 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -453,6 +453,7 @@ def __init__( **kwargs: Any, ): self.is_generation_colocated = None + self.is_prepared = False if "generation" in config and config["generation"] is not None: self.is_generation_colocated = config["generation"]["colocated"]["enabled"] @@ -883,6 +884,11 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training() or prepare_for_lp_inference()?" + ) self.model.zero_grad_buffer() if hasattr(self.model, "inference_params"): self.model.inference_params = None @@ -1146,6 +1152,11 @@ def get_logprobs( We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. The logprob of input token i is specified at position i in the output logprobs tensor. """ + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training() or prepare_for_lp_inference()?" + ) no_grad = torch.no_grad() no_grad.__enter__() logprob_batch_size = ( @@ -1436,6 +1447,11 @@ def generate( - logprobs: Log probabilities for each token - generation_lengths: Lengths of each response """ + if not self.is_prepared: + raise RuntimeError( + "Model is not prepared for GPU execution. " + "Did you forget to call prepare_for_training() or prepare_for_lp_inference()?" + ) no_grad = torch.no_grad() no_grad.__enter__() self.model.config.flash_decode = True @@ -1763,12 +1779,14 @@ def broadcast_weights_for_collective(self) -> None: self.model_update_group.broadcast(tensor, src=0) def prepare_for_lp_inference(self): + self.is_prepared = True self.model = self.move_model(self.model, "cuda", move_grads=False) self.model.eval() self.offload_before_refit() def prepare_for_training(self, *args, **kwargs): # onload models and optimizer state to cuda + self.is_prepared = True self.model = self.move_model( self.model, "cuda", move_grads=True, move_params=True ) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 48b2c01dc8..01dea7e790 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -27,6 +27,7 @@ from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.lm_policy import Policy +from nemo_rl.models.policy.megatron_policy_worker import MegatronPolicyWorker from tests.unit.test_utils import SimpleLoss @@ -1932,3 +1933,30 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path): finally: policy.shutdown() cluster.shutdown() + + +def test_megatron_policy_worker_raises_if_not_prepared(tiny_llama_model_path): + """Test that MegatronPolicyWorker methods raise if prepare_for_training or prepare_for_lp_inference is not called.""" + config = PolicyConfig(model_name=tiny_llama_model_path) + tokenizer = get_tokenizer({"name": tiny_llama_model_path}) + + worker = MegatronPolicyWorker( + config=config, + tokenizer=tokenizer, + worker_sharding_annotations=None, + pre_init_communication_queue=None, + ) + + dummy_data = BatchedDataDict({"input_ids": None, "input_lengths": None}) + + # train should raise + with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"): + worker.train(dummy_data, loss_fn=None) + + # get_logprobs should raise + with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"): + worker.get_logprobs(data=dummy_data) + + # generate should raise + with pytest.raises(RuntimeError, match="Model is not prepared for GPU execution"): + worker.generate(data=dummy_data)