diff --git a/3rdparty/Automodel-workspace/Automodel b/3rdparty/Automodel-workspace/Automodel index a2db048383..756ed10c29 160000 --- a/3rdparty/Automodel-workspace/Automodel +++ b/3rdparty/Automodel-workspace/Automodel @@ -1 +1 @@ -Subproject commit a2db048383cd54b3fafc928df4c30bf7bbf7c430 +Subproject commit 756ed10c29039cd9af551761d054a526021f559d diff --git a/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml new file mode 100644 index 0000000000..0f150b1c4a --- /dev/null +++ b/examples/configs/recipes/llm/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel.yaml @@ -0,0 +1,29 @@ +defaults: ../../sft.yaml +policy: + model_name: openai/gpt-oss-20b + train_global_batch_size: 128 + train_micro_batch_size: 8 + max_total_sequence_length: 512 + dequantize_base_checkpoint: true + automodel_model_kwargs: + backend: + _target_: nemo_automodel.components.moe.utils.BackendConfig + attn: te + linear: te + rms_norm: te + enable_deepep: true + fake_balanced_gate: false + enable_hf_state_dict_adapter: true + dtensor_cfg: + _v2: true + expert_parallel_size: 8 + data_parallel_size: 8 + optimizer: + name: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam + kwargs: + store_param_remainders: true + master_weights: true + exp_avg_dtype: bfloat16 + exp_avg_sq_dtype: bfloat16 +checkpointing: + checkpoint_dir: results/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 47a0c60da4..5e749d4815 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -13,52 +13,60 @@ # limitations under the License. import gc +import inspect import itertools import os import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from typing import Any, Generator, Optional, cast +from typing import Any, Callable, Generator, Optional, cast import ray import torch import zmq from accelerate import init_empty_weights -from nemo_automodel import ( +from nemo_automodel._transformers.auto_model import ( NeMoAutoModelForSequenceClassification, ) +from nemo_automodel.components.checkpoint._backports.filesystem import ( + SerializationFormat, +) +from nemo_automodel.components.checkpoint.checkpointing import ( + Checkpointer, + _maybe_adapt_state_dict_to_hf, +) +from nemo_automodel.components.checkpoint.checkpointing import ( + CheckpointingConfig as AutomodelCheckpointingConfig, +) +from nemo_automodel.components.config.loader import _resolve_target from nemo_automodel.components.distributed.cp_utils import ( create_context_parallel_ctx, get_train_context, ) -from nemo_automodel.components.distributed.grad_utils import ( - clip_grad_by_total_norm_, - get_grad_norm, -) -from nemo_automodel.components.distributed.parallelizer import ( - fsdp2_strategy_parallelize, +from nemo_automodel.components.distributed.fsdp2 import ( + FSDP2Manager, ) from nemo_automodel.components.distributed.tensor_utils import ( get_cpu_state_dict, to_local_if_dtensor, ) -from torch import nn -from torch.distributed.checkpoint.state_dict import ( - StateDictOptions, - set_model_state_dict, +from nemo_automodel.components.moe.parallelizer import ( + parallelize_model as moe_parallelize_model, ) +from torch import nn from torch.distributed.fsdp import ( CPUOffloadPolicy, MixedPrecisionPolicy, - OffloadPolicy, ) from torch.distributed.tensor import DTensor, Shard from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, + PreTrainedModel, ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM +from transformers.utils import TRANSFORMERS_CACHE from nemo_rl.algorithms.interfaces import LossFunction, LossType from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper @@ -85,14 +93,16 @@ import_class_from_path, resolve_model_class, ) -from nemo_rl.utils.automodel_checkpoint import ( - load_checkpoint, - save_checkpoint, -) from nemo_rl.utils.checkpoint import CheckpointingConfig from nemo_rl.utils.nsys import wrap_with_nvtx_name from nemo_rl.utils.packed_tensor import packed_broadcast_producer +STRING_TO_DTYPE = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + @ray.remote( runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2") @@ -140,26 +150,28 @@ def __init__( configure_dynamo_cache() self.cfg = config + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call - torch.distributed.init_process_group(backend="nccl") + backend = "nccl" if not self.cpu_offload else "cuda:nccl,cpu:gloo" + torch.distributed.init_process_group(backend=backend) self.rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + # We initialize the AutoModel checkpointer here. This needs to be persistent because of async checkpointing support + # once NeMo-RL is >= torch 2.9.0 + self.checkpointer = None + self.checkpoint_config = None + self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] - if self.cfg["precision"] == "float32": - self.dtype = torch.float32 - elif self.cfg["precision"] == "bfloat16": - self.dtype = torch.bfloat16 - elif self.cfg["precision"] == "float16": - self.dtype = torch.float16 - else: + try: + self.dtype = STRING_TO_DTYPE[self.cfg["precision"]] + except KeyError: raise ValueError(f"Unknown precision: {self.cfg['precision']}") - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"] if self.enable_seq_packing: assert not self.is_vlm, ( @@ -172,6 +184,16 @@ def __init__( hf_config_overrides = self.cfg.get("hf_config_overrides", {}) or {} + # Choose attention implementation on the following basis: + # - Packed sequence requires FA2 and CP must be 1 + # - CP > 1 requires SDPA + cp_size_cfg = self.cfg["dtensor_cfg"]["context_parallel_size"] + attn_impl = ( + "flash_attention_2" + if (self.enable_seq_packing and cp_size_cfg == 1) + else ("sdpa" if cp_size_cfg > 1 else None) + ) + model_config = AutoConfig.from_pretrained( model_name, # Always load the model in float32 to keep master weights in float32. @@ -221,56 +243,51 @@ def __init__( # DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc. model_class = resolve_model_class(model_config.model_type) - full_state_dict = None - model_state_dict_keys = None - if self.rank == 0: - print(f"[Rank {self.rank}] Loading model {model_name} on CPU...") - model = model_class.from_pretrained( - model_name, - device_map="cpu", # load weights onto CPU initially - trust_remote_code=True, - config=model_config, - use_liger_kernel=False, - torch_dtype=str(model_config.torch_dtype), - ) - - full_state_dict = model.state_dict() - # Store the original model state dict keys before any parallelization - model_state_dict_keys = list(full_state_dict.keys()) - del model - print(f"[Rank {self.rank}] Initializing empty model for FSDP...") # All ranks initialize model on meta device, so FSDP can shard it. # The actual weights will be broadcast from rank 0. + automodel_model_kwargs = self.cfg.get("automodel_model_kwargs", {}) + if automodel_model_kwargs.get("backend", None) is not None: + backend_class = _resolve_target( + automodel_model_kwargs.get("backend", None)["_target_"] + ) + backend_kwargs = automodel_model_kwargs.get("backend") + backend_kwargs.pop("_target_") + backend = backend_class( + **backend_kwargs, + ) + automodel_model_kwargs["backend"] = backend + with init_empty_weights(): # NeMoAutoModelForCausalLM uses flash_attention_2 by default # so we need to set it to None if sequence packing is disabled # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180 - self.model = model_class.from_config( - model_config, - attn_implementation="flash_attention_2" - if self.enable_seq_packing - else None, - use_liger_kernel=False, - trust_remote_code=True, + self.model = model_class.from_pretrained( + model_name, + attn_implementation=attn_impl, torch_dtype=str(model_config.torch_dtype), + trust_remote_code=True, + config=model_config, + use_liger_kernel=False, + **automodel_model_kwargs, ) + # Hold a copy of model state_dict keys before any parallelization + self.model_state_dict_keys = list(self.model.state_dict().keys()) + if self.model.config.pad_token_id is None: self.model.config.pad_token_id = tokenizer.pad_token_id - tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"] - cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"] + tp_size = self.cfg["dtensor_cfg"].get("tensor_parallel_size", 1) + cp_size = self.cfg["dtensor_cfg"].get("context_parallel_size", 1) + ep_size = self.cfg["dtensor_cfg"].get("expert_parallel_size", 1) + dp_size = self.cfg["dtensor_cfg"].get("data_parallel_size", None) if cp_size > 1 and self.enable_seq_packing: raise ValueError( "Context parallel is not supported for sequence packing. Refer to https://github.com/NVIDIA/NeMo-RL/blob/main/docs/model-quirks.md#context-parallel-with-fsdp2 for more details." ) - dp_size = world_size // tp_size // cp_size sequence_parallel_enabled = self.cfg["dtensor_cfg"]["sequence_parallel"] - assert world_size == dp_size * tp_size * cp_size, ( - f"World size({world_size}) must equal to dp_size({dp_size}) * tp_size({tp_size}) * cp_size({cp_size}) to use DTensor" - ) if sequence_parallel_enabled and tp_size == 1: print( @@ -297,53 +314,18 @@ def __init__( "Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models." ) - # For FSDP2 compatibility, we need to support HSDP structure - # For now, we use dp_replicate_size = 1 (no hybrid sharding) - dp_replicate_size = 1 - dp_shard_size = dp_size - - # torch==2.8 uses LOCAL_RANK to set the device here (https://github.com/pytorch/pytorch/blob/ba56102387ef21a3b04b357e5b183d48f0afefc7/torch/distributed/device_mesh.py#L500), - # but CUDA_VISIBLE_DEVICES is set to only 1 gpu, so we need to temporarily set LOCAL_RANK to 0. - # TODO: consider changing the default LOCAL_RANK set in worker_groups.py - prev_local_rank = os.environ["LOCAL_RANK"] - os.environ["LOCAL_RANK"] = "0" - - # Create device mesh with HSDP structure for FSDP2 compatibility - device_mesh = torch.distributed.device_mesh.init_device_mesh( - "cuda", - (dp_replicate_size, dp_shard_size, cp_size, tp_size), - mesh_dim_names=("dp_replicate", "dp_shard", "cp", "tp"), - ) - os.environ["LOCAL_RANK"] = prev_local_rank - - # Create flattened submeshes for different use cases - # Flatten dp_replicate + dp_shard for the "dp" dimension (backward compatibility) - device_mesh[("dp_replicate", "dp_shard")]._flatten(mesh_dim_name="dp") - - # Flatten dp_shard + cp for FSDP2 sharding - device_mesh[("dp_shard", "cp")]._flatten(mesh_dim_name="dp_shard_cp") - - # Flatten dp_replicate + dp_shard + cp for gradient operations - device_mesh[("dp_replicate", "dp_shard", "cp")]._flatten(mesh_dim_name="dp_cp") - - # Store mesh references for backward compatibility - self.dp_cp_mesh = device_mesh["dp_cp"] - self.dp_mesh = device_mesh["dp"] - self.tp_mesh = device_mesh["tp"] - self.cp_mesh = device_mesh["cp"] - - self.dp_size = dp_size - self.tp_size = tp_size - self.cp_size = cp_size - self.device_mesh = device_mesh - # ------------------------------------------------ - # 3) Move to GPU + Composable FSDP - # (Initialize device mesh, shard submodules, then shard entire model) + # Build device mesh and parallelize # ------------------------------------------------ - self.model = fsdp2_strategy_parallelize( - self.model, - device_mesh=self.device_mesh, + manager = FSDP2Manager( + dp_size=dp_size, + dp_replicate_size=1, + tp_size=tp_size, + cp_size=cp_size, + ep_size=ep_size, + pp_size=1, + sequence_parallel=sequence_parallel_enabled, + use_hf_tp_plan=self.cfg["dtensor_cfg"].get("use_hf_tp_plan", False), mp_policy=MixedPrecisionPolicy( param_dtype=self.dtype, reduce_dtype=torch.float32, @@ -351,33 +333,73 @@ def __init__( ), offload_policy=CPUOffloadPolicy(pin_memory=False) if self.cpu_offload - else OffloadPolicy(), - sequence_parallel=sequence_parallel_enabled, + else None, + backend="nccl", + world_size=world_size, activation_checkpointing=self.cfg["dtensor_cfg"][ "activation_checkpointing" ], - tp_shard_plan=self.cfg["dtensor_cfg"]["custom_parallel_plan"], - dp_replicate_mesh_name="dp_replicate", - dp_shard_cp_mesh_name="dp_shard_cp", - tp_mesh_name="tp", + custom_tp_plan=self.cfg["dtensor_cfg"].get("custom_parallel_plan", None), ) - print(f"[Rank {self.rank}] Loading state dict from rank 0...") - # This will broadcast the state dict from rank 0 to all other ranks - # and load it into the FSDP model. - set_model_state_dict( - self.model, - model_state_dict=full_state_dict, - options=StateDictOptions( - full_state_dict=True, - broadcast_from_rank0=True, - ), + # Store mesh references for downstream usage + self.device_mesh = manager.device_mesh + self.dp_cp_mesh = self.device_mesh["dp_cp"] + self.dp_mesh = self.device_mesh["dp"] + self.tp_mesh = self.device_mesh["tp"] + self.cp_mesh = self.device_mesh["cp"] + self.moe_mesh = getattr(manager, "moe_mesh", None) + + self.dp_size = manager.dp_size + self.tp_size = manager.tp_size + self.cp_size = manager.cp_size + + # Parallelize model + is_moe_model = any(["expert" in key for key in self.model_state_dict_keys]) + if not isinstance(self.model, PreTrainedModel) and is_moe_model: + moe_parallelize_model( + model=self.model, + world_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + pp_enabled=False, + dp_axis_names=( + ("dp_replicate", "dp_shard_cp") + if "dp_replicate" in self.device_mesh.mesh_dim_names + and "dp_shard_cp" in self.device_mesh.mesh_dim_names + else ("dp_shard_cp",) + ), + cp_axis_name="cp", + tp_axis_name="tp", + ep_axis_name="ep", + ep_shard_axis_names=("ep_shard",), + ) + else: + self.model = manager.parallelize(self.model) + + # Load base model weights across all ranks using Automodel Checkpointer + # This mirrors build_model_and_optimizer's is_meta_device + load_weights path + print(self.model) + self._ensure_checkpointer( + config_updates={ + "model_repo_id": model_name, + "dequantize_base_checkpoint": self.cfg.get( + "dequantize_base_checkpoint", False + ), + }, + checkpoint_root=None, ) + self.checkpointer.config.model_state_dict_keys = self.model_state_dict_keys - # Broadcast model state dict keys to all ranks and store as instance variable - keys_to_broadcast = [model_state_dict_keys] - torch.distributed.broadcast_object_list(keys_to_broadcast, src=0) - self.model_state_dict_keys = keys_to_broadcast[0] + # Load base HF weights unless an explicit checkpoint is provided later + # This puts shards directly into the parallelized model + self.checkpointer.load_base_model( + self.model, + device=torch.cuda.current_device(), + root_dir=hf_config_overrides.get("cache_dir", TRANSFORMERS_CACHE), + model_name=model_name, + peft_init_method=None, # TODO: change for LoRA + load_base_model=True, + ) # Handle tied word embeddings after loading the state dict # We need to actually tie the parameters at the model level @@ -394,10 +416,6 @@ def __init__( if embed_tokens_weight is not None: self.model.lm_head.weight = embed_tokens_weight - # Manually broadcast buffers - for _, buf in self.model.named_buffers(): - torch.distributed.broadcast(to_local_if_dtensor(buf), src=0) - if self.cpu_offload: self.model = self.move_to_device(self.model, "cpu") @@ -408,8 +426,12 @@ def __init__( if init_optimizer: optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) + optimizer_kwargs = _resolve_kwargs( + optimizer_cls, self.cfg["optimizer"]["kwargs"] + ) self.optimizer = optimizer_cls( - self.model.parameters(), **self.cfg["optimizer"]["kwargs"] + self.model.parameters(), + **optimizer_kwargs, ) else: self.optimizer = None @@ -453,7 +475,7 @@ def __init__( self.load_checkpoint(weights_path, optimizer_path) else: print( - "No weights path provided. Starting from scratch (default policy init)" + "No weights path provided. Loaded base HF weights via Checkpointer (default policy init)" ) def _apply_temperature_scaling(self, logits: torch.Tensor) -> torch.Tensor: @@ -689,10 +711,11 @@ def train( ) with get_train_context(False, False, context_parallel_ctx)(): - with torch.autocast(device_type="cuda", dtype=self.dtype): + with nullcontext(): model_args = dict( input_ids=input_ids, attention_mask=attention_mask, + padding_mask=~attention_mask, position_ids=position_ids, use_cache=False, flash_attn_kwargs=flash_attn_kwargs, @@ -718,10 +741,14 @@ def train( outputs = self.model(**model_args) # Get logprobs - if not hasattr(outputs, "logits"): - logits = self.model.lm_head(outputs.last_hidden_state) + if isinstance(outputs, (torch.Tensor, DTensor)): + # custom models (e.g., those coming from AutoModel) can output logits directly + logits = outputs else: - logits = outputs.logits + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits del outputs # Apply temperature scaling @@ -812,7 +839,7 @@ def train( # when FSDP reduces the gradients over the DP dim, they're automatically averaged # but we want to sum them so we cancel out the average here - loss *= self.dp_size * self.cp_size + # loss *= self.dp_size * self.cp_size loss.backward() if num_valid_samples > 0: @@ -821,20 +848,29 @@ def train( grad_norm: Optional[float | torch.Tensor] = None if not eval_mode: - with torch.no_grad(): - grad_norm = get_grad_norm( - self.model.parameters(), - dp_cp_group=self.dp_cp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), - dtype=torch.float32, - ) - if self.max_grad_norm is not None: - clip_grad_by_total_norm_( - self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, - ) - grad_norm = torch.tensor([grad_norm]) + from nemo_automodel.components.training.utils import ( + scale_grads_and_clip_grad_norm, + ) + + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) # Update parameters self.optimizer.step() @@ -1015,7 +1051,7 @@ def get_logprobs( ) with get_train_context(False, False, context_parallel_ctx)(): - with torch.autocast(device_type="cuda", dtype=self.dtype): + with nullcontext(): model_args = dict( input_ids=input_ids, attention_mask=attention_mask, @@ -1035,7 +1071,7 @@ def get_logprobs( outputs = self.model(**model_args) - logits = outputs.logits + logits = outputs.logits if hasattr(outputs, "logits") else outputs # Apply temperature scaling logits = self._apply_temperature_scaling(logits) @@ -1877,6 +1913,7 @@ def save_checkpoint( the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. """ + print(f"Saving checkpoint to {weights_path}") if checkpointing_cfg is None: raise ValueError( "checkpointing_cfg must be provided when saving checkpoint" @@ -1892,41 +1929,152 @@ def save_checkpoint( "save_consolidated", "is_peft", "peft_config", + "model_cache_dir", + "model_repo_id", + "is_async", + "dequantize_base_checkpoint", } } - save_checkpoint( + checkpoint_root = _infer_checkpoint_root(weights_path) + + # Ensure a persistent Checkpointer exists and is configured + self._ensure_checkpointer( + config_updates=checkpoint_kwargs, checkpoint_root=checkpoint_root + ) + + self.checkpointer.save_model( model=self.model, weights_path=weights_path, - optimizer=self.optimizer if optimizer_path else None, - scheduler=self.scheduler if optimizer_path else None, - optimizer_path=optimizer_path, - tokenizer=self.tokenizer if tokenizer_path else None, - tokenizer_path=tokenizer_path, - model_state_dict_keys=self.model_state_dict_keys, - **checkpoint_kwargs, + peft_config=checkpoint_kwargs.get("peft_config"), + tokenizer=self.tokenizer if tokenizer_path is None else None, ) + if optimizer_path and self.optimizer is not None: + self.checkpointer.save_optimizer( + optimizer=self.optimizer, + model=self.model, + weights_path=optimizer_path, + scheduler=self.scheduler, + ) + + # TODO: needed? + if tokenizer_path and self.tokenizer is not None: + print(f"Saving tokenizer (or processor) to {tokenizer_path}") + self.tokenizer.save_pretrained(tokenizer_path) + def load_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, ) -> None: - """Load a checkpoint into the model.""" - load_checkpoint( + """Load a checkpoint into the model using Automodel Checkpointer.""" + print(f"Loading weights from {weights_path}") + + model_save_format, is_peft = detect_checkpoint_format(weights_path) + + weights_dir = os.path.dirname(weights_path) + checkpoint_root = ( + os.path.dirname(weights_dir) + if weights_dir.endswith("weights") + else weights_dir + ) + + # Ensure a persistent Checkpointer exists and is configured + self._ensure_checkpointer( + config_updates={ + "model_save_format": model_save_format, + "is_peft": is_peft, + "dequantize_base_checkpoint": False, # the saved checkpoint is already dequantized + }, + checkpoint_root=checkpoint_root, + ) + + model_dir = ( + weights_path + if weights_path.endswith("/model") + else os.path.join(weights_path, "model") + ) + + self.checkpointer.load_model( model=self.model, - weights_path=weights_path, - optimizer=self.optimizer if optimizer_path else None, - scheduler=self.scheduler if optimizer_path else None, - optimizer_path=optimizer_path, + model_path=model_dir, ) + if optimizer_path and self.optimizer is not None: + self.checkpointer.load_optimizer( + optimizer=self.optimizer, + model=self.model, + weights_path=optimizer_path, + scheduler=self.scheduler, + ) + + def _ensure_checkpointer( + self, config_updates=None, checkpoint_root: Optional[str] = None + ) -> None: + """Create or update a persistent Automodel Checkpointer bound to this worker ranks. + + Args: + config_updates: Dict of CheckpointingConfig fields to update. + checkpoint_root: Optional root directory for checkpoints. + """ + if config_updates is None: + config_updates = {} + + # Compute dp/tp ranks + dp_rank = torch.distributed.get_rank(self.dp_mesh.get_group()) + tp_rank = torch.distributed.get_rank(self.tp_mesh.get_group()) + pp_rank = 0 + + if self.checkpointer is None: + # Initialize a base config with sensible defaults + base_cfg = AutomodelCheckpointingConfig( + enabled=True, + checkpoint_dir=checkpoint_root or "", + model_save_format=config_updates.get( + "model_save_format", "safetensors" + ), + model_cache_dir=config_updates.get("model_cache_dir", ""), + model_repo_id=config_updates.get("model_repo_id", ""), + save_consolidated=config_updates.get("save_consolidated", False), + is_peft=config_updates.get("is_peft", False), + model_state_dict_keys=getattr(self, "model_state_dict_keys", None), + is_async=config_updates.get("is_async", False), + dequantize_base_checkpoint=config_updates.get( + "dequantize_base_checkpoint", False + ), + ) + self.checkpoint_config = base_cfg + self.checkpointer = Checkpointer( + config=base_cfg, + dp_rank=dp_rank, + tp_rank=tp_rank, + pp_rank=pp_rank, + moe_mesh=None, + ) + else: + # Update mutable config fields on the existing instance + cfg = self.checkpointer.config + if checkpoint_root is not None: + cfg.checkpoint_dir = checkpoint_root + for k, v in config_updates.items(): + if k == "model_save_format": + # Ensure enum type + v = SerializationFormat[v.upper()] if isinstance(v, str) else v + setattr(cfg, k, v) + # Ensure model_state_dict_keys is current + if getattr(self, "model_state_dict_keys", None) is not None: + cfg.model_state_dict_keys = self.model_state_dict_keys + def shutdown(self) -> None: """Shutdown the policy.""" # Clean up extension resources like ZMQ sockets if hasattr(self, "zmq_socket"): self.zmq_socket.close() self.zmq_context.term() + # Close checkpointer resources + if hasattr(self, "checkpointer") and self.checkpointer is not None: + self.checkpointer.close() def start_gpu_profiling(self) -> None: """Start GPU profiling.""" @@ -1941,3 +2089,92 @@ def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]: ip = ray._private.services.get_node_ip_address() gpu_id = ray.get_gpu_ids()[0] return (ip, gpu_id) + + +def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: + """Detect model save format and PEFT status from checkpoint directory. + + Args: + weights_path: Path to the checkpoint directory (e.g., weights/model) + + Returns: + tuple: (model_save_format, is_peft) where: + model_save_format is "torch_save" for DCP or "safetensors" for safetensors + is_peft is True if PEFT/adapter patterns are detected + """ + is_peft = False + model_save_format = "safetensors" + try: + # Iterate through all subdirectories and files recursively + all_files = [] + for root, dirs, files in os.walk(weights_path): + all_files.extend(files) + + if any(f.endswith(".distcp") for f in all_files): + model_save_format = "torch_save" + elif any(f.endswith(".safetensors") for f in all_files): + model_save_format = "safetensors" + elif any(f.endswith((".bin", ".pt", ".pth")) for f in all_files): + model_save_format = "torch_save" + + if not is_peft: + is_peft = any("adapter" in f.lower() for f in all_files) + + except (OSError, PermissionError): + pass + + return model_save_format, is_peft + + +def _infer_checkpoint_root(weights_path: str) -> str: + """Infer checkpoint root directory from weights path. + + When weights_path ends with "…/weights/model", we need the parent of + the weights directory (the checkpoint root), not the weights directory itself. + + Args: + weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") + + Returns: + str: Checkpoint root directory (e.g., "/path/to/policy") + """ + weights_dir = os.path.dirname(weights_path) + if weights_dir.endswith("weights"): + return os.path.dirname(weights_dir) + return weights_dir + + +def _resolve_kwargs(callable: Callable, kwargs: dict[str, Any]) -> dict[str, Any]: + """Resolve kwargs for a callable. + + Args: + callable: The callable to resolve kwargs for + kwargs: The kwargs to resolve + + Returns: + The resolved kwargs + """ + + def _resolve_import_class(name: str) -> Any | None: + try: + if name in STRING_TO_DTYPE: + return STRING_TO_DTYPE[name] + return import_class_from_path(name) + except Exception: + return + + signature = ( + inspect.signature(callable) + if inspect.isfunction(callable) + else inspect.signature(callable.__init__) + ) + result = {} + for k, v in kwargs.items(): + if k in signature.parameters: + _maybe_resolved_value = ( + _resolve_import_class(v) if isinstance(v, str) else v + ) + result[k] = ( + _maybe_resolved_value if _maybe_resolved_value is not None else v + ) + return result diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index 7ad6d99849..e9bab9231c 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -30,7 +30,7 @@ # Try to import nemo_automodel classes, fallback to None if not available try: - from nemo_automodel.components._transformers.auto_model import ( + from nemo_automodel._transformers.auto_model import ( NeMoAutoModelForCausalLM, NeMoAutoModelForImageTextToText, NeMoAutoModelForTextToWaveform, diff --git a/nemo_rl/utils/automodel_checkpoint.py b/nemo_rl/utils/automodel_checkpoint.py deleted file mode 100644 index a9f0793851..0000000000 --- a/nemo_rl/utils/automodel_checkpoint.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Checkpoint management utilities for HF models.""" - -import os -from typing import Any, Optional - -import torch -from nemo_automodel.components.checkpoint._backports.filesystem import ( - SerializationFormat, -) - -# Apply torch backports for compatibility with torch==2.7.1 -from nemo_automodel.components.checkpoint._torch_backports import apply_patches - -# Import from nemo-automodel -from nemo_automodel.components.checkpoint.checkpointing import ( - CheckpointingConfig, - load_model, - load_optimizer, - save_model, - save_optimizer, -) - -# Apply torch backports for compatibility with torch==2.7.1 -apply_patches() - - -def _infer_checkpoint_root(weights_path: str) -> str: - """Infer checkpoint root directory from weights path. - - When weights_path ends with "…/weights/model", we need the parent of - the weights directory (the checkpoint root), not the weights directory itself. - - Args: - weights_path: Path to model weights (e.g., "/path/to/policy/weights/model") - - Returns: - str: Checkpoint root directory (e.g., "/path/to/policy") - """ - weights_dir = os.path.dirname(weights_path) - if weights_dir.endswith("weights"): - return os.path.dirname(weights_dir) - return weights_dir - - -def detect_checkpoint_format(weights_path: str) -> tuple[str, bool]: - """Detect model save format and PEFT status from checkpoint directory. - - Args: - weights_path: Path to the checkpoint directory (e.g., weights/model) - - Returns: - tuple: (model_save_format, is_peft) where: - model_save_format is "torch_save" for DCP or "safetensors" for safetensors - is_peft is True if PEFT/adapter patterns are detected - """ - is_peft = False - model_save_format = "safetensors" - try: - # Iterate through all subdirectories and files recursively - all_files = [] - for root, dirs, files in os.walk(weights_path): - all_files.extend(files) - - if any(f.endswith(".distcp") for f in all_files): - model_save_format = "torch_save" - elif any(f.endswith(".safetensors") for f in all_files): - model_save_format = "safetensors" - elif any(f.endswith((".bin", ".pt", ".pth")) for f in all_files): - model_save_format = "torch_save" - - if not is_peft: - is_peft = any("adapter" in f.lower() for f in all_files) - - except (OSError, PermissionError): - pass - - return model_save_format, is_peft - - -def save_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Any] = None, - optimizer_path: Optional[str] = None, - tokenizer: Optional[Any] = None, - tokenizer_path: Optional[str] = None, - model_save_format: str = "safetensors", - is_peft: bool = False, - peft_config: Optional[Any] = None, - save_consolidated: bool = False, - model_state_dict_keys: Optional[list[str]] = None, -) -> None: - """Save a checkpoint of the model and optionally optimizer state. - - Args: - model: The PyTorch model to save - weights_path: Path to save model weights - optimizer: Optional optimizer to save - scheduler: Optional scheduler to save - optimizer_path: Path to save optimizer state (required if optimizer provided) - tokenizer: Optional tokenizer to save - tokenizer_path: Path to save tokenizer state (required if tokenizer provided) - model_save_format: Format for saving model ("torch_save" or "safetensors") - is_peft: Whether the model uses PEFT - peft_config: PEFT configuration if is_peft is True - save_consolidated: Whether to save consolidated checkpoints (for HF compatibility) - model_state_dict_keys: Copy of the model state dict keys before any parallelization. - If None, will be extracted from the model's current state dict. - """ - # Create checkpoint config - - # Extract model state dict keys if not provided - if model_state_dict_keys is None: - model_state_dict_keys = list(model.state_dict().keys()) - - valid_formats = {"safetensors", "torch_save"} - if model_save_format not in valid_formats: - raise ValueError( - f"Unsupported model_save_format='{model_save_format}'. " - f"Expected one of {sorted(valid_formats)}." - ) - - # Ensure target directories exist - os.makedirs(weights_path, exist_ok=True) - if optimizer_path: - os.makedirs(optimizer_path, exist_ok=True) - if tokenizer_path: - os.makedirs(tokenizer_path, exist_ok=True) - - checkpoint_config = CheckpointingConfig( - enabled=True, - checkpoint_dir=_infer_checkpoint_root(weights_path), - model_save_format=model_save_format, - model_cache_dir="", - model_repo_id="", - save_consolidated=save_consolidated, - is_peft=is_peft, - model_state_dict_keys=model_state_dict_keys, - ) - - # Save model using nemo-automodel API - save_model( - model=model, - weights_path=weights_path, - checkpoint_config=checkpoint_config, - peft_config=peft_config, - tokenizer=tokenizer if tokenizer_path is None else None, - ) - - # Save optimizer if provided - if optimizer is not None: - if optimizer_path is None: - raise ValueError( - "optimizer_path must be provided when saving optimizer state" - ) - save_optimizer( - optimizer=optimizer, - model=model, - weights_path=optimizer_path, - scheduler=scheduler, - ) - - # Save tokenizer separately if tokenizer_path provided - if tokenizer is not None and tokenizer_path is not None: - print(f"Saving tokenizer (or processor) to {tokenizer_path}") - tokenizer.save_pretrained(tokenizer_path) - - -def load_checkpoint( - model: torch.nn.Module, - weights_path: str, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[Any] = None, - optimizer_path: Optional[str] = None, -) -> None: - """Load a model weights and optionally optimizer state. - - Args: - model: The PyTorch model whose weights to update - weights_path: Path to load model weights from - optimizer: Optional optimizer to load state into - scheduler: Optional scheduler to load state into - optimizer_path: Path to load optimizer state from (required if optimizer provided) - """ - print(f"Loading weights from {weights_path}") - - model_save_format, is_peft = detect_checkpoint_format(weights_path) - - try: - format_enum = SerializationFormat[model_save_format.upper()] - - # append /model to the weights_path if it doesn't exist - # TODO: remove this once nemo-automodel is updated - if not weights_path.endswith("/model"): - weights_path = os.path.join(weights_path, "model") - - # Load model using nemo-automodel API - load_model( - model=model, - model_path=weights_path, - model_save_format=format_enum, - is_peft=is_peft, - ) - except FileNotFoundError as e: - msg = ( - f"Failed to load model from '{weights_path}': {e}\n" - "Note: DTensorPolicyWorkerV2 expects:\n" - " - Model shards under '/weights/model'\n" - " - Optimizer states under '/optimizer/optim'\n" - "Please verify your checkpoint layout." - ) - raise FileNotFoundError(msg) from e - - if optimizer is not None: - if optimizer_path is None: - raise ValueError( - "optimizer_path must be provided when loading optimizer state" - ) - print(f"Loading optimizer from {optimizer_path}") - load_optimizer( - optimizer=optimizer, - model=model, - weights_path=optimizer_path, - scheduler=scheduler, - ) diff --git a/nemo_rl/utils/venvs.py b/nemo_rl/utils/venvs.py index c5511473ea..80d1e12a0a 100644 --- a/nemo_rl/utils/venvs.py +++ b/nemo_rl/utils/venvs.py @@ -14,6 +14,7 @@ import logging import os import shlex +import shutil import subprocess import time from functools import lru_cache @@ -71,8 +72,6 @@ def create_local_venv( # Force rebuild if requested if force_rebuild and os.path.exists(venv_path): logger.info(f"Force rebuilding venv at {venv_path}") - import shutil - shutil.rmtree(venv_path) logger.info(f"Creating new venv at {venv_path}") @@ -89,6 +88,10 @@ def create_local_venv( # https://docs.astral.sh/uv/concepts/projects/config/#project-environment-path env["UV_PROJECT_ENVIRONMENT"] = venv_path + # Set TORCH_CUDA_ARCH_LIST for grouped_gemm & DeepEP installation. Hopper+ architectures are supported. + if "TORCH_CUDA_ARCH_LIST" not in env: + env["TORCH_CUDA_ARCH_LIST"] = "9.0 10.0 12.0" + # Split the py_executable into command and arguments exec_cmd = shlex.split(py_executable) # Command doesn't matter, since `uv` syncs the environment no matter the command. diff --git a/pyproject.toml b/pyproject.toml index 73eb392ba5..5be4aa2dda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,10 +57,13 @@ automodel = [ # Flash-attn version should be selected to satisfy both TE + vLLM requirements (xformers in particular) # https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108 # https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76 - "vllm==0.11.0", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved + "vllm==0.11.0", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/811 resolved "flash-attn==2.8.1", "mamba-ssm", "causal-conv1d", + "grouped_gemm @ git+https://github.com/fanshiqing/grouped_gemm@v1.1.4", + "transformer-engine[pytorch]==2.8.0", + "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", ] vllm = [ "cuda-python", @@ -68,7 +71,7 @@ vllm = [ # deep_ep also needs libibverbs-dev # sudo apt-get update # sudo apt-get install libibverbs-dev - "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@e3908bf5bd0cc6265bcb225d15cd8c996d4759ef", + "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", "vllm==0.11.0", "num2words>=0.5.14", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved @@ -227,7 +230,8 @@ deep_gemm = [{ requirement = "torch", match-runtime = true }] transformer-engine = [{ requirement = "torch", match-runtime = true }] transformer-engine-torch = [{ requirement = "torch", match-runtime = true }] mamba-ssm = [{ requirement = "torch", match-runtime = true }] -causal-conv1d = [{ requirement = "torch", match-runtime = true }] +causal-conv1d = ["torch", "setuptools"] +grouped-gemm = ["torch"] # Needed when building from source [[tool.uv.dependency-metadata]] @@ -249,7 +253,7 @@ requires-dist = ["torch", "packaging", "ninja", "causal-conv1d"] [[tool.uv.dependency-metadata]] name = "deep_ep" # This version has to match the version in the commit/rev/tag used -version = "v1.1.0+e3908bf" +version = "v1.2.1+bfded34" requires-dist = ["torch", "packaging", "ninja"] [[tool.uv.dependency-metadata]] diff --git a/pyrefly.toml b/pyrefly.toml index a1d64ad6fa..c9c1c42ecb 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -107,7 +107,6 @@ project-includes = [ "nemo_rl/utils/checkpoint.py", "nemo_rl/utils/config.py", "nemo_rl/utils/native_checkpoint.py", - "nemo_rl/utils/automodel_checkpoint.py", "nemo_rl/utils/nsys.py", "nemo_rl/utils/nvml.py", "nemo_rl/utils/packed_tensor.py", diff --git a/tests/unit/utils/test_automodel_checkpoint.py b/tests/unit/utils/test_automodel_checkpoint.py deleted file mode 100644 index 9906a1522f..0000000000 --- a/tests/unit/utils/test_automodel_checkpoint.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from tempfile import TemporaryDirectory -from unittest.mock import MagicMock, patch - -import pytest -import torch - -# Skip entire module if nemo_automodel is not available -pytest_plugins = [] -try: - import nemo_automodel # noqa: F401 -except ImportError: - pytest.skip("nemo_automodel not available", allow_module_level=True) - -from nemo_rl.utils.automodel_checkpoint import ( - detect_checkpoint_format, - load_checkpoint, - save_checkpoint, -) - - -class TestModel(torch.nn.Module): - """Simple test model with a forward method.""" - - def __init__(self): - super().__init__() - self.layers = torch.nn.ModuleList( - [ - torch.nn.Linear(4, 4), - torch.nn.LayerNorm(4), - torch.nn.ReLU(), - torch.nn.Linear(4, 1), - ] - ) - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -@pytest.fixture -def mock_model(): - """Create a simple mock model for testing.""" - return TestModel() - - -@pytest.fixture -def mock_optimizer(): - """Create a simple mock optimizer for testing.""" - model = torch.nn.Linear(4, 1) - return torch.optim.Adam(model.parameters()) - - -@pytest.mark.automodel -class TestDetectCheckpointFormat: - """Test the detect_checkpoint_format function.""" - - def test_directory_with_safetensors(self): - """Test detection for directories containing safetensors files.""" - with TemporaryDirectory() as tmp_dir: - # Create directory with safetensors files - os.makedirs(os.path.join(tmp_dir, "weights", "model")) - weights_path = os.path.join(tmp_dir, "weights", "model") - - # Create safetensors shard files - with open( - os.path.join( - weights_path, "shard-00001-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - with open( - os.path.join( - weights_path, "shard-00002-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(weights_path) - assert format_type == "safetensors" - assert is_peft == False - - def test_directory_with_dcp_format(self): - """Test detection for directories with DCP (Distributed Checkpoint) format.""" - with TemporaryDirectory() as tmp_dir: - # Create directory structure like: step_3/policy/optimizer/optim - optim_path = os.path.join(tmp_dir, "step_3", "policy", "optimizer", "optim") - os.makedirs(optim_path) - - # Create DCP files (.distcp + .metadata) - with open(os.path.join(optim_path, "__0_0.distcp"), "w") as f: - f.write("dummy dcp content") - with open(os.path.join(optim_path, "__1_0.distcp"), "w") as f: - f.write("dummy dcp content") - with open(os.path.join(optim_path, ".metadata"), "w") as f: - f.write("dummy metadata") - - format_type, is_peft = detect_checkpoint_format(optim_path) - assert format_type == "torch_save" # DCP uses torch_save format - assert is_peft == False - - def test_directory_with_torch_files(self): - """Test detection for directories containing torch save files.""" - with TemporaryDirectory() as tmp_dir: - model_path = os.path.join(tmp_dir, "model") - os.makedirs(model_path) - - # Create torch save files - with open(os.path.join(model_path, "pytorch_model.bin"), "w") as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(model_path) - assert format_type == "torch_save" - assert is_peft == False - - def test_peft_detection_in_filenames(self): - """Test PEFT detection from filenames within directories.""" - with TemporaryDirectory() as tmp_dir: - model_path = os.path.join(tmp_dir, "regular_model") - os.makedirs(model_path) - - # Create file with adapter pattern in name - with open(os.path.join(model_path, "adapter_model.safetensors"), "w") as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(model_path) - assert format_type == "safetensors" - assert is_peft == True # Should detect adapter in filename - - def test_default_fallback(self): - """Test default behavior for non-existent directories.""" - # Non-existent directory should default to safetensors, no PEFT - format_type, is_peft = detect_checkpoint_format("/non/existent/directory") - assert format_type == "safetensors" - assert is_peft == False - - def test_expected_structure(self): - """Test with the expected folder structure from the user.""" - with TemporaryDirectory() as tmp_dir: - # Create the expected structure: step_3/policy/weights/model - weights_path = os.path.join(tmp_dir, "step_3", "policy", "weights", "model") - os.makedirs(weights_path) - - # Create safetensors shard files as in the example - with open( - os.path.join( - weights_path, "shard-00001-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - with open( - os.path.join( - weights_path, "shard-00002-model-00001-of-00001.safetensors" - ), - "w", - ) as f: - f.write("dummy content") - - format_type, is_peft = detect_checkpoint_format(weights_path) - assert format_type == "safetensors" - assert is_peft == False - - """Test the save_checkpoint function.""" - - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") - def test_save_model_only(self, mock_save_optimizer, mock_save_model, mock_model): - """Test saving model weights only.""" - with TemporaryDirectory() as tmp_dir: - weights_path = os.path.join(tmp_dir, "weights") - os.makedirs(os.path.dirname(weights_path), exist_ok=True) - - # Save checkpoint - save_checkpoint( - model=mock_model, - weights_path=weights_path, - model_save_format="safetensors", - is_peft=False, - ) - - # Verify save_model was called correctly - mock_save_model.assert_called_once() - call_args = mock_save_model.call_args - assert call_args[1]["model"] is mock_model - assert call_args[1]["weights_path"] == weights_path - assert ( - call_args[1]["checkpoint_config"].model_save_format.value - == "safetensors" - ) - assert call_args[1]["checkpoint_config"].is_peft == False - - # Verify optimizer saving was not called - mock_save_optimizer.assert_not_called() - - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - @patch("nemo_rl.utils.automodel_checkpoint.save_optimizer") - def test_save_with_optimizer( - self, mock_save_optimizer, mock_save_model, mock_model, mock_optimizer - ): - """Test saving model and optimizer weights.""" - with TemporaryDirectory() as tmp_dir: - weights_path = os.path.join(tmp_dir, "model", "weights") - optimizer_path = os.path.join(tmp_dir, "optimizer", "optim") - os.makedirs(os.path.dirname(weights_path)) - os.makedirs(os.path.dirname(optimizer_path)) - - # Save checkpoint with optimizer - save_checkpoint( - model=mock_model, - weights_path=weights_path, - optimizer=mock_optimizer, - optimizer_path=optimizer_path, - model_save_format="torch_save", - is_peft=True, - ) - - # Verify both model and optimizer saving were called - mock_save_model.assert_called_once() - mock_save_optimizer.assert_called_once() - - # Check optimizer call args - opt_call_args = mock_save_optimizer.call_args - assert opt_call_args[1]["optimizer"] is mock_optimizer - assert opt_call_args[1]["model"] is mock_model - assert opt_call_args[1]["weights_path"] == optimizer_path - - @pytest.mark.automodel - @patch("nemo_rl.utils.automodel_checkpoint.save_model") - def test_save_with_tokenizer(self, mock_save_model, mock_model): - """Test saving with tokenizer.""" - with TemporaryDirectory() as tmp_dir: - weights_path = os.path.join(tmp_dir, "model", "weights") - tokenizer_path = os.path.join(tmp_dir, "tokenizer") - os.makedirs(os.path.dirname(weights_path)) - os.makedirs(tokenizer_path) - - # Create mock tokenizer - mock_tokenizer = MagicMock() - - # Save checkpoint with tokenizer - save_checkpoint( - model=mock_model, - weights_path=weights_path, - tokenizer=mock_tokenizer, - tokenizer_path=tokenizer_path, - ) - - # Verify tokenizer.save_pretrained was called - mock_tokenizer.save_pretrained.assert_called_once_with(tokenizer_path) - - -@pytest.fixture -def mock_experiment(): - """Create a real model, optimizer, and scheduler for integration testing.""" - model = TestModel() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) - return model, optimizer, scheduler - - -def check_dict_equality(dict1, dict2): - """Recursively check equality of two dictionaries""" - for k in dict1.keys(): - if isinstance(dict1[k], dict): - check_dict_equality(dict1[k], dict2[k]) - elif isinstance(dict1[k], torch.Tensor): - assert torch.allclose(dict1[k], dict2[k]) - else: - assert dict1[k] == dict2[k] - - -@pytest.mark.automodel -class TestSaveLoadIntegration: - """Integration tests that actually save and load checkpoints.""" - - def test_save_and_load_model_only_safetensors(self, mock_experiment): - """Test saving and loading model weights only with safetensors format.""" - test_model, _, _ = mock_experiment - original_state_dict = test_model.state_dict() - - with TemporaryDirectory() as tmp_dir: - weights_path = os.path.join(tmp_dir, "test_model") - - # Save checkpoint - save_checkpoint( - model=test_model, - weights_path=weights_path, - model_save_format="safetensors", - ) - - # Verify files are created - assert os.path.exists(weights_path) - files = os.listdir(os.path.join(weights_path, "model")) - assert any(f.endswith(".safetensors") for f in files) - - # Create a new model with different weights - new_model = TestModel() - # Initialize with different values - for param in new_model.parameters(): - param.data.fill_(999.0) - - # Load the checkpoint - load_checkpoint(model=new_model, weights_path=weights_path) - - # Verify the weights match the original - check_dict_equality(new_model.state_dict(), original_state_dict) - - def test_save_and_load_model_only_torch_save(self, mock_experiment): - """Test saving and loading model weights only with torch_save format.""" - test_model, _, _ = mock_experiment - original_state_dict = test_model.state_dict() - - with TemporaryDirectory() as tmp_dir: - weights_path = os.path.join(tmp_dir, "test_model") - - # Save checkpoint - save_checkpoint( - model=test_model, - weights_path=weights_path, - model_save_format="torch_save", - ) - - # Verify files are created - assert os.path.exists(weights_path) - files = os.listdir(os.path.join(weights_path, "model")) - assert any(f.endswith(".distcp") for f in files) - - # Create a new model with different weights - new_model = TestModel() - # Initialize with different values - for param in new_model.parameters(): - param.data.fill_(999.0) - - # Load the checkpoint - load_checkpoint(model=new_model, weights_path=weights_path) - - # Verify the weights match the original - check_dict_equality(new_model.state_dict(), original_state_dict) - - def test_save_and_load_model_and_optimizer(self, mock_experiment): - """Test saving and loading both model and optimizer.""" - test_model, optimizer, scheduler = mock_experiment - - # Take some optimization steps to change optimizer state - for _ in range(5): - loss = torch.nn.functional.mse_loss( - test_model(torch.randn(2, 4)), torch.randn(2, 1) - ) - optimizer.zero_grad() - loss.backward() - optimizer.step() - scheduler.step() - - original_model_state = test_model.state_dict() - original_optimizer_state = optimizer.state_dict() - original_scheduler_state = scheduler.state_dict() - - with TemporaryDirectory() as tmp_dir: - model_path = os.path.join(tmp_dir, "model_and_optimizer", "model_path") - optimizer_path = os.path.join(tmp_dir, "model_and_optimizer", "optimizer") - os.makedirs(os.path.dirname(model_path), exist_ok=True) - os.makedirs(os.path.dirname(optimizer_path), exist_ok=True) - - # Save checkpoint - save_checkpoint( - model=test_model, - weights_path=model_path, - optimizer=optimizer, - scheduler=scheduler, - optimizer_path=optimizer_path, - ) - - # Verify files are created - assert os.path.exists(model_path) - assert os.path.exists(optimizer_path) - - # Create new model, optimizer, and scheduler with different state - new_model = TestModel() - new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) - new_scheduler = torch.optim.lr_scheduler.StepLR( - new_optimizer, step_size=4, gamma=0.2 - ) - - # Initialize with different values - for param in new_model.parameters(): - param.data.fill_(999.0) - - # Load the checkpoint - load_checkpoint( - model=new_model, - weights_path=model_path, - optimizer=new_optimizer, - scheduler=new_scheduler, - optimizer_path=optimizer_path, - ) - - # Verify all states match the original - check_dict_equality(new_model.state_dict(), original_model_state) - check_dict_equality(new_optimizer.state_dict(), original_optimizer_state) - assert new_scheduler.state_dict() == original_scheduler_state diff --git a/uv.lock b/uv.lock index 03c163b5ec..8ae775957d 100644 --- a/uv.lock +++ b/uv.lock @@ -37,7 +37,7 @@ requires-dist = ["torch", "packaging", "ninja"] [[manifest.dependency-metadata]] name = "deep-ep" -version = "1.1.0+e3908bf" +version = "1.2.1+bfded34" requires-dist = ["torch", "packaging", "ninja"] [[manifest.dependency-metadata]] @@ -350,6 +350,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/22/91616fe707a5c5510de2cac9b046a30defe7007ba8a0c04f9c08f27df312/audioop_lts-0.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:b492c3b040153e68b9fdaff5913305aaaba5bb433d8a7f73d5cf6a64ed3cc1dd", size = 25206, upload-time = "2025-08-05T16:43:16.444Z" }, ] +[[package]] +name = "audioread" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "standard-aifc", marker = "python_full_version >= '3.13'" }, + { name = "standard-sunau", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/4a/874ecf9b472f998130c2b5e145dcdb9f6131e84786111489103b66772143/audioread-3.1.0.tar.gz", hash = "sha256:1c4ab2f2972764c896a8ac61ac53e261c8d29f0c6ccd652f84e18f08a4cab190", size = 20082, upload-time = "2025-10-26T19:44:13.484Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/16/fbe8e1e185a45042f7cd3a282def5bb8d95bb69ab9e9ef6a5368aa17e426/audioread-3.1.0-py3-none-any.whl", hash = "sha256:b30d1df6c5d3de5dcef0fb0e256f6ea17bdcf5f979408df0297d8a408e2971b4", size = 23143, upload-time = "2025-10-26T19:44:12.016Z" }, +] + [[package]] name = "av" version = "15.0.0" @@ -407,19 +420,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/cd/30110dc0ffcf3b131156077b90e9f60ed75711223f306da4db08eff8403b/beautifulsoup4-4.13.4-py3-none-any.whl", hash = "sha256:9bbbb14bfde9d79f38b8cd5f8c7c85f4b8f2523190ebed90e950a8dea4cb1c4b", size = 187285, upload-time = "2025-04-15T17:05:12.221Z" }, ] -[[package]] -name = "bitsandbytes" -version = "0.45.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, - { name = "torch", version = "2.8.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/07/b7/cb5ce4d1a382cf53c19ef06c5fc29e85f5e129b4da6527dd207d90a5b8ad/bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:a5453f30cc6aab6ccaac364e6bf51a7808d3da5f71763dffeb6d9694c59136e4", size = 76059261, upload-time = "2025-04-07T13:32:52.573Z" }, - { url = "https://files.pythonhosted.org/packages/a6/4c/77b535e025ce780d2ada8271c1e481fb7337c1df2588a52fe1c9bd87d2e8/bitsandbytes-0.45.5-py3-none-win_amd64.whl", hash = "sha256:ed1c61b91d989d6a33fd05737d6edbf5086d8ebc89235ee632c7a19144085da2", size = 75430204, upload-time = "2025-04-07T13:32:57.553Z" }, -] - [[package]] name = "blake3" version = "1.0.5" @@ -1147,12 +1147,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/57/ecc9ae29fa5b2d90107cd1d9bf8ed19aacb74b2264d986ae9d44fe9bdf87/debugpy-1.8.16-py2.py3-none-any.whl", hash = "sha256:19c9521962475b87da6f673514f7fd610328757ec993bf7ec0d8c96f9a325f9e", size = 5287700, upload-time = "2025-08-06T18:00:42.333Z" }, ] +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + [[package]] name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -1161,8 +1170,8 @@ wheels = [ [[package]] name = "deep-ep" -version = "1.1.0+e3908bf" -source = { git = "https://github.com/deepseek-ai/DeepEP.git?rev=e3908bf5bd0cc6265bcb225d15cd8c996d4759ef#e3908bf5bd0cc6265bcb225d15cd8c996d4759ef" } +version = "1.2.1+bfded34" +source = { git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480#bfded34800dfec415b71503f8205181de90b2480" } dependencies = [ { name = "ninja" }, { name = "packaging" }, @@ -1217,6 +1226,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/ae/afb1487556e2dc827a17097aac8158a25b433a345386f0e249f6d2694ccb/devtools-0.12.2-py3-none-any.whl", hash = "sha256:c366e3de1df4cdd635f1ad8cbcd3af01a384d7abda71900e68d43b04eb6aaca7", size = 19411, upload-time = "2023-09-03T16:56:59.049Z" }, ] +[[package]] +name = "diffusers" +version = "0.35.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "huggingface-hub" }, + { name = "importlib-metadata" }, + { name = "numpy" }, + { name = "pillow" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/03/68/288ca23c7c05c73e87ffe5efffc282400ac9b017f7a9bb03883f4310ea15/diffusers-0.35.2.tar.gz", hash = "sha256:30ecd552303edfcfe1724573c3918a8462ee3ab4d529bdbd4c0045f763affded", size = 3366711, upload-time = "2025-10-15T04:05:17.213Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/2e/38d9824f8c6bb048c5ba21c6d4da54c29c162a46b58b3ef907a360a76d3e/diffusers-0.35.2-py3-none-any.whl", hash = "sha256:d50d5e74fdd6dcf55e5c1d304bc52cc7c2659abd1752740d736d7b54078b4db5", size = 4121649, upload-time = "2025-10-15T04:05:14.391Z" }, +] + [[package]] name = "dill" version = "0.3.8" @@ -1432,6 +1460,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/14/42b2651a2f46b022ccd948bca9f2d5af0fd8929c4eec235b8d6d844fbe67/filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d", size = 15988, upload-time = "2025-08-14T16:56:01.633Z" }, ] +[[package]] +name = "fla-core" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.8.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "sys_platform != 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/3d/79a9d5c8cd973c86f35403931031787dfc6cc97d838a42d4c62e8cbbb66f/fla_core-0.4.0.tar.gz", hash = "sha256:d975022b074e97bfd086dc6b767dccb35e27a9fe36f26f3b26b1c2b68b36a1c8", size = 316316, upload-time = "2025-10-27T08:18:51.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/0c/d52ab65e9c163631895052d70d4111f8530ca52f45beb0895378d1a2a8b5/fla_core-0.4.0-py3-none-any.whl", hash = "sha256:5396f36a9838c99f9e45c70e88e2e0b26688f719d07d2ddd61be16d29327f4ea", size = 438519, upload-time = "2025-10-27T08:18:49.561Z" }, +] + [[package]] name = "flash-attn" version = "2.8.1" @@ -1446,6 +1488,19 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/e8/6d/7066d160bdffa2f9da29a8c3957f266b17a03ca0b3bdc8fdae86d9881fe7/flash_attn-2.8.1.tar.gz", hash = "sha256:0ff003899fcb244f357905b04f622d5c9736887126dd6675f8f4bc52954e3923", size = 8166563, upload-time = "2025-07-10T05:16:39.729Z" } +[[package]] +name = "flash-linear-attention" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fla-core" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/44/9a/e546815da2bf149e0af58449ff1ca10074165af4384febead438ad46f74c/flash_linear_attention-0.4.0.tar.gz", hash = "sha256:c5d2bf6e1a766af3a4426f07f710b0b87809f7218de21eb313314be6ff1b0dba", size = 157646, upload-time = "2025-10-27T08:18:52.445Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/76/4f716c953608204c970de7cd4045db1af02643d7f19c94a49254834b7563/flash_linear_attention-0.4.0-py3-none-any.whl", hash = "sha256:50c97163f7cb64dc53585194ef36af44d2a6bc545227c4f73bb3ba9062630f1a", size = 290439, upload-time = "2025-10-27T08:18:50.589Z" }, +] + [[package]] name = "flask" version = "3.1.2" @@ -1617,6 +1672,18 @@ http = [ { name = "aiohttp" }, ] +[[package]] +name = "ftfy" +version = "6.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/d3/8650919bc3c7c6e90ee3fa7fd618bf373cbbe55dff043bd67353dbb20cd8/ftfy-6.3.1.tar.gz", hash = "sha256:9b3c3d90f84fb267fe64d375a07b7f8912d817cf86009ae134aa03e1819506ec", size = 308927, upload-time = "2024-10-26T00:50:35.149Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/6e/81d47999aebc1b155f81eca4477a616a70f238a2549848c38983f3c22a82/ftfy-6.3.1-py3-none-any.whl", hash = "sha256:7c70eb532015cd2f9adb53f101fb6c7945988d023a085d127d1573dc49dd0083", size = 44821, upload-time = "2024-10-26T00:50:33.425Z" }, +] + [[package]] name = "gguf" version = "0.17.1" @@ -1886,6 +1953,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/27/3d6dcadc8a3214d8522c1e7f6a19554e33659be44546d44a2f7572ac7d2a/groovy-0.1.2-py3-none-any.whl", hash = "sha256:7f7975bab18c729a257a8b1ae9dcd70b7cafb1720481beae47719af57c35fa64", size = 14090, upload-time = "2025-02-28T20:24:55.152Z" }, ] +[[package]] +name = "grouped-gemm" +version = "1.1.4" +source = { git = "https://github.com/fanshiqing/grouped_gemm?rev=v1.1.4#172fada89fa7364fe5d026b3a0dfab58b591ffdd" } +dependencies = [ + { name = "absl-py" }, + { name = "numpy" }, + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, + { name = "torch", version = "2.8.0+cu129", source = { registry = "https://download.pytorch.org/whl/cu129" }, marker = "sys_platform != 'darwin'" }, +] + [[package]] name = "grpcio" version = "1.74.0" @@ -1935,6 +2013,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "h2" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, +] + [[package]] name = "hatchling" version = "1.27.0" @@ -1965,6 +2056,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/d3/0aaf279f4f3dea58e99401b92c31c0f752924ba0e6c7d7bb07b1dbd7f35e/hf_xet-1.1.8-cp37-abi3-win_amd64.whl", hash = "sha256:4171f31d87b13da4af1ed86c98cf763292e4720c088b4957cf9d564f92904ca9", size = 2801689, upload-time = "2025-08-18T22:01:04.81Z" }, ] +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -2015,6 +2115,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[package.optional-dependencies] +http2 = [ + { name = "h2" }, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -2048,6 +2153,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, ] +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, +] + [[package]] name = "identify" version = "2.6.13" @@ -2066,6 +2180,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "imageio-ffmpeg" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bd/c3343c721f2a1b0c9fc71c1aebf1966a3b7f08c2eea8ed5437a2865611d6/imageio_ffmpeg-0.6.0.tar.gz", hash = "sha256:e2556bed8e005564a9f925bb7afa4002d82770d6b08825078b7697ab88ba1755", size = 25210, upload-time = "2025-01-16T21:34:32.747Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/58/87ef68ac83f4c7690961bce288fd8e382bc5f1513860fc7f90a9c1c1c6bf/imageio_ffmpeg-0.6.0-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.whl", hash = "sha256:9d2baaf867088508d4a3458e61eeb30e945c4ad8016025545f66c4b5aaef0a61", size = 24932969, upload-time = "2025-01-16T21:34:20.464Z" }, + { url = "https://files.pythonhosted.org/packages/40/5c/f3d8a657d362cc93b81aab8feda487317da5b5d31c0e1fdfd5e986e55d17/imageio_ffmpeg-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b1ae3173414b5fc5f538a726c4e48ea97edc0d2cdc11f103afee655c463fa742", size = 21113891, upload-time = "2025-01-16T21:34:00.277Z" }, + { url = "https://files.pythonhosted.org/packages/33/e7/1925bfbc563c39c1d2e82501d8372734a5c725e53ac3b31b4c2d081e895b/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1d47bebd83d2c5fc770720d211855f208af8a596c82d17730aa51e815cdee6dc", size = 25632706, upload-time = "2025-01-16T21:33:53.475Z" }, + { url = "https://files.pythonhosted.org/packages/a0/2d/43c8522a2038e9d0e7dbdf3a61195ecc31ca576fb1527a528c877e87d973/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c7e46fcec401dd990405049d2e2f475e2b397779df2519b544b8aab515195282", size = 29498237, upload-time = "2025-01-16T21:34:13.726Z" }, + { url = "https://files.pythonhosted.org/packages/a0/13/59da54728351883c3c1d9fca1710ab8eee82c7beba585df8f25ca925f08f/imageio_ffmpeg-0.6.0-py3-none-win32.whl", hash = "sha256:196faa79366b4a82f95c0f4053191d2013f4714a715780f0ad2a68ff37483cc2", size = 19652251, upload-time = "2025-01-16T21:34:06.812Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c6/fa760e12a2483469e2bf5058c5faff664acf66cadb4df2ad6205b016a73d/imageio_ffmpeg-0.6.0-py3-none-win_amd64.whl", hash = "sha256:02fa47c83703c37df6bfe4896aab339013f62bf02c5ebf2dce6da56af04ffc0a", size = 31246824, upload-time = "2025-01-16T21:34:28.6Z" }, +] + [[package]] name = "imagesize" version = "1.4.1" @@ -2327,6 +2455,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/60/dfbbf40e3a371388c0e03ff65b01319b7d4023e883df6d7261125772ffdc/latex2sympy2_extended-1.10.2-py3-none-any.whl", hash = "sha256:f910442c5b02a466c1046f47d05cc5285181068b882399281f30102715337fb7", size = 207855, upload-time = "2025-07-02T15:26:04.88Z" }, ] +[[package]] +name = "lazy-loader" +version = "0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6b/c875b30a1ba490860c93da4cabf479e03f584eba06fe5963f6f6644653d8/lazy_loader-0.4.tar.gz", hash = "sha256:47c75182589b91a4e1a85a136c074285a5ad4d9f39c63e0d7fb76391c4574cd1", size = 15431, upload-time = "2024-04-05T13:03:12.261Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/60/d497a310bde3f01cb805196ac61b7ad6dc5dcf8dce66634dc34364b20b4f/lazy_loader-0.4-py3-none-any.whl", hash = "sha256:342aa8e14d543a154047afb4ba8ef17f5563baad3fc610d7b15b213b0f119efc", size = 12097, upload-time = "2024-04-05T13:03:10.514Z" }, +] + +[[package]] +name = "librosa" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioread" }, + { name = "decorator" }, + { name = "joblib" }, + { name = "lazy-loader" }, + { name = "msgpack" }, + { name = "numba" }, + { name = "numpy" }, + { name = "pooch" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "soundfile" }, + { name = "soxr" }, + { name = "standard-aifc", marker = "python_full_version >= '3.13'" }, + { name = "standard-sunau", marker = "python_full_version >= '3.13'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/36/360b5aafa0238e29758729e9486c6ed92a6f37fa403b7875e06c115cdf4a/librosa-0.11.0.tar.gz", hash = "sha256:f5ed951ca189b375bbe2e33b2abd7e040ceeee302b9bbaeeffdfddb8d0ace908", size = 327001, upload-time = "2025-03-11T15:09:54.884Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/ba/c63c5786dfee4c3417094c4b00966e61e4a63efecee22cb7b4c0387dda83/librosa-0.11.0-py3-none-any.whl", hash = "sha256:0b6415c4fd68bff4c29288abe67c6d80b587e0e1e2cfb0aad23e4559504a7fa1", size = 260749, upload-time = "2025-03-11T15:09:52.982Z" }, +] + [[package]] name = "liger-kernel" version = "0.6.2" @@ -3083,10 +3249,14 @@ wheels = [ name = "nemo-automodel" source = { editable = "3rdparty/Automodel-workspace/Automodel" } dependencies = [ - { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "datasets" }, + { name = "diffusers" }, + { name = "ftfy" }, + { name = "imageio-ffmpeg" }, { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "megatron-fsdp" }, + { name = "mlflow" }, + { name = "opencv-python-headless" }, { name = "pybind11" }, { name = "pyyaml" }, { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -3098,6 +3268,25 @@ dependencies = [ ] [package.optional-dependencies] +all = [ + { name = "backoff" }, + { name = "flash-linear-attention" }, + { name = "mistral-common", extra = ["opencv"] }, + { name = "numba" }, + { name = "numpy" }, + { name = "perceptron" }, + { name = "pillow" }, + { name = "qwen-omni-utils" }, + { name = "qwen-vl-utils", extra = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, + { name = "sentencepiece" }, + { name = "timm" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, +] +extra = [ + { name = "flash-linear-attention" }, + { name = "perceptron" }, + { name = "sentencepiece" }, +] fa = [ { name = "flash-attn" }, ] @@ -3110,10 +3299,10 @@ vlm = [ { name = "numba" }, { name = "numpy" }, { name = "pillow" }, - { name = "qwen-vl-utils", extra = ["decord"] }, + { name = "qwen-omni-utils" }, + { name = "qwen-vl-utils", extra = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, { name = "timm" }, - { name = "torchcodec" }, - { name = "transformers" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" }, ] [package.dev-dependencies] @@ -3147,36 +3336,45 @@ test = [ [package.metadata] requires-dist = [ { name = "backoff", marker = "extra == 'vlm'" }, - { name = "bitsandbytes", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = "==0.45.5" }, { name = "datasets", specifier = ">=4.0.0" }, + { name = "diffusers" }, { name = "flash-attn", marker = "extra == 'fa'", specifier = "<=2.8.3" }, + { name = "flash-linear-attention", marker = "extra == 'extra'" }, + { name = "ftfy" }, + { name = "imageio-ffmpeg" }, { name = "liger-kernel", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'", specifier = ">=0.5.9" }, { name = "megatron-fsdp" }, { name = "mistral-common", extras = ["opencv"], marker = "extra == 'vlm'" }, + { name = "mlflow" }, + { name = "nemo-automodel", extras = ["extra"], marker = "extra == 'all'", editable = "3rdparty/Automodel-workspace/Automodel" }, + { name = "nemo-automodel", extras = ["vlm"], marker = "extra == 'all'", editable = "3rdparty/Automodel-workspace/Automodel" }, { name = "numba", marker = "extra == 'vlm'" }, { name = "numpy", marker = "extra == 'vlm'" }, + { name = "opencv-python-headless", specifier = "==4.10.0.84" }, + { name = "perceptron", marker = "extra == 'extra'" }, { name = "pillow", marker = "extra == 'vlm'" }, { name = "pybind11" }, { name = "pyyaml" }, - { name = "qwen-vl-utils", extras = ["decord"], marker = "extra == 'vlm'" }, - { name = "timm", marker = "extra == 'vlm'", specifier = "==1.0.16" }, - { name = "torch", marker = "sys_platform != 'darwin'", specifier = "<=2.8.0", index = "https://download.pytorch.org/whl/cu129" }, - { name = "torch", marker = "sys_platform == 'darwin'", specifier = "<=2.8.0", index = "https://pypi.org/simple" }, + { name = "qwen-omni-utils", marker = "extra == 'vlm'" }, + { name = "qwen-vl-utils", extras = ["decord"], marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, + { name = "sentencepiece", marker = "extra == 'extra'" }, + { name = "timm", marker = "extra == 'vlm'", specifier = "<=1.0.22" }, + { name = "torch", marker = "sys_platform != 'darwin'", specifier = "<=2.9.0", index = "https://download.pytorch.org/whl/cu129" }, + { name = "torch", marker = "sys_platform == 'darwin'", specifier = "<=2.9.0", index = "https://pypi.org/simple" }, { name = "torchao" }, - { name = "torchcodec", marker = "extra == 'vlm'" }, + { name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" }, { name = "torchdata" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'moe'", specifier = "==2.8.0" }, - { name = "transformers", specifier = "<=4.55.4" }, - { name = "transformers", marker = "extra == 'vlm'", specifier = "<=4.55.4" }, + { name = "transformers", specifier = "<=4.57.1" }, { name = "wandb" }, ] -provides-extras = ["vlm", "fa", "moe"] +provides-extras = ["vlm", "fa", "moe", "extra", "all"] [package.metadata.requires-dev] build = [ { name = "setuptools" }, - { name = "torch", marker = "sys_platform != 'darwin'", specifier = "<=2.8.0", index = "https://download.pytorch.org/whl/cu129" }, - { name = "torch", marker = "sys_platform == 'darwin'", specifier = "<=2.8.0", index = "https://pypi.org/simple" }, + { name = "torch", marker = "sys_platform != 'darwin'", specifier = "<=2.9.0", index = "https://download.pytorch.org/whl/cu129" }, + { name = "torch", marker = "sys_platform == 'darwin'", specifier = "<=2.9.0", index = "https://pypi.org/simple" }, ] dev = [{ name = "cut-cross-entropy", git = "https://github.com/apple/ml-cross-entropy.git?rev=87a86ab" }] docs = [ @@ -3242,9 +3440,12 @@ dependencies = [ [package.optional-dependencies] automodel = [ { name = "causal-conv1d" }, + { name = "deep-ep" }, { name = "flash-attn" }, + { name = "grouped-gemm" }, { name = "mamba-ssm" }, { name = "nemo-automodel" }, + { name = "transformer-engine", extra = ["pytorch"] }, { name = "vllm" }, ] mcore = [ @@ -3317,11 +3518,13 @@ requires-dist = [ { name = "cuda-python", marker = "extra == 'vllm'" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, - { name = "deep-ep", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=e3908bf5bd0cc6265bcb225d15cd8c996d4759ef" }, + { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, + { name = "deep-ep", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-gemm", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepGEMM.git?rev=7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" }, { name = "flash-attn", marker = "extra == 'automodel'", specifier = "==2.8.1" }, { name = "flash-attn", marker = "extra == 'mcore'", specifier = "==2.8.1" }, { name = "flash-attn", marker = "extra == 'vllm'", specifier = "==2.8.1" }, + { name = "grouped-gemm", marker = "extra == 'automodel'", git = "https://github.com/fanshiqing/grouped_gemm?rev=v1.1.4" }, { name = "hydra-core" }, { name = "mamba-ssm", marker = "extra == 'automodel'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, { name = "mamba-ssm", marker = "extra == 'vllm'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, @@ -3355,6 +3558,7 @@ requires-dist = [ { name = "torchdata" }, { name = "torchvision", marker = "sys_platform != 'darwin'", specifier = ">=0.22.0", index = "https://download.pytorch.org/whl/cu129" }, { name = "torchvision", marker = "sys_platform == 'darwin'", specifier = ">=0.22.0", index = "https://pypi.org/simple" }, + { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'automodel'", specifier = "==2.8.0" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'mcore'", specifier = "==2.8.0" }, { name = "transformers", specifier = ">=4.55.4" }, { name = "triton", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')", index = "https://download.pytorch.org/whl/cu129" }, @@ -4186,6 +4390,24 @@ requires-dist = [ { name = "yappi" }, ] +[[package]] +name = "perceptron" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "httpx", extra = ["http2"] }, + { name = "numpy" }, + { name = "pillow" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/60/85db2243d8b550823603d8f9c5845b0dd0f01074e9aabf0b2af0c4f52565/perceptron-0.1.4.tar.gz", hash = "sha256:62fd190efb74925e2cc33c0cd38761e19959be3bdb7b24fbf9e3386d6961f690", size = 78116, upload-time = "2025-11-12T20:00:28.024Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/17/b7cb1a10ebb0a9a4c9fbcd96a28b43d44e08a90f620bab07e644a658d2f1/perceptron-0.1.4-py3-none-any.whl", hash = "sha256:f490a6df6c15167e91e1a528601cae98ce99a30991cf792f9ef83ebc15d335c4", size = 57421, upload-time = "2025-11-12T20:00:26.395Z" }, +] + [[package]] name = "pillow" version = "11.3.0" @@ -4283,6 +4505,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "pooch" +version = "1.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "platformdirs" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/77/b3d3e00c696c16cf99af81ef7b1f5fe73bd2a307abca41bd7605429fe6e5/pooch-1.8.2.tar.gz", hash = "sha256:76561f0de68a01da4df6af38e9955c4c9d1a5c90da73f7e40276a5728ec83d10", size = 59353, upload-time = "2024-06-06T16:53:46.224Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574, upload-time = "2024-06-06T16:53:44.343Z" }, +] + [[package]] name = "pre-commit" version = "4.3.0" @@ -5079,6 +5315,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/32/3836ed85947b06f1d67c07ce16c00b0cf8c053ab0b249d234f9f81ff95ff/pyzmq-27.0.1-cp314-cp314t-win_arm64.whl", hash = "sha256:0fc24bf45e4a454e55ef99d7f5c8b8712539200ce98533af25a5bfa954b6b390", size = 575098, upload-time = "2025-08-03T05:04:27.974Z" }, ] +[[package]] +name = "qwen-omni-utils" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "av" }, + { name = "librosa" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/b1/cc58b03b5eadddc0812cef884d013ed6cc66b09f9b0f5b45123f89dcd056/qwen_omni_utils-0.0.8.tar.gz", hash = "sha256:b5808673e1455f4115cb784a62cdc8e8616576221a01fc738610b0f9268cb33c", size = 8145, upload-time = "2025-06-12T11:02:05.411Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/b1/dcdd69246a3c3c3bd6f6ced58e2307b3afbd894c4412c29fd49dd897e562/qwen_omni_utils-0.0.8-py3-none-any.whl", hash = "sha256:c42bcc633fbfd84d565ff0de9d45fae68a6b57a9b7b97a4b77eda71a0d3ee73a", size = 9218, upload-time = "2025-06-12T11:02:03.981Z" }, +] + [[package]] name = "qwen-vl-utils" version = "0.0.11" @@ -5096,7 +5348,7 @@ wheels = [ [package.optional-dependencies] decord = [ - { name = "decord" }, + { name = "decord", marker = "(platform_machine != 'aarch64' and sys_platform != 'darwin') or sys_platform == 'win32'" }, ] [[package]] @@ -6016,6 +6268,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" }, ] +[[package]] +name = "standard-aifc" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, + { name = "standard-chunk", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/53/6050dc3dde1671eb3db592c13b55a8005e5040131f7509cef0215212cb84/standard_aifc-3.13.0.tar.gz", hash = "sha256:64e249c7cb4b3daf2fdba4e95721f811bde8bdfc43ad9f936589b7bb2fae2e43", size = 15240, upload-time = "2024-10-30T16:01:31.772Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/52/5fbb203394cc852334d1575cc020f6bcec768d2265355984dfd361968f36/standard_aifc-3.13.0-py3-none-any.whl", hash = "sha256:f7ae09cc57de1224a0dd8e3eb8f73830be7c3d0bc485de4c1f82b4a7f645ac66", size = 10492, upload-time = "2024-10-30T16:01:07.071Z" }, +] + +[[package]] +name = "standard-chunk" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/06/ce1bb165c1f111c7d23a1ad17204d67224baa69725bb6857a264db61beaf/standard_chunk-3.13.0.tar.gz", hash = "sha256:4ac345d37d7e686d2755e01836b8d98eda0d1a3ee90375e597ae43aaf064d654", size = 4672, upload-time = "2024-10-30T16:18:28.326Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/90/a5c1084d87767d787a6caba615aa50dc587229646308d9420c960cb5e4c0/standard_chunk-3.13.0-py3-none-any.whl", hash = "sha256:17880a26c285189c644bd5bd8f8ed2bdb795d216e3293e6dbe55bbd848e2982c", size = 4944, upload-time = "2024-10-30T16:18:26.694Z" }, +] + +[[package]] +name = "standard-sunau" +version = "3.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/e3/ce8d38cb2d70e05ffeddc28bb09bad77cfef979eb0a299c9117f7ed4e6a9/standard_sunau-3.13.0.tar.gz", hash = "sha256:b319a1ac95a09a2378a8442f403c66f4fd4b36616d6df6ae82b8e536ee790908", size = 9368, upload-time = "2024-10-30T16:01:41.626Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/ae/e3707f6c1bc6f7aa0df600ba8075bfb8a19252140cd595335be60e25f9ee/standard_sunau-3.13.0-py3-none-any.whl", hash = "sha256:53af624a9529c41062f4c2fd33837f297f3baa196b0cfceffea6555654602622", size = 7364, upload-time = "2024-10-30T16:01:28.003Z" }, +] + [[package]] name = "starlette" version = "0.47.2" @@ -6422,9 +6708,7 @@ name = "torchcodec" version = "0.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/b3/11326a0e7a3c803a95975cfce4ac88fa4ea1a0d432bb876081046c5a5554/torchcodec-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fba260145a239b5afe13336e3a5bc1b089c9c31a073e9a7c2026d4cbd853fdd9", size = 3482584, upload-time = "2025-08-07T08:51:32.535Z" }, { url = "https://files.pythonhosted.org/packages/a7/d1/3f90561df013f6a015ef19de22726b64073fee405f53d3c4b8255ab05a67/torchcodec-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:fdef91a17fb1f1a159ce23710324a9a4e6d6a885275de73700f94a9ad562c6b2", size = 1370954, upload-time = "2025-08-07T08:51:15.021Z" }, - { url = "https://files.pythonhosted.org/packages/87/d0/0b5dd42652e4527d578e1d6239dbb907bf83e502115e517b83a55d8b7f8b/torchcodec-0.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:de20cab5df7fa7cdd74ec1dc0d508324685573f86de6789f0ebb860b7ea20b33", size = 3446017, upload-time = "2025-08-07T08:51:34.484Z" }, { url = "https://files.pythonhosted.org/packages/97/62/a938334e39101d4304619b90847d8aef7d1c607c6bcf33638f72931ae990/torchcodec-0.6.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:46dab701a2d809e975a8b07d7ee47ed34f1d903511e374c74cfc1de6a5ab0e3f", size = 1374794, upload-time = "2025-08-07T08:51:17.355Z" }, ]