diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 7b5b7817f1..9f88be274e 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -87,7 +87,8 @@ def update_weights(self) -> None: if dist.get_rank() == 0: ray.get([engine.pause_generation.remote() for engine in self.rollout_engines]) - ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) + if not getattr(self.args, "keep_cache_on_weight_update", False): + ray.get([engine.flush_cache.remote() for engine in self.rollout_engines]) # int4/fp4 pre_process if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index d8da881ea3..cfd37e5446 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -434,6 +434,12 @@ def add_rollout_arguments(parser): default=1, help="Interval for updating the weights", ) + parser.add_argument( + "--keep-cache-on-weight-update", + action="store_true", + default=False, + help="If set, skip flushing the rollout engine cache during weight updates.", + ) parser.add_argument( "--keep-old-actor", action="store_true",