From 7545ecd487eeb7f0d41bf09701f417a487d670ea Mon Sep 17 00:00:00 2001 From: yqqxybm Date: Fri, 8 Aug 2025 17:52:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0wav2vec=E5=92=8Cwav2v?= =?UTF-8?q?ec2=E5=AE=8C=E6=95=B4=E4=B8=AD=E6=96=87=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 为wav2vec.py添加逐行中文注释,包含所有类和方法的详细说明 - 为wav2vec2.py添加完整中文注释,覆盖154个配置参数和所有核心组件 - 注释涵盖算法原理、设计决策、工程优化等技术细节 - 提供维度变化、数据流程、掩码策略等深度解析 - 增强代码可读性和教学价值,适合深度学习和语音处理学习 --- fairseq/models/wav2vec/wav2vec.py | 971 +++++++++---- fairseq/models/wav2vec/wav2vec2.py | 2154 ++++++++++++++++++++-------- 2 files changed, 2294 insertions(+), 831 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py index af6604da10..b26a157ec3 100644 --- a/fairseq/models/wav2vec/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -1,545 +1,934 @@ # Copyright (c) Facebook, Inc. and its affiliates. +# Facebook公司及其附属机构版权所有 # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass, field -import logging -import math -from typing import Optional, Tuple -from omegaconf import II -import sys - -import torch -import torch.nn as nn -import torch.nn.functional as F -from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.models import BaseFairseqModel, register_model -from fairseq.modules import ( - Fp32GroupNorm, - Fp32LayerNorm, - GumbelVectorQuantizer, - KmeansVectorQuantizer, - TransposeLast, +# 此源代码基于MIT许可证授权,许可证文件位于项目根目录 + +# ============================================================================ +# Wav2Vec 1.0 模型实现 +# 基于对比预测编码(CPC)的自监督语音表示学习模型 +# 论文: wav2vec: Unsupervised Pre-training for Speech Recognition (2019) +# ============================================================================ + +from dataclasses import dataclass, field # 数据类装饰器,用于定义配置类 +import logging # 日志记录模块 +import math # 数学函数库 +from typing import Optional, Tuple # 类型注解支持 +from omegaconf import II # OmegaConf配置管理库的插值功能 +import sys # 系统相关功能 + +import torch # PyTorch深度学习框架 +import torch.nn as nn # PyTorch神经网络模块 +import torch.nn.functional as F # PyTorch函数式接口 +from fairseq.dataclass import ChoiceEnum, FairseqDataclass # Fairseq数据类和选择枚举 +from fairseq.models import BaseFairseqModel, register_model # Fairseq基础模型类和注册装饰器 +from fairseq.modules import ( # Fairseq预定义模块 + Fp32GroupNorm, # 32位浮点组归一化 + Fp32LayerNorm, # 32位浮点层归一化 + GumbelVectorQuantizer, # Gumbel向量量化器 + KmeansVectorQuantizer, # K-means向量量化器 + TransposeLast, # 转置最后维度的工具模块 ) -from fairseq.tasks import FairseqTask -from fairseq.utils import buffered_arange +from fairseq.tasks import FairseqTask # Fairseq任务基类 +from fairseq.utils import buffered_arange # 缓冲区范围生成工具 -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) # 获取当前模块的日志记录器 -AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) -PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) -ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) -VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) +# 定义模型配置的枚举选择项 +AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) # 聚合器类型:卷积神经网络或门控循环单元 +PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) # 特征投影方式:无投影、复用聚合器、新建聚合器 +ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) # 激活函数选择:ReLU或GELU +VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) # 向量量化类型:无量化、Gumbel或K-means @dataclass class Wav2VecConfig(FairseqDataclass): + """ + Wav2Vec模型配置类 + 包含模型架构、训练策略、向量量化等所有可配置参数 + """ + + # ============================================================================ + # 对比预测编码(CPC)相关参数 + # ============================================================================ prediction_steps: int = field( default=12, metadata={"help": "number of steps ahead to predict"} - ) + ) # 预测步数:模型需要预测未来多少个时间步的表示 + sample_distance: Optional[int] = field( default=None, metadata={ "help": "sample distance from target. does not work properly with cross-sampling" }, - ) + ) # 采样距离:从目标位置采样负样本的距离限制 + cross_sample_negatives: int = field( default=0, metadata={"help": "num of cross sampled negatives"} - ) + ) # 跨样本负样本数:从其他样本中采样的负样本数量 + num_negatives: int = field( default=10, metadata={"help": "num of sampled negatives"} - ) + ) # 负样本数量:每个正样本对应的负样本数量 + + # ============================================================================ + # 卷积网络架构参数 + # ============================================================================ conv_feature_layers: str = field( default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", metadata={ "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" }, - ) + ) # 特征提取卷积层配置:每层的(输出维度, 卷积核大小, 步长) + conv_aggregator_layers: str = field( default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", metadata={ "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" }, - ) + ) # 聚合器卷积层配置:用于时序信息聚合的卷积层参数 + + # ============================================================================ + # 正则化参数 + # ============================================================================ dropout: float = field( default=0.0, metadata={"help": "dropout to apply within the model"} - ) + ) # 模型内部的dropout概率 + dropout_features: float = field( default=0.0, metadata={"help": "dropout to apply to the features"} - ) + ) # 特征层的dropout概率 + dropout_agg: float = field( default=0.0, metadata={"help": "dropout to apply after aggregation step"} - ) + ) # 聚合步骤后的dropout概率 + + # ============================================================================ + # 聚合器配置 + # ============================================================================ aggregator: AGGREGATOR_CHOICES = field( default="cnn", metadata={"help": "type of aggregator to use"} - ) + ) # 聚合器类型:使用CNN或GRU进行时序信息聚合 + gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) + # GRU聚合器的隐藏维度 + + # ============================================================================ + # 卷积层配置 + # ============================================================================ no_conv_bias: bool = field( default=False, metadata={"help": "if set, does not learn bias for conv layers"} - ) + ) # 是否禁用卷积层的偏置项 + agg_zero_pad: bool = field( default=False, metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, - ) + ) # 聚合器是否使用零填充而非复制填充 + + # ============================================================================ + # 跳跃连接配置 + # ============================================================================ skip_connections_feat: bool = field( default=False, metadata={"help": "if set, adds skip connections to the feature extractor"}, - ) + ) # 特征提取器是否使用跳跃连接 + skip_connections_agg: bool = field( default=True, metadata={"help": "if set, adds skip connections to the aggregator"}, - ) + ) # 聚合器是否使用跳跃连接 + residual_scale: float = field( default=0.5, metadata={"help": "scales residual by sqrt(value)"} - ) + ) # 残差连接的缩放因子:residual = sqrt(residual_scale) * residual + + # ============================================================================ + # 特征处理配置 + # ============================================================================ log_compression: bool = field( default=True, metadata={"help": "if set, adds a log compression to feature extractor"}, - ) + ) # 是否对特征进行对数压缩:log(|x| + 1) + balanced_classes: bool = field( default=False, metadata={"help": "if set, loss is scaled to balance for number of negatives"}, - ) + ) # 是否平衡正负样本的损失权重 + project_features: PROJECT_FEATURES_CHOICES = field( default="none", metadata={ "help": "if not none, features are projected using the (same or new) aggregator" }, - ) + ) # 特征投影方式:无投影/复用聚合器/新建投影器 + non_affine_group_norm: bool = field( default=False, metadata={"help": "if set, group norm is not affine"} - ) + ) # 组归一化是否使用仿射变换(可学习的scale和shift参数) + + # ============================================================================ + # 时序对齐配置 + # ============================================================================ offset: str = field( default="auto", metadata={ "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" }, - ) + ) # 时序偏移量:auto表示根据感受野自动计算,否则使用指定值 + activation: ACTIVATION_CHOICES = field( default="relu", metadata={ "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" }, - ) + ) # 激活函数类型:ReLU或GELU + + # ============================================================================ + # 向量量化(VQ)配置 + # ============================================================================ vq_type: VQ_TYPE_CHOICES = field( default="none", metadata={"help": "which type of quantizer to use"} - ) + ) # 向量量化器类型:无量化/Gumbel Softmax/K-means + vq_vars: int = field( default=320, metadata={"help": "project to this many vector quantized variables per group"}, - ) + ) # 每组向量量化变量的数量(码本大小) + vq_groups: int = field( default=2, metadata={"help": "number of groups of latent variables"} - ) + ) # 潜在变量的分组数量(乘积量化的组数) + vq_dim: int = field( default=0, metadata={ "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" }, - ) + ) # 量化向量的维度,0表示使用model_dim // groups + vq_depth: int = field( default=1, metadata={"help": "number of layers for vq weight projection"} - ) + ) # 向量量化权重投影的层数 + combine_groups: bool = field( default=False, metadata={"help": "if set, variables are shared among groups"} - ) + ) # 是否在组间共享变量 + vq_temp: Tuple[float, float, float] = field( default=(2.0, 0.5, 0.999995), metadata={ "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" }, - ) + ) # Gumbel Softmax温度参数:(起始温度, 结束温度, 衰减率) + vq_gamma: float = field( default=0.25, metadata={"help": "gamma parameter for kmeans style vector quantization"}, - ) - infonce: bool = II("criterion.infonce") + ) # K-means向量量化的gamma参数(聚类损失权重) + + infonce: bool = II("criterion.infonce") # 是否使用InfoNCE损失(从损失函数配置中继承) -@register_model("wav2vec", dataclass=Wav2VecConfig) +@register_model("wav2vec", dataclass=Wav2VecConfig) # 注册模型到Fairseq框架 class Wav2VecModel(BaseFairseqModel): + """ + Wav2Vec 1.0 主模型类 + + 模型架构: + 1. 卷积特征提取器 (ConvFeatureExtractionModel) + 2. 向量量化器 (可选: GumbelVectorQuantizer/KmeansVectorQuantizer) + 3. 特征聚合器 (ConvAggegator/GRU) + 4. 对比预测模块 (Wav2VecPredictionsModel) + + 训练目标:通过对比学习预测未来时间步的特征表示 + """ + @classmethod def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): - """Build a new model instance.""" - - model = Wav2VecModel(cfg) - logger.info(model) + """ + 模型构建工厂方法 + Args: + cfg: 模型配置 + task: Fairseq任务实例 + Returns: + 构建好的Wav2VecModel实例 + """ + model = Wav2VecModel(cfg) # 创建模型实例 + logger.info(model) # 记录模型结构信息 return model def __init__(self, cfg: Wav2VecConfig): - super().__init__() - + """ + 初始化Wav2Vec模型 + Args: + cfg: 模型配置参数 + """ + super().__init__() # 调用父类初始化 + + # 保存预测步数配置 self.prediction_steps = cfg.prediction_steps - offset = cfg.offset + offset = cfg.offset # 时序偏移量 + # ============================================================================ + # 1. 激活函数配置 + # ============================================================================ if cfg.activation == "relu": - activation = nn.ReLU() + activation = nn.ReLU() # 使用ReLU激活函数 elif cfg.activation == "gelu": - activation = nn.GELU() + activation = nn.GELU() # 使用GELU激活函数 else: - raise Exception("unknown activation " + cfg.activation) + raise Exception("unknown activation " + cfg.activation) # 未知激活函数 - feature_enc_layers = eval(cfg.conv_feature_layers) + # ============================================================================ + # 2. 卷积特征提取器 + # ============================================================================ + feature_enc_layers = eval(cfg.conv_feature_layers) # 解析卷积层配置字符串 self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - log_compression=cfg.log_compression, - skip_connections=cfg.skip_connections_feat, - residual_scale=cfg.residual_scale, - non_affine_group_norm=cfg.non_affine_group_norm, - activation=activation, + conv_layers=feature_enc_layers, # 卷积层配置列表 + dropout=0.0, # 特征提取器内部不使用dropout + log_compression=cfg.log_compression, # 是否使用对数压缩 + skip_connections=cfg.skip_connections_feat, # 是否使用跳跃连接 + residual_scale=cfg.residual_scale, # 残差缩放因子 + non_affine_group_norm=cfg.non_affine_group_norm, # 组归一化配置 + activation=activation, # 激活函数 ) - embed = feature_enc_layers[-1][0] + embed = feature_enc_layers[-1][0] # 获取最后一层的输出维度作为嵌入维度 - self.vector_quantizer = None + # ============================================================================ + # 3. 向量量化器 (可选) + # ============================================================================ + self.vector_quantizer = None # 默认不使用量化器 if cfg.vq_type == "gumbel": + # 使用Gumbel Softmax向量量化器 self.vector_quantizer = GumbelVectorQuantizer( - dim=embed, - num_vars=cfg.vq_vars, - temp=cfg.vq_temp, - groups=cfg.vq_groups, - combine_groups=cfg.combine_groups, - vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, - time_first=False, - activation=activation, - weight_proj_depth=cfg.vq_depth, - weight_proj_factor=2, + dim=embed, # 输入特征维度 + num_vars=cfg.vq_vars, # 每组量化变量数量 + temp=cfg.vq_temp, # 温度参数(起始,结束,衰减) + groups=cfg.vq_groups, # 量化分组数 + combine_groups=cfg.combine_groups, # 是否组合分组 + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, # 量化维度 + time_first=False, # 时间维度不在第一维 + activation=activation, # 激活函数 + weight_proj_depth=cfg.vq_depth, # 权重投影层深度 + weight_proj_factor=2, # 权重投影因子 ) elif cfg.vq_type == "kmeans": + # 使用K-means向量量化器 self.vector_quantizer = KmeansVectorQuantizer( - dim=embed, - num_vars=cfg.vq_vars, - groups=cfg.vq_groups, - combine_groups=cfg.combine_groups, - vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, - time_first=False, - gamma=cfg.vq_gamma, + dim=embed, # 输入特征维度 + num_vars=cfg.vq_vars, # 每组量化变量数量 + groups=cfg.vq_groups, # 量化分组数 + combine_groups=cfg.combine_groups, # 是否组合分组 + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, # 量化维度 + time_first=False, # 时间维度不在第一维 + gamma=cfg.vq_gamma, # 聚类损失权重 ) else: + # 验证量化器类型配置 assert ( cfg.vq_type == "none" or cfg.vq_type is None ), "Unknown quantizer type" + # ============================================================================ + # 4. 自动计算时序偏移量 + # ============================================================================ if cfg.offset == "auto": - jin = 0 - rin = 0 + # 根据卷积网络的感受野自动计算偏移量 + jin = 0 # 累积步长 + rin = 0 # 感受野大小 for _, k, stride in feature_enc_layers: if rin == 0: - rin = k - rin = rin + (k - 1) * jin + rin = k # 初始感受野等于第一层卷积核大小 + rin = rin + (k - 1) * jin # 更新感受野大小 if jin == 0: - jin = stride + jin = stride # 初始步长 else: - jin *= stride - offset = math.ceil(rin / jin) + jin *= stride # 累积步长 + offset = math.ceil(rin / jin) # 计算最终偏移量 - offset = int(offset) + offset = int(offset) # 确保偏移量为整数 + # ============================================================================ + # 5. 聚合器构建函数 + # ============================================================================ def make_aggregator(): + """ + 创建特征聚合器 + Returns: + tuple: (聚合器模块, 输出维度) + """ if cfg.aggregator == "cnn": - agg_layers = eval(cfg.conv_aggregator_layers) - agg_dim = agg_layers[-1][0] + # 使用卷积聚合器 + agg_layers = eval(cfg.conv_aggregator_layers) # 解析聚合器层配置 + agg_dim = agg_layers[-1][0] # 获取最后一层输出维度 feature_aggregator = ConvAggegator( - conv_layers=agg_layers, - embed=embed, - dropout=cfg.dropout, - skip_connections=cfg.skip_connections_agg, - residual_scale=cfg.residual_scale, - non_affine_group_norm=cfg.non_affine_group_norm, - conv_bias=not cfg.no_conv_bias, - zero_pad=cfg.agg_zero_pad, - activation=activation, + conv_layers=agg_layers, # 卷积层配置 + embed=embed, # 输入嵌入维度 + dropout=cfg.dropout, # dropout概率 + skip_connections=cfg.skip_connections_agg, # 跳跃连接 + residual_scale=cfg.residual_scale, # 残差缩放 + non_affine_group_norm=cfg.non_affine_group_norm, # 组归一化 + conv_bias=not cfg.no_conv_bias, # 卷积偏置 + zero_pad=cfg.agg_zero_pad, # 零填充 + activation=activation, # 激活函数 ) elif cfg.aggregator == "gru": - agg_dim = cfg.gru_dim + # 使用GRU聚合器 + agg_dim = cfg.gru_dim # GRU隐藏维度 feature_aggregator = nn.Sequential( - TransposeLast(), + TransposeLast(), # 转置:BxTxC -> BxCxT,为GRU准备 nn.GRU( - input_size=embed, - hidden_size=agg_dim, - num_layers=1, - dropout=cfg.dropout, + input_size=embed, # 输入维度 + hidden_size=agg_dim, # 隐藏层维度 + num_layers=1, # 单层GRU + dropout=cfg.dropout, # dropout ), - TransposeLast(deconstruct_idx=0), + TransposeLast(deconstruct_idx=0), # 转置回来,只取output ) else: raise Exception("unknown aggregator type " + cfg.aggregator) return feature_aggregator, agg_dim + # 创建聚合器实例 self.feature_aggregator, agg_dim = make_aggregator() + # ============================================================================ + # 6. 对比预测模块 + # ============================================================================ self.wav2vec_predictions = Wav2VecPredictionsModel( - in_dim=agg_dim, - out_dim=embed, - prediction_steps=cfg.prediction_steps, - n_negatives=cfg.num_negatives, - cross_sample_negatives=cfg.cross_sample_negatives, - sample_distance=cfg.sample_distance, - dropout=cfg.dropout, - offset=offset, - balanced_classes=cfg.balanced_classes, - infonce=cfg.infonce, + in_dim=agg_dim, # 输入维度(聚合器输出维度) + out_dim=embed, # 输出维度(特征维度) + prediction_steps=cfg.prediction_steps, # 预测步数 + n_negatives=cfg.num_negatives, # 负样本数量 + cross_sample_negatives=cfg.cross_sample_negatives, # 跨样本负样本数 + sample_distance=cfg.sample_distance, # 采样距离 + dropout=cfg.dropout, # dropout概率 + offset=offset, # 时序偏移量 + balanced_classes=cfg.balanced_classes, # 是否平衡类别 + infonce=cfg.infonce, # 是否使用InfoNCE损失 ) - self.dropout_feats = nn.Dropout(p=cfg.dropout_features) - self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) + # ============================================================================ + # 7. Dropout层 + # ============================================================================ + self.dropout_feats = nn.Dropout(p=cfg.dropout_features) # 特征dropout + self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) # 聚合后dropout + # ============================================================================ + # 8. 特征投影器 (可选) + # ============================================================================ if cfg.project_features == "none": - self.project_features = None + self.project_features = None # 不使用特征投影 elif cfg.project_features == "same": - self.project_features = self.feature_aggregator + self.project_features = self.feature_aggregator # 复用聚合器 elif cfg.project_features == "new": - self.project_features, _ = make_aggregator() + self.project_features, _ = make_aggregator() # 创建新的聚合器作为投影器 def forward(self, source): - result = {} - + """ + Wav2Vec模型前向传播 + + 数据流: + 原始音频 -> 特征提取 -> [向量量化] -> dropout -> 聚合 -> dropout -> 对比预测 + + Args: + source (Tensor): 原始音频波形 [batch_size, seq_len] + + Returns: + dict: 包含对比预测logits和targets的字典 + - cpc_logits: 对比预测的logits + - cpc_targets: 对比预测的目标 + - 其他量化器相关的输出(如果使用) + """ + result = {} # 存储所有输出结果 + + # ============================================================================ + # 1. 特征提取:原始音频 -> 卷积特征 [B, T] -> [B, C, T] + # ============================================================================ features = self.feature_extractor(source) + + # ============================================================================ + # 2. 向量量化 (可选):连续特征 -> 离散码本 + # ============================================================================ if self.vector_quantizer: - q_res = self.vector_quantizer(features) - features = q_res["x"] + q_res = self.vector_quantizer(features) # 执行向量量化 + features = q_res["x"] # 获取量化后的特征 + # 保存量化器的其他输出(如困惑度、损失等) for k in q_res.keys(): if k != "x": result[k] = q_res[k] - x = self.dropout_feats(features) - x = self.feature_aggregator(x) - x = self.dropout_agg(x) + # ============================================================================ + # 3. 特征处理:dropout -> 聚合 -> dropout + # ============================================================================ + x = self.dropout_feats(features) # 对特征应用dropout + x = self.feature_aggregator(x) # 时序信息聚合(CNN或GRU) + x = self.dropout_agg(x) # 对聚合结果应用dropout + # ============================================================================ + # 4. 特征投影 (可选):为对比学习准备目标特征 + # ============================================================================ if self.project_features is not None: - features = self.project_features(features) + features = self.project_features(features) # 投影原始特征作为目标 + + # ============================================================================ + # 5. 对比预测:预测未来时间步的特征表示 + # ============================================================================ x, targets = self.wav2vec_predictions(x, features) - result["cpc_logits"] = x - result["cpc_targets"] = targets + result["cpc_logits"] = x # 对比预测的logits + result["cpc_targets"] = targets # 对比预测的目标 return result def upgrade_state_dict_named(self, state_dict, name): + """ + 升级模型状态字典,用于向后兼容 + Args: + state_dict: 模型状态字典 + name: 模型名称 + """ super().upgrade_state_dict_named(state_dict, name) def max_positions(self): - """Maximum length supported by the model.""" + """ + 模型支持的最大序列长度 + Returns: + int: 最大位置数(理论上无限制) + """ return sys.maxsize def get_logits(self, net_output): + """ + 从网络输出中提取logits + Args: + net_output (dict): 网络前向传播输出 + Returns: + Tensor: CPC logits张量 + """ logits = net_output["cpc_logits"] return logits def get_targets(self, sample, net_output): + """ + 从网络输出中提取目标 + Args: + sample: 输入样本(未使用) + net_output (dict): 网络前向传播输出 + Returns: + Tensor: CPC目标张量 + """ t = net_output["cpc_targets"] - if isinstance(t, tuple): + if isinstance(t, tuple): # 如果目标是元组,取第一个元素 t = t[0] - return t.contiguous() + return t.contiguous() # 确保内存连续 def get_target_weights(self, targets, net_output): + """ + 获取目标权重(用于加权损失计算) + Args: + targets: 目标张量(未使用) + net_output (dict): 网络前向传播输出 + Returns: + Tensor or None: 目标权重张量 + """ targets = net_output["cpc_targets"] if isinstance(targets, tuple) and targets[-1] is not None: - return targets[-1] + return targets[-1] # 返回权重(元组的最后一个元素) return None def get_extra_losses(self, net_output): + """ + 获取额外的损失项(主要来自向量量化器) + Args: + net_output (dict): 网络前向传播输出 + Returns: + Tensor or None: 额外的损失项 + """ loss = None if "prob_perplexity" in net_output: + # Gumbel VQ的困惑度损失:鼓励使用更多的码本条目 loss = net_output["num_vars"] - net_output["prob_perplexity"] elif "kmeans_loss" in net_output: + # K-means VQ的聚类损失 loss = net_output["kmeans_loss"] return loss def norm_block(is_layer_norm, dim, affine=True): + """ + 创建归一化模块 + Args: + is_layer_norm (bool): 是否使用层归一化(否则使用组归一化) + dim (int): 特征维度 + affine (bool): 是否使用仿射变换(可学习的scale和shift) + Returns: + nn.Module: 归一化模块 + """ if is_layer_norm: + # 层归一化:需要转置以适应1D卷积的输出格式 mod = nn.Sequential( - TransposeLast(), - Fp32LayerNorm(dim, elementwise_affine=affine), - TransposeLast(), + TransposeLast(), # [B, C, T] -> [B, T, C] + Fp32LayerNorm(dim, elementwise_affine=affine), # 在最后一维进行归一化 + TransposeLast(), # [B, T, C] -> [B, C, T] ) else: + # 组归一化:直接在通道维度进行,num_groups=1等价于实例归一化 mod = Fp32GroupNorm(1, dim, affine=affine) return mod class ConvFeatureExtractionModel(nn.Module): + """ + 卷积特征提取器 + + 功能: + 1. 将原始音频波形转换为高维特征表示 + 2. 通过多层1D卷积逐步提取层次化特征 + 3. 可选的跳跃连接和对数压缩 + + 架构: + - 多层1D卷积 + 组归一化 + 激活函数 + - 可选的残差连接 + - 最终的对数压缩 + """ + def __init__( self, - conv_layers, - dropout, - log_compression, - skip_connections, - residual_scale, - non_affine_group_norm, - activation, + conv_layers, # 卷积层配置列表 + dropout, # dropout概率 + log_compression, # 是否使用对数压缩 + skip_connections, # 是否使用跳跃连接 + residual_scale, # 残差缩放因子 + non_affine_group_norm, # 组归一化是否禁用仿射变换 + activation, # 激活函数 ): super().__init__() def block(n_in, n_out, k, stride): + """ + 创建单个卷积块 + Args: + n_in (int): 输入通道数 + n_out (int): 输出通道数 + k (int): 卷积核大小 + stride (int): 步长 + Returns: + nn.Sequential: 卷积块(Conv1d + Dropout + GroupNorm + Activation) + """ return nn.Sequential( - nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), - nn.Dropout(p=dropout), + nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), # 1D卷积,无偏置 + nn.Dropout(p=dropout), # dropout正则化 norm_block( - is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm + is_layer_norm=False, # 使用组归一化 + dim=n_out, # 归一化维度 + affine=not non_affine_group_norm # 是否使用仿射变换 ), - activation, + activation, # 激活函数(ReLU/GELU) ) - in_d = 1 + # ============================================================================ + # 构建卷积层序列 + # ============================================================================ + in_d = 1 # 初始输入通道数(原始音频为1通道) self.conv_layers = nn.ModuleList() + + # 根据配置创建每一层卷积 for dim, k, stride in conv_layers: self.conv_layers.append(block(in_d, dim, k, stride)) - in_d = dim + in_d = dim # 更新下一层的输入通道数 - self.log_compression = log_compression - self.skip_connections = skip_connections - self.residual_scale = math.sqrt(residual_scale) + # 保存配置参数 + self.log_compression = log_compression # 是否使用对数压缩 + self.skip_connections = skip_connections # 是否使用跳跃连接 + self.residual_scale = math.sqrt(residual_scale) # 残差缩放因子(开平方) def forward(self, x): - # BxT -> BxCxT - x = x.unsqueeze(1) - + """ + 前向传播 + Args: + x (Tensor): 原始音频波形 [batch_size, seq_len] + Returns: + Tensor: 提取的特征 [batch_size, feature_dim, seq_len'] + """ + # ============================================================================ + # 1. 维度扩展:BxT -> Bx1xT (添加通道维度) + # ============================================================================ + x = x.unsqueeze(1) # [B, T] -> [B, 1, T] + + # ============================================================================ + # 2. 逐层卷积特征提取 + # ============================================================================ for conv in self.conv_layers: - residual = x - x = conv(x) + residual = x # 保存残差连接的输入 + x = conv(x) # 卷积变换 + + # 残差连接 (如果启用且通道数匹配) if self.skip_connections and x.size(1) == residual.size(1): - tsz = x.size(2) - r_tsz = residual.size(2) + # 处理时序长度不匹配的情况 + tsz = x.size(2) # 当前输出的时序长度 + r_tsz = residual.size(2) # 残差的时序长度 + + # 下采样残差以匹配当前输出的时序长度 residual = residual[..., :: r_tsz // tsz][..., :tsz] + + # 残差连接并缩放 x = (x + residual) * self.residual_scale + # ============================================================================ + # 3. 对数压缩 (可选):log(|x| + 1) + # ============================================================================ if self.log_compression: - x = x.abs() - x = x + 1 - x = x.log() + x = x.abs() # 取绝对值,避免负数 + x = x + 1 # 加1防止log(0) + x = x.log() # 对数变换,压缩动态范围 return x class ZeroPad1d(nn.Module): + """ + 1D零填充模块 + 对1D张量进行左右零填充 + """ def __init__(self, pad_left, pad_right): + """ + Args: + pad_left (int): 左侧填充数量 + pad_right (int): 右侧填充数量 + """ super().__init__() - self.pad_left = pad_left - self.pad_right = pad_right + self.pad_left = pad_left # 左侧填充数量 + self.pad_right = pad_right # 右侧填充数量 def forward(self, x): - return F.pad(x, (self.pad_left, self.pad_right)) + """ + Args: + x (Tensor): 输入张量 [B, C, T] + Returns: + Tensor: 填充后的张量 [B, C, T + pad_left + pad_right] + """ + return F.pad(x, (self.pad_left, self.pad_right)) # 在最后一维进行填充 class ConvAggegator(nn.Module): + """ + 卷积聚合器 + + 功能: + 1. 对特征提取器的输出进行时序信息聚合 + 2. 通过因果卷积保持时序关系 + 3. 使用跳跃连接增强信息流 + + 架构: + - 多层因果1D卷积 + - 残差连接 + 维度投影 + - 层归一化 + dropout + """ + def __init__( self, - conv_layers, - embed, - dropout, - skip_connections, - residual_scale, - non_affine_group_norm, - conv_bias, - zero_pad, - activation, + conv_layers, # 卷积层配置列表 + embed, # 输入嵌入维度 + dropout, # dropout概率 + skip_connections, # 是否使用跳跃连接 + residual_scale, # 残差缩放因子 + non_affine_group_norm, # 组归一化是否禁用仿射变换 + conv_bias, # 是否使用卷积偏置 + zero_pad, # 是否使用零填充(否则使用复制填充) + activation, # 激活函数 ): super().__init__() def block(n_in, n_out, k, stride): - # padding dims only really make sense for stride = 1 - ka = k // 2 - kb = ka - 1 if k % 2 == 0 else ka - + """ + 创建单个聚合卷积块 + Args: + n_in (int): 输入通道数 + n_out (int): 输出通道数 + k (int): 卷积核大小 + stride (int): 步长 + Returns: + nn.Sequential: 聚合卷积块 + """ + # ================================================================ + # 因果填充:确保未来信息不泄露 + # ================================================================ + ka = k // 2 # 左侧填充 + kb = ka - 1 if k % 2 == 0 else ka # 右侧填充(偶数核需要-1) + + # 选择填充方式:零填充 vs 复制填充 pad = ( - ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) + ZeroPad1d(ka + kb, 0) if zero_pad + else nn.ReplicationPad1d((ka + kb, 0)) ) return nn.Sequential( - pad, - nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), - nn.Dropout(p=dropout), - norm_block(False, n_out, affine=not non_affine_group_norm), - activation, + pad, # 填充 + nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), # 1D卷积 + nn.Dropout(p=dropout), # dropout正则化 + norm_block(False, n_out, affine=not non_affine_group_norm), # 组归一化 + activation, # 激活函数 ) - in_d = embed - self.conv_layers = nn.ModuleList() - self.residual_proj = nn.ModuleList() + # ============================================================================ + # 构建聚合器层 + # ============================================================================ + in_d = embed # 初始输入维度 + self.conv_layers = nn.ModuleList() # 卷积层列表 + self.residual_proj = nn.ModuleList() # 残差投影层列表 + for dim, k, stride in conv_layers: + # 残差投影:处理维度不匹配的情况 if in_d != dim and skip_connections: + # 需要1x1卷积进行维度变换 self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) else: + # 维度匹配,不需要投影 self.residual_proj.append(None) + # 添加聚合卷积块 self.conv_layers.append(block(in_d, dim, k, stride)) - in_d = dim - self.conv_layers = nn.Sequential(*self.conv_layers) - self.skip_connections = skip_connections - self.residual_scale = math.sqrt(residual_scale) + in_d = dim # 更新下一层的输入维度 + + # 注意:这里没有将conv_layers转为Sequential,保持ModuleList以配合residual_proj使用 + self.skip_connections = skip_connections # 是否使用跳跃连接 + self.residual_scale = math.sqrt(residual_scale) # 残差缩放因子 def forward(self, x): + """ + 前向传播 + Args: + x (Tensor): 输入特征 [batch_size, embed_dim, seq_len] + Returns: + Tensor: 聚合后的特征 [batch_size, output_dim, seq_len] + """ + # ============================================================================ + # 逐层聚合处理 + # ============================================================================ for rproj, conv in zip(self.residual_proj, self.conv_layers): - residual = x - x = conv(x) + residual = x # 保存残差连接的输入 + x = conv(x) # 卷积聚合 + + # 跳跃连接处理 if self.skip_connections: if rproj is not None: + # 维度不匹配,需要投影 residual = rproj(residual) + # 残差连接并缩放 x = (x + residual) * self.residual_scale return x class Wav2VecPredictionsModel(nn.Module): + """ + Wav2Vec对比预测模块 + + 功能: + 1. 实现对比预测编码(CPC)的核心逻辑 + 2. 预测未来时间步的特征表示 + 3. 通过负采样进行对比学习 + + 工作流程: + 1. 将聚合器输出投影到多个预测步骤 + 2. 采样负样本进行对比学习 + 3. 计算预测logits和目标targets + """ + def __init__( self, - in_dim, - out_dim, - prediction_steps, - n_negatives, - cross_sample_negatives, - sample_distance, - dropout, - offset, - balanced_classes, - infonce, + in_dim, # 输入维度(聚合器输出维度) + out_dim, # 输出维度(目标特征维度) + prediction_steps, # 预测步数 + n_negatives, # 负样本数量 + cross_sample_negatives, # 跨样本负样本数量 + sample_distance, # 采样距离限制 + dropout, # dropout概率 + offset, # 时序偏移量 + balanced_classes, # 是否平衡类别权重 + infonce, # 是否使用InfoNCE损失 ): super().__init__() - self.n_negatives = n_negatives - self.cross_sample_negatives = cross_sample_negatives - self.sample_distance = sample_distance + # 保存配置参数 + self.n_negatives = n_negatives # 负样本数量 + self.cross_sample_negatives = cross_sample_negatives # 跨样本负样本数量 + self.sample_distance = sample_distance # 采样距离限制 + + # ============================================================================ + # 预测投影层:将聚合特征投影到多个预测步骤 + # ============================================================================ self.project_to_steps = nn.ConvTranspose2d( in_dim, out_dim, (1, prediction_steps) - ) - self.dropout = nn.Dropout(p=dropout) - self.offset = offset - self.balanced_classes = balanced_classes - self.infonce = infonce + ) # 2D转置卷积:[B,in_dim,T,1] -> [B,out_dim,T,prediction_steps] + + self.dropout = nn.Dropout(p=dropout) # dropout正则化 + self.offset = offset # 时序偏移量 + self.balanced_classes = balanced_classes # 是否平衡类别权重 + self.infonce = infonce # 是否使用InfoNCE损失 def sample_negatives(self, y): - bsz, fsz, tsz = y.shape - - y = y.transpose(0, 1) # BCT -> CBT - y = y.contiguous().view(fsz, -1) # CBT => C(BxT) - - cross_high = tsz * bsz + """ + 负样本采样函数 + + 目标:为对比学习采样负样本 + 策略: + 1. 同批次内采样 (避免采样到真实目标) + 2. 跨批次采样 (增加负样本多样性) + + Args: + y (Tensor): 目标特征 [batch_size, feature_dim, seq_len] + Returns: + Tensor: 负样本 [n_negatives, batch_size, feature_dim, seq_len] + """ + bsz, fsz, tsz = y.shape # 批次大小、特征维度、序列长度 + + # ============================================================================ + # 1. 数据重组:方便负样本采样 + # ============================================================================ + y = y.transpose(0, 1) # [B,C,T] -> [C,B,T] + y = y.contiguous().view(fsz, -1) # [C,B,T] -> [C,B*T] 展平为特征池 + + # ============================================================================ + # 2. 计算采样范围 + # ============================================================================ + cross_high = tsz * bsz # 跨样本采样的总范围 + # 同样本内采样范围:受sample_distance限制 high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) - assert high > 1 + assert high > 1, "采样范围必须大于1" + # 预分配负样本索引 neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) - with torch.no_grad(): + with torch.no_grad(): # 不需要梯度计算 + # ==================================================================== + # 3. 同批次内负样本采样 + # ==================================================================== if self.n_negatives > 0: + # 创建时间步索引:[0,1,2,...,T-1] 重复 n_negatives 次 tszs = ( - buffered_arange(tsz) - .unsqueeze(-1) - .expand(-1, self.n_negatives) - .flatten() + buffered_arange(tsz) # [0,1,2,...,T-1] + .unsqueeze(-1) # [T,1] + .expand(-1, self.n_negatives) # [T,N] + .flatten() # [T*N] ) + # 随机采样负样本索引,避免采样到真实目标 neg_idxs = torch.randint( low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) ) + # 如果采样到的索引 >= 当前时间步,则+1跳过真实目标 neg_idxs[neg_idxs >= tszs] += 1 + # ==================================================================== + # 4. 跨批次负样本采样 + # ==================================================================== if self.cross_sample_negatives > 0: + # 跨样本时间步索引 tszs = ( buffered_arange(tsz) .unsqueeze(-1) @@ -547,6 +936,7 @@ def sample_negatives(self, y): .flatten() ) + # 从整个批次池中采样 cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, @@ -554,76 +944,137 @@ def sample_negatives(self, y): ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + # ============================================================================ + # 5. 调整索引以匹配批次结构 + # ============================================================================ if self.n_negatives > 0: + # 为每个批次调整索引偏移 for i in range(1, bsz): neg_idxs[i] += i * high else: + # 只使用跨样本负样本 neg_idxs = cross_neg_idxs + # 合并同批次和跨批次负样本 if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) - negs = y[..., neg_idxs.view(-1)] + # ============================================================================ + # 6. 提取负样本特征 + # ============================================================================ + negs = y[..., neg_idxs.view(-1)] # 根据索引提取负样本 negs = negs.view( fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz ).permute( 2, 1, 0, 3 - ) # to NxBxCxT + ) # 重组为 [N_negatives, Batch, Features, Time] return negs def forward(self, x, y): - - x = x.unsqueeze(-1) - x = self.project_to_steps(x) # BxCxTxS - x = self.dropout(x) - - negatives = self.sample_negatives(y) - y = y.unsqueeze(0) - targets = torch.cat([y, negatives], dim=0) # Copies x B x C x T - - copies = targets.size(0) - bsz, dim, tsz, steps = x.shape - steps = min(steps, tsz - self.offset) - + """ + 对比预测前向传播 + + 核心逻辑: + 1. 将聚合特征投影到多个预测步骤 + 2. 采样负样本构建对比学习目标 + 3. 计算预测与目标的相似度 + 4. 返回logits和labels用于损失计算 + + Args: + x (Tensor): 聚合后的上下文特征 [B, agg_dim, T] + y (Tensor): 目标特征 [B, feature_dim, T] + Returns: + tuple: (predictions, labels) + - predictions: 预测logits + - labels: 对比学习的标签 + """ + + # ============================================================================ + # 1. 多步预测投影:上下文 -> 预测 + # ============================================================================ + x = x.unsqueeze(-1) # [B,C,T] -> [B,C,T,1],添加预测步维度 + x = self.project_to_steps(x) # [B,C,T,1] -> [B,out_dim,T,steps] 投影到多步预测 + x = self.dropout(x) # dropout正则化 + + # ============================================================================ + # 2. 构建对比学习目标:正样本 + 负样本 + # ============================================================================ + negatives = self.sample_negatives(y) # 采样负样本 [N_neg, B, C, T] + y = y.unsqueeze(0) # 正样本 [1, B, C, T] + targets = torch.cat([y, negatives], dim=0) # 合并目标 [1+N_neg, B, C, T] + + # ============================================================================ + # 3. 计算预测相似度 + # ============================================================================ + copies = targets.size(0) # 目标总数 = 1 + n_negatives + bsz, dim, tsz, steps = x.shape # 预测张量的形状 + steps = min(steps, tsz - self.offset) # 有效预测步数(考虑偏移) + + # 预分配预测结果张量 predictions = x.new( bsz * copies * (tsz - self.offset + 1) * steps - ((steps + 1) * steps // 2) * copies * bsz - ) + ) # 考虑因果约束后的实际预测数量 + + # ============================================================================ + # 4. 准备标签和权重 + # ============================================================================ if self.infonce: + # InfoNCE:正样本标签为0,其余为负样本 labels = predictions.new_full( (predictions.shape[0] // copies,), 0, dtype=torch.long ) else: + # 二分类:正样本为1,负样本为0 labels = torch.zeros_like(predictions) + + # 类别平衡权重 weights = ( torch.full_like(labels, 1 / self.n_negatives) if self.balanced_classes and not self.infonce else None ) + # ============================================================================ + # 5. 逐步预测计算 + # ============================================================================ start = end = 0 for i in range(steps): - offset = i + self.offset - end = start + (tsz - offset) * bsz * copies + offset = i + self.offset # 当前预测的时间偏移 + end = start + (tsz - offset) * bsz * copies # 当前步的结束位置 + if self.infonce: + # InfoNCE:计算预测和所有目标的点积 predictions[start:end] = torch.einsum( - "bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:] + "bct,nbct->tbn", # 爱因斯坦求和:批次×特征×时间 与 目标×批次×特征×时间 + x[..., :-offset, i], # 预测特征(去掉未来部分) + targets[..., offset:] # 目标特征(从offset开始) ).flatten() else: - pos_num = (end - start) // copies + # 二分类:计算每个目标的预测分数 + pos_num = (end - start) // copies # 正样本数量 predictions[start:end] = torch.einsum( "bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:] ).flatten() + + # 设置正样本标签为1.0 labels[start : start + pos_num] = 1.0 if weights is not None: weights[start : start + pos_num] = 1.0 start = end - assert end == predictions.numel(), "{} != {}".format(end, predictions.numel()) + + # 验证预测数量正确性 + assert end == predictions.numel(), f"预测数量不匹配: {end} != {predictions.numel()}" + # ============================================================================ + # 6. 格式化输出 + # ============================================================================ if self.infonce: + # InfoNCE:重组为 [样本数, 目标数] 用于交叉熵损失 predictions = predictions.view(-1, copies) else: + # 二分类:添加权重信息 if weights is not None: labels = (labels, weights) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 0faba77f8b..85f71478b7 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -1,48 +1,74 @@ # Copyright (c) Facebook, Inc. and its affiliates. +# Facebook公司及其附属机构版权所有 # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -import math -from dataclasses import dataclass, field -from typing import List, Tuple - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from fairseq import utils -from fairseq.data.data_utils import compute_mask_indices -from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.distributed import fsdp_wrap -from fairseq.models import BaseFairseqModel, register_model -from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel -from fairseq.modules import ( - Fp32GroupNorm, - Fp32LayerNorm, - GradMultiply, - GumbelVectorQuantizer, - LayerNorm, - MultiheadAttention, - RelPositionalEncoding, - SamePad, - TransposeLast, +# 此源代码基于MIT许可证授权,许可证文件位于项目根目录 + +# ============================================================================ +# Wav2Vec 2.0 模型实现 +# 基于掩码预测的自监督语音表示学习模型 +# 论文: wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (2020) +# +# 主要改进: +# 1. 使用Transformer替代卷积聚合器 +# 2. 采用BERT式的掩码预测而非未来预测 +# 3. 改进的向量量化策略 +# 4. 更强的表示学习能力 +# ============================================================================ + +import math # 数学函数库 +from dataclasses import dataclass, field # 数据类装饰器,用于定义配置类 +from typing import List, Tuple # 类型注解支持 + +import numpy as np # 数值计算库 +import torch # PyTorch深度学习框架 +import torch.nn as nn # PyTorch神经网络模块 +import torch.nn.functional as F # PyTorch函数式接口 + +from fairseq import utils # Fairseq工具函数 +from fairseq.data.data_utils import compute_mask_indices # 掩码索引计算工具 +from fairseq.dataclass import ChoiceEnum, FairseqDataclass # Fairseq数据类和选择枚举 +from fairseq.distributed import fsdp_wrap # 完全分片数据并行包装器 +from fairseq.models import BaseFairseqModel, register_model # Fairseq基础模型类和注册装饰器 +from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel # 完全分片数据并行 +from fairseq.modules import ( # Fairseq预定义模块 + Fp32GroupNorm, # 32位浮点组归一化 + Fp32LayerNorm, # 32位浮点层归一化 + GradMultiply, # 梯度乘法器(用于特征提取器梯度缩放) + GumbelVectorQuantizer, # Gumbel向量量化器 + LayerNorm, # 层归一化 + MultiheadAttention, # 多头注意力机制 + RelPositionalEncoding, # 相对位置编码 + SamePad, # 保持输入输出尺寸的填充 + TransposeLast, # 转置最后维度的工具模块 ) -from fairseq.modules.checkpoint_activations import checkpoint_wrapper -from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer -from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange, index_put, is_xla_tensor +from fairseq.modules.checkpoint_activations import checkpoint_wrapper # 激活检查点包装器(节省内存) +from fairseq.modules.conformer_layer import ConformerWav2Vec2EncoderLayer # Conformer编码器层 +from fairseq.modules.transformer_sentence_encoder import init_bert_params # BERT参数初始化 +from fairseq.utils import buffered_arange, index_put, is_xla_tensor # 工具函数 -from .utils import pad_to_multiple +from .utils import pad_to_multiple # 填充到指定倍数的工具函数 -EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) -MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) -LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"]) +# ============================================================================ +# 枚举类型定义 +# ============================================================================ +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) # 特征提取器模式:默认(组归一化)或层归一化 +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) # 掩码长度分布:静态、均匀、正态、泊松 +LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"]) # 编码器层类型:Transformer、Conformer、带适配器的Transformer @dataclass class Wav2Vec2Config(FairseqDataclass): + """ + Wav2Vec 2.0模型配置类 + 包含模型架构、训练策略、掩码策略、向量量化等所有可配置参数 + 相比Wav2Vec 1.0,增加了Transformer架构和掩码预测相关配置 + """ + + # ============================================================================ + # 特征提取器配置 + # ============================================================================ extractor_mode: EXTRACTOR_MODE_CHOICES = field( default="default", metadata={ @@ -50,380 +76,562 @@ class Wav2Vec2Config(FairseqDataclass): "groups in the first conv block, whereas layer_norm has layer norms in " "every block (meant to use with normalize=True)" }, - ) + ) # 特征提取器模式:default(第一层组归一化)或layer_norm(每层层归一化) + + # ============================================================================ + # Transformer编码器架构配置 + # ============================================================================ encoder_layers: int = field( default=12, metadata={"help": "num encoder layers in the transformer"} - ) + ) # Transformer编码器层数:默认12层(Base模型),Large模型为24层 + encoder_embed_dim: int = field( default=768, metadata={"help": "encoder embedding dimension"} - ) + ) # 编码器嵌入维度:Base模型768,Large模型1024 + encoder_ffn_embed_dim: int = field( default=3072, metadata={"help": "encoder embedding dimension for FFN"} - ) + ) # 前馈网络隐藏层维度:通常为embed_dim的4倍 + encoder_attention_heads: int = field( default=12, metadata={"help": "num encoder attention heads"} - ) + ) # 多头注意力的头数:Base模型12头,Large模型16头 + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( default="gelu", metadata={"help": "activation function to use"} - ) + ) # 激活函数:GELU在Transformer中表现更好 + layer_type: LAYER_TYPE_CHOICES = field( default="transformer", metadata={"help": "layer type in encoder"} - ) - # dropouts + ) # 编码器层类型:transformer、conformer或trf_adp(带适配器) + + # ============================================================================ + # Dropout正则化配置 + # ============================================================================ dropout: float = field( default=0.1, metadata={"help": "dropout probability for the transformer"} - ) + ) # Transformer整体dropout概率 + attention_dropout: float = field( default=0.1, metadata={"help": "dropout probability for attention weights"} - ) + ) # 注意力权重dropout概率 + activation_dropout: float = field( default=0.0, metadata={"help": "dropout probability after activation in FFN"} - ) + ) # 前馈网络激活函数后的dropout概率 + encoder_layerdrop: float = field( default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} - ) + ) # LayerDrop:随机跳过整个Transformer层的概率 + dropout_input: float = field( default=0.0, metadata={"help": "dropout to apply to the input (after feat extr)"}, - ) + ) # 特征提取后、输入编码器前的dropout概率 + dropout_features: float = field( default=0.0, metadata={"help": "dropout to apply to the features (after feat extr)"}, - ) + ) # 特征层的dropout概率(用于未掩码特征) + # ============================================================================ + # 投影和归一化配置 + # ============================================================================ final_dim: int = field( default=0, metadata={ "help": "project final representations and targets to this many dimensions." "set to encoder_embed_dim is <= 0" }, - ) + ) # 最终输出维度:0表示使用encoder_embed_dim,>0则投影到指定维度 + layer_norm_first: bool = field( default=False, metadata={"help": "apply layernorm first in the transformer"} - ) + ) # 是否在Transformer中先应用LayerNorm(Pre-LN):True为Pre-LN,False为Post-LN + + # ============================================================================ + # 卷积特征提取器配置 + # ============================================================================ conv_feature_layers: str = field( default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", metadata={ "help": "string describing convolutional feature extraction layers in form of a python list that contains " "[(dim, kernel_size, stride), ...]" }, - ) + ) # 卷积特征提取层配置:相比Wav2Vec 1.0,使用更小的卷积核和步长 + conv_bias: bool = field( default=False, metadata={"help": "include bias in conv encoder"} - ) + ) # 卷积层是否使用偏置项:通常配合归一化时设为False + + # ============================================================================ + # 对比学习和损失函数配置 + # ============================================================================ logit_temp: float = field( default=0.1, metadata={"help": "temperature to divide logits by"} - ) + ) # 对比学习的温度参数:控制相似度分布的锐利程度 + + # ============================================================================ + # 向量量化配置 + # ============================================================================ quantize_targets: bool = field( default=False, metadata={"help": "use quantized targets"} - ) + ) # 是否对目标特征进行向量量化:离散化目标表示 + quantize_input: bool = field( default=False, metadata={"help": "use quantized inputs"} - ) + ) # 是否对输入特征进行向量量化:离散化输入表示 + same_quantizer: bool = field( default=False, metadata={"help": "use same quantizer for inputs and targets"} - ) + ) # 输入和目标是否使用相同的量化器:共享量化参数 + target_glu: bool = field( default=False, metadata={"help": "adds projection + glu to targets"} - ) + ) # 目标特征是否使用GLU(门控线性单元):增强非线性表达能力 + feature_grad_mult: float = field( default=1.0, metadata={"help": "multiply feature extractor var grads by this"} - ) + ) # 特征提取器梯度乘数:控制特征提取器的学习速度 + quantizer_depth: int = field( default=1, metadata={"help": "number of quantizer layers"}, - ) + ) # 量化器层数:深度量化器可能有更好的表示能力 + quantizer_factor: int = field( default=3, metadata={ "help": "dimensionality increase for inner quantizer layers (if depth > 1)" }, - ) + ) # 量化器内部层维度扩展因子:depth>1时内部层的维度倍数 + latent_vars: int = field( default=320, metadata={"help": "number of latent variables V in each group of the codebook"}, - ) + ) # 每组码本中的潜在变量数量V:码本大小,影响离散化粒度 + latent_groups: int = field( default=2, metadata={"help": "number of groups G of latent variables in the codebook"}, - ) + ) # 码本分组数量G:乘积量化,总码本大小为V^G + latent_dim: int = field( default=0, metadata={ "help": "if > 0, uses this dimensionality for latent variables. " "otherwise uses final_dim / latent_groups" }, - ) + ) # 潜在变量维度:0表示使用final_dim/latent_groups,>0则使用指定维度 - # masking + # ============================================================================ + # 时序掩码策略配置 (核心创新:BERT式掩码预测) + # ============================================================================ mask_length: int = field(default=10, metadata={"help": "mask length"}) + # 掩码长度:连续掩码的时间步数,通常为10个时间步(约100ms) + mask_prob: float = field( default=0.65, metadata={"help": "probability of replacing a token with mask"} - ) + ) # 掩码概率:65%的时间步会被掩码,比BERT的15%更高(语音信号冗余度更大) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( default="static", metadata={"help": "how to choose mask length"} - ) + ) # 掩码长度选择策略:static(固定)、uniform(均匀)、normal(正态)、poisson(泊松) + mask_other: float = field( default=0, metadata={ "help": "secondary mask argument (used for more complex distributions), " "see help in compute_mask_indices" }, - ) + ) # 掩码分布的辅助参数:用于复杂分布(如泊松分布的lambda参数) + no_mask_overlap: bool = field( default=False, metadata={"help": "whether to allow masks to overlap"} - ) + ) # 是否允许掩码重叠:False允许重叠,True禁止重叠 + mask_min_space: int = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, - ) + ) # 掩码间最小间隔:禁止重叠时,掩码段之间的最小距离 + require_same_masks: bool = field( default=True, metadata={ "help": "whether to number of masked timesteps must be the same across all " "examples in a batch" }, - ) + ) # 是否要求批次内样本的掩码数量相同:保持批处理一致性 + mask_dropout: float = field( default=0.0, metadata={"help": "percent of masks to unmask for each sample"}, - ) - - # channel masking + ) # 掩码dropout:随机取消一部分掩码,增加训练多样性 + + # ============================================================================ + # 通道(特征维度)掩码策略配置 + # ============================================================================ mask_channel_length: int = field( default=10, metadata={"help": "length of the mask for features (channels)"} - ) + ) # 通道掩码长度:连续掩码的特征维度数 + mask_channel_prob: float = field( default=0.0, metadata={"help": "probability of replacing a feature with 0"} - ) + ) # 通道掩码概率:特征维度被掩码的概率(默认不使用) + mask_channel_before: bool = False + # 是否在特征提取后立即应用通道掩码:True为提取后,False为编码器前 + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( default="static", metadata={"help": "how to choose mask length for channel masking"}, - ) + ) # 通道掩码长度选择策略:与时序掩码类似 + mask_channel_other: float = field( default=0, metadata={ "help": "secondary mask argument (used for more complex distributions), " "see help in compute_mask_indicesh" }, - ) + ) # 通道掩码分布的辅助参数 + no_mask_channel_overlap: bool = field( default=False, metadata={"help": "whether to allow channel masks to overlap"} - ) + ) # 是否允许通道掩码重叠 + mask_channel_min_space: int = field( default=1, metadata={"help": "min space between spans (if no overlap is enabled)"}, - ) + ) # 通道掩码间最小间隔 - # negative selection + # ============================================================================ + # 负样本采样策略配置 + # ============================================================================ num_negatives: int = field( default=100, metadata={"help": "number of negative examples from the same sample"}, - ) + ) # 同样本负样本数:从当前样本的其他位置采样的负样本数量 + negatives_from_everywhere: bool = field( default=False, metadata={"help": "sample negatives from everywhere, not just masked states"}, - ) + ) # 是否从所有位置采样负样本:False仅从掩码位置,True从所有位置 + cross_sample_negatives: int = field( default=0, metadata={"help": "number of negative examples from the any sample"} - ) + ) # 跨样本负样本数:从批次中其他样本采样的负样本数量 + codebook_negatives: int = field( default=0, metadata={"help": "number of negative examples codebook"} - ) - - # positional embeddings + ) # 码本负样本数:直接从量化器码本采样的负样本数量 + + # ============================================================================ + # 位置编码配置 (相对位置编码) + # ============================================================================ conv_pos: int = field( default=128, metadata={"help": "number of filters for convolutional positional embeddings"}, - ) + ) # 卷积位置编码的卷积核数量:替代绝对位置编码 + conv_pos_groups: int = field( default=16, metadata={"help": "number of groups for convolutional positional embedding"}, - ) + ) # 卷积位置编码的分组数:分组卷积提高效率 + pos_conv_depth: int = field( default=1, metadata={"help": "depth of positional encoder network"}, - ) - + ) # 位置编码网络深度:多层卷积位置编码 + + # ============================================================================ + # 量化器温度和训练配置 + # ============================================================================ latent_temp: Tuple[float, float, float] = field( default=(2, 0.5, 0.999995), metadata={ "help": "temperature for latent variable sampling. " "can be tuple of 3 values (start, end, decay)" }, - ) + ) # Gumbel Softmax温度参数:(起始温度, 结束温度, 衰减率) + max_positions: int = field(default=100000, metadata={"help": "Max positions"}) + # 模型支持的最大位置数:理论上的序列长度限制 + checkpoint_activations: bool = field( default=False, metadata={"help": "recompute activations and save memory for extra compute"}, - ) - - # FP16 optimization + ) # 激活检查点:重计算激活以节省内存(增加计算开销) + + # ============================================================================ + # FP16和序列长度优化配置 + # ============================================================================ required_seq_len_multiple: int = field( default=2, metadata={ "help": "pad the input to encoder such that the sequence length is divisible by multiple" }, - ) + ) # 编码器输入序列长度必须是此数的倍数:优化FP16性能 + crop_seq_to_multiple: int = field( default=1, metadata={ "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple" }, - ) - - # Conformer + ) # 特征提取器输出裁剪到此数的倍数:对齐要求 + + # ============================================================================ + # Conformer架构专用配置 + # ============================================================================ depthwise_conv_kernel_size: int = field( default=31, metadata={ "help": "depthwise-conv-kernel-size for convolution in conformer layer" }, - ) + ) # Conformer中深度可分离卷积的核大小:通常为奇数 + attn_type: str = field( default="", metadata={"help": "if espnet use ESPNET MHA"}, - ) + ) # 注意力类型:兼容ESPnet的多头注意力实现 + pos_enc_type: str = field( default="abs", metadata={"help": "Positional encoding type to use in conformer"}, - ) + ) # 位置编码类型:abs(绝对)、rel_pos(相对)、rope(旋转位置编码) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) - - # Adapter num + # 是否使用FP16:混合精度训练标志 + + # ============================================================================ + # 适配器(Adapter)模块配置 + # ============================================================================ adp_num: int = field( default=-1 - ) + ) # 适配器数量:-1表示不使用适配器 + adp_dim: int = field( default=64 - ) + ) # 适配器隐藏层维度:瓶颈架构的中间维度 + adp_act_fn: str = field( default="relu" - ) + ) # 适配器激活函数:ReLU、GELU等 + adp_trf_idx: str = field( default="all", - ) + ) # 适配器应用的Transformer层索引:all表示所有层,或指定层范围 -@register_model("wav2vec2", dataclass=Wav2Vec2Config) +@register_model("wav2vec2", dataclass=Wav2Vec2Config) # 注册模型到Fairseq框架 class Wav2Vec2Model(BaseFairseqModel): + """ + Wav2Vec 2.0 主模型类 + + 相比Wav2Vec 1.0的主要改进: + 1. 使用Transformer编码器替代卷积聚合器 + 2. 采用掩码预测任务而非对比预测编码 + 3. 更灵活的向量量化策略 + 4. 支持Conformer等新架构 + + 模型架构: + 1. 卷积特征提取器 (ConvFeatureExtractionModel) + 2. 可选的特征投影层 + 3. 掩码应用 (时序掩码 + 通道掩码) + 4. Transformer/Conformer编码器 + 5. 最终投影层和对比学习头 + """ + def __init__(self, cfg: Wav2Vec2Config): - super().__init__() - self.cfg = cfg + """ + 初始化Wav2Vec 2.0模型 + Args: + cfg: 模型配置参数 + """ + super().__init__() # 调用父类初始化 + self.cfg = cfg # 保存配置引用 - feature_enc_layers = eval(cfg.conv_feature_layers) - self.embed = feature_enc_layers[-1][0] + # ============================================================================ + # 1. 卷积特征提取器初始化 + # ============================================================================ + feature_enc_layers = eval(cfg.conv_feature_layers) # 解析卷积层配置字符串 + self.embed = feature_enc_layers[-1][0] # 获取最后一层的输出维度作为嵌入维度 self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - mode=cfg.extractor_mode, - conv_bias=cfg.conv_bias, + conv_layers=feature_enc_layers, # 卷积层配置列表 + dropout=0.0, # 特征提取器内部不使用dropout + mode=cfg.extractor_mode, # 归一化模式:default或layer_norm + conv_bias=cfg.conv_bias, # 是否使用卷积偏置 ) + # ============================================================================ + # 2. 特征投影层 (可选) + # ============================================================================ self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None - ) - - self.crop_seq_to_multiple = cfg.crop_seq_to_multiple - - self.mask_prob = cfg.mask_prob - self.mask_selection = cfg.mask_selection - self.mask_other = cfg.mask_other - self.mask_length = cfg.mask_length - self.no_mask_overlap = cfg.no_mask_overlap - self.mask_min_space = cfg.mask_min_space - - self.mask_channel_prob = cfg.mask_channel_prob - self.mask_channel_before = cfg.mask_channel_before - self.mask_channel_selection = cfg.mask_channel_selection - self.mask_channel_other = cfg.mask_channel_other - self.mask_channel_length = cfg.mask_channel_length - self.no_mask_channel_overlap = cfg.no_mask_channel_overlap - self.mask_channel_min_space = cfg.mask_channel_min_space - - self.dropout_input = nn.Dropout(cfg.dropout_input) - self.dropout_features = nn.Dropout(cfg.dropout_features) - - self.feature_grad_mult = cfg.feature_grad_mult - - self.quantizer = None - self.input_quantizer = None - - self.n_negatives = cfg.num_negatives - self.cross_sample_negatives = cfg.cross_sample_negatives - self.codebook_negatives = cfg.codebook_negatives - self.negatives_from_everywhere = cfg.negatives_from_everywhere - - self.logit_temp = cfg.logit_temp - - final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - + ) # 当特征维度与编码器维度不匹配且不使用输入量化时,添加线性投影层 + + # ============================================================================ + # 3. 序列长度处理配置 + # ============================================================================ + self.crop_seq_to_multiple = cfg.crop_seq_to_multiple # 裁剪序列长度到指定倍数 + + # ============================================================================ + # 4. 时序掩码参数配置 + # ============================================================================ + self.mask_prob = cfg.mask_prob # 掩码概率 + self.mask_selection = cfg.mask_selection # 掩码长度选择策略 + self.mask_other = cfg.mask_other # 掩码分布辅助参数 + self.mask_length = cfg.mask_length # 掩码长度 + self.no_mask_overlap = cfg.no_mask_overlap # 是否禁止掩码重叠 + self.mask_min_space = cfg.mask_min_space # 掩码间最小间隔 + + # ============================================================================ + # 5. 通道掩码参数配置 + # ============================================================================ + self.mask_channel_prob = cfg.mask_channel_prob # 通道掩码概率 + self.mask_channel_before = cfg.mask_channel_before # 是否在特征提取后立即掩码 + self.mask_channel_selection = cfg.mask_channel_selection # 通道掩码长度选择策略 + self.mask_channel_other = cfg.mask_channel_other # 通道掩码分布辅助参数 + self.mask_channel_length = cfg.mask_channel_length # 通道掩码长度 + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap # 是否禁止通道掩码重叠 + self.mask_channel_min_space = cfg.mask_channel_min_space # 通道掩码间最小间隔 + + # ============================================================================ + # 6. Dropout层初始化 + # ============================================================================ + self.dropout_input = nn.Dropout(cfg.dropout_input) # 输入dropout + self.dropout_features = nn.Dropout(cfg.dropout_features) # 特征dropout + + # ============================================================================ + # 7. 特征提取器梯度控制 + # ============================================================================ + self.feature_grad_mult = cfg.feature_grad_mult # 特征提取器梯度乘数 + + # ============================================================================ + # 8. 量化器初始化 (延后初始化) + # ============================================================================ + self.quantizer = None # 目标量化器 + self.input_quantizer = None # 输入量化器 + + # ============================================================================ + # 9. 负样本采样配置 + # ============================================================================ + self.n_negatives = cfg.num_negatives # 同样本负样本数 + self.cross_sample_negatives = cfg.cross_sample_negatives # 跨样本负样本数 + self.codebook_negatives = cfg.codebook_negatives # 码本负样本数 + self.negatives_from_everywhere = cfg.negatives_from_everywhere # 是否从所有位置采样负样本 + + # ============================================================================ + # 10. 对比学习温度参数 + # ============================================================================ + self.logit_temp = cfg.logit_temp # 对比学习的温度参数 + + # ============================================================================ + # 11. 最终输出维度计算 + # ============================================================================ + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim # 最终投影维度 + + # ============================================================================ + # 12. 目标量化器初始化 (可选) + # ============================================================================ if cfg.quantize_targets: + # 计算量化向量维度 vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( - dim=self.embed, - num_vars=cfg.latent_vars, - temp=cfg.latent_temp, - groups=cfg.latent_groups, - combine_groups=False, - vq_dim=vq_dim, - time_first=True, - weight_proj_depth=cfg.quantizer_depth, - weight_proj_factor=cfg.quantizer_factor, + dim=self.embed, # 输入特征维度 + num_vars=cfg.latent_vars, # 每组码本大小 + temp=cfg.latent_temp, # 温度调度参数 + groups=cfg.latent_groups, # 量化分组数 + combine_groups=False, # 不合并分组 + vq_dim=vq_dim, # 量化向量维度 + time_first=True, # 时间维度在前(与Wav2Vec 1.0不同) + weight_proj_depth=cfg.quantizer_depth, # 权重投影层深度 + weight_proj_factor=cfg.quantizer_factor, # 权重投影扩展因子 ) - self.project_q = nn.Linear(vq_dim, final_dim) + self.project_q = nn.Linear(vq_dim, final_dim) # 量化特征投影到最终维度 else: + # 不使用量化时直接投影原始特征 self.project_q = nn.Linear(self.embed, final_dim) + # ============================================================================ + # 13. 输入量化器初始化 (可选) + # ============================================================================ if cfg.quantize_input: if cfg.same_quantizer and self.quantizer is not None: + # 输入和目标共享同一个量化器 vq_dim = final_dim self.input_quantizer = self.quantizer else: + # 为输入创建独立的量化器 vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( - dim=self.embed, - num_vars=cfg.latent_vars, - temp=cfg.latent_temp, - groups=cfg.latent_groups, - combine_groups=False, - vq_dim=vq_dim, - time_first=True, - weight_proj_depth=cfg.quantizer_depth, - weight_proj_factor=cfg.quantizer_factor, + dim=self.embed, # 输入特征维度 + num_vars=cfg.latent_vars, # 每组码本大小 + temp=cfg.latent_temp, # 温度调度参数 + groups=cfg.latent_groups, # 量化分组数 + combine_groups=False, # 不合并分组 + vq_dim=vq_dim, # 量化向量维度 + time_first=True, # 时间维度在前 + weight_proj_depth=cfg.quantizer_depth, # 权重投影层深度 + weight_proj_factor=cfg.quantizer_factor, # 权重投影扩展因子 ) - self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) # 量化输入投影到编码器维度 + # ============================================================================ + # 14. 掩码嵌入向量 + # ============================================================================ self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_() - ) - encoder_cls = TransformerEncoder + ) # 掩码令牌的可学习嵌入向量,用均匀分布初始化 + + # ============================================================================ + # 15. 编码器选择和初始化 + # ============================================================================ + encoder_cls = TransformerEncoder # 默认使用Transformer编码器 if cfg.layer_type == "conformer" and cfg.pos_enc_type in ["rel_pos", "rope"]: - encoder_cls = ConformerEncoder + encoder_cls = ConformerEncoder # 使用Conformer编码器(需要相对位置编码) - self.encoder = encoder_cls(cfg) - self.layer_norm = LayerNorm(self.embed) + self.encoder = encoder_cls(cfg) # 初始化编码器 + self.layer_norm = LayerNorm(self.embed) # 特征归一化层 + # ============================================================================ + # 16. 目标GLU层 (可选) + # ============================================================================ self.target_glu = None if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU() - ) + ) # 门控线性单元:增强目标特征的非线性表达能力 - self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + # ============================================================================ + # 17. 最终投影层 + # ============================================================================ + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) # 编码器输出投影到最终维度 def upgrade_state_dict_named(self, state_dict, name): + """ + 升级状态字典以兼容新版本的Fairseq + Args: + state_dict: 模型状态字典 + name: 模型名称 + Returns: + 更新后的状态字典 + """ super().upgrade_state_dict_named(state_dict, name) - """Upgrade a (possibly old) state dict for new versions of fairseq.""" return state_dict @classmethod def build_model(cls, cfg: Wav2Vec2Config, task=None): - """Build a new model instance.""" - + """ + 构建新的模型实例 (工厂方法) + Args: + cfg: 模型配置 + task: 任务配置(可选) + Returns: + Wav2Vec2Model实例 + """ return cls(cfg) def apply_mask( @@ -433,104 +641,162 @@ def apply_mask( mask_indices=None, mask_channel_indices=None, ): - B, T, C = x.shape + """ + 应用掩码到输入特征 (Wav2Vec 2.0的核心创新) + + 掩码策略: + 1. 通道掩码(可选):在特征维度上掩码,类似SpecAugment + 2. 时序掩码:在时间维度上掩码,类似BERT的[MASK] + + Args: + x: 输入特征 [B, T, C] + padding_mask: 填充掩码 [B, T] + mask_indices: 预计算的时序掩码索引(可选) + mask_channel_indices: 预计算的通道掩码索引(可选) + + Returns: + masked_x: 掩码后的特征 [B, T, C] + mask_indices: 时序掩码索引 [B, T] (用于损失计算) + """ + B, T, C = x.shape # 批次大小、时间步数、特征维度 + # ============================================================================ + # 1. 早期通道掩码 (特征提取后立即掩码) + # ============================================================================ if self.mask_channel_prob > 0 and self.mask_channel_before: mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, + (B, C), # 掩码形状:批次×特征维度 + None, # 通道掩码不考虑padding + self.mask_channel_prob, # 通道掩码概率 + self.mask_channel_length, # 连续掩码的特征维度数 + self.mask_channel_selection, # 掩码长度分布策略 + self.mask_channel_other, # 分布辅助参数 + no_overlap=self.no_mask_channel_overlap, # 是否禁止重叠 + min_space=self.mask_channel_min_space, # 最小间隔 ) + # 转换为张量并扩展到时间维度:[B, C] -> [B, T, C] mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) + .unsqueeze(1) # [B, C] -> [B, 1, C] + .expand(-1, T, -1) # [B, 1, C] -> [B, T, C] ) - x[mask_channel_indices] = 0 + x[mask_channel_indices] = 0 # 将掩码位置置零 + # ============================================================================ + # 2. 时序掩码 (核心:BERT式掩码预测) + # ============================================================================ if self.mask_prob > 0: if mask_indices is None: mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - require_same_masks=self.cfg.require_same_masks, - mask_dropout=self.cfg.mask_dropout, + (B, T), # 掩码形状:批次×时间步 + padding_mask, # 考虑序列的有效长度 + self.mask_prob, # 掩码概率(65%) + self.mask_length, # 连续掩码长度(10步) + self.mask_selection, # 掩码长度分布策略 + self.mask_other, # 分布辅助参数 + min_masks=2, # 最少掩码数量 + no_overlap=self.no_mask_overlap, # 是否禁止重叠 + min_space=self.mask_min_space, # 最小间隔 + require_same_masks=self.cfg.require_same_masks, # 批次内掩码数量一致 + mask_dropout=self.cfg.mask_dropout, # 掩码dropout ) mask_indices = torch.from_numpy(mask_indices).to(x.device) + # 用可学习的掩码嵌入向量替换掩码位置的特征 x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None + # ============================================================================ + # 3. 晚期通道掩码 (编码器前掩码) + # ============================================================================ if self.mask_channel_prob > 0 and not self.mask_channel_before: if mask_channel_indices is None: mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, + (B, C), # 掩码形状:批次×特征维度 + None, # 通道掩码不考虑padding + self.mask_channel_prob, # 通道掩码概率 + self.mask_channel_length, # 连续掩码的特征维度数 + self.mask_channel_selection, # 掩码长度分布策略 + self.mask_channel_other, # 分布辅助参数 + no_overlap=self.no_mask_channel_overlap, # 是否禁止重叠 + min_space=self.mask_channel_min_space, # 最小间隔 ) + # 转换为张量并扩展到时间维度 mask_channel_indices = ( torch.from_numpy(mask_channel_indices) .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) + .unsqueeze(1) # [B, C] -> [B, 1, C] + .expand(-1, T, -1) # [B, 1, C] -> [B, T, C] ) - x = index_put(x, mask_channel_indices, 0) + x = index_put(x, mask_channel_indices, 0) # 将掩码位置置零 return x, mask_indices def sample_negatives(self, y, num, padding_count=None): - + """ + 采样负样本用于对比学习 (InfoNCE损失) + + 负样本策略: + 1. 同样本负样本:从当前样本的其他时间步采样 + 2. 跨样本负样本:从批次中其他样本采样 + + Args: + y: 目标特征 [B, T, C] + num: 每个位置需要的负样本总数 + padding_count: 填充的时间步数(可选) + + Returns: + negs: 负样本特征 [N, B, T, C] (N=负样本数) + neg_idxs: 负样本索引 + """ + if self.n_negatives == 0 and self.cross_sample_negatives == 0: - return y.new(0) + return y.new(0) # 不使用负样本时返回空张量 - bsz, tsz, fsz = y.shape - y = y.view(-1, fsz) # BTC => (BxT)C + bsz, tsz, fsz = y.shape # 批次大小、时间步数、特征维度 + y = y.view(-1, fsz) # 重塑为 [B*T, C] 便于索引 - # FIXME: what happens if padding_count is specified? - cross_high = tsz * bsz - high = tsz - (padding_count or 0) + # ============================================================================ + # 计算采样范围 + # ============================================================================ + cross_high = tsz * bsz # 跨样本采样的总范围 + high = tsz - (padding_count or 0) # 有效时间步数(排除填充) + with torch.no_grad(): - assert high > 1, f"{bsz,tsz,fsz}" + assert high > 1, f"有效时间步数必须>1, 当前: {bsz,tsz,fsz}" + # ======================================================================== + # 1. 同样本负样本采样 + # ======================================================================== if self.n_negatives > 0: + # 为每个掩码位置生成时间步索引 tszs = ( - buffered_arange(num) - .unsqueeze(-1) - .expand(-1, self.n_negatives) - .flatten() + buffered_arange(num) # [0, 1, ..., num-1] + .unsqueeze(-1) # [num, 1] + .expand(-1, self.n_negatives) # [num, n_negatives] + .flatten() # [num * n_negatives] ) + # 随机采样负样本索引(避免采样到正样本位置) neg_idxs = torch.randint( low=0, high=high - 1, size=(bsz, self.n_negatives * num) ) + # 当索引>=当前位置时,向后偏移1以跳过正样本 neg_idxs[neg_idxs >= tszs] += 1 + # ======================================================================== + # 2. 跨样本负样本采样 + # ======================================================================== if self.cross_sample_negatives > 0: tszs = ( - buffered_arange(num) - .unsqueeze(-1) - .expand(-1, self.cross_sample_negatives) - .flatten() + buffered_arange(num) # [0, 1, ..., num-1] + .unsqueeze(-1) # [num, 1] + .expand(-1, self.cross_sample_negatives) # [num, cross_negatives] + .flatten() # [num * cross_negatives] ) + # 从整个批次范围内随机采样 cross_neg_idxs = torch.randint( low=0, high=cross_high - 1, @@ -538,40 +804,81 @@ def sample_negatives(self, y, num, padding_count=None): ) cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + # ============================================================================ + # 3. 合并负样本索引 + # ============================================================================ if self.n_negatives > 0: + # 为同样本负样本添加批次偏移 neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) else: neg_idxs = cross_neg_idxs if self.cross_sample_negatives > 0 and self.n_negatives > 0: + # 合并同样本和跨样本负样本 neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) - negs = y[neg_idxs.view(-1)] + # ============================================================================ + # 4. 根据索引提取负样本特征 + # ============================================================================ + negs = y[neg_idxs.view(-1)] # 提取负样本特征 negs = negs.view( bsz, num, self.n_negatives + self.cross_sample_negatives, fsz ).permute( 2, 0, 1, 3 - ) # to NxBxTxC + ) # 重塑为 [N, B, T, C] 格式 + return negs, neg_idxs def compute_preds(self, x, y, negatives): - - neg_is_pos = (y == negatives).all(-1) - y = y.unsqueeze(0) - targets = torch.cat([y, negatives], dim=0) - + """ + 计算对比学习的预测logits (InfoNCE损失核心) + + 计算上下文表示与正/负样本的相似度,用于对比学习 + + Args: + x: 上下文表示(编码器输出) [B, T, C] + y: 正样本目标特征 [B, T, C] + negatives: 负样本特征 [N, B, T, C] + + Returns: + logits: 相似度分数 [N+1, B, T] (第0维是正样本,其余是负样本) + """ + + # ============================================================================ + # 1. 检测负样本中是否有与正样本相同的 (避免虚假负样本) + # ============================================================================ + neg_is_pos = (y == negatives).all(-1) # 检查负样本是否与正样本相同 + + # ============================================================================ + # 2. 合并正样本和负样本 + # ============================================================================ + y = y.unsqueeze(0) # [B, T, C] -> [1, B, T, C] + targets = torch.cat([y, negatives], dim=0) # [N+1, B, T, C] + + # ============================================================================ + # 3. 计算余弦相似度 + # ============================================================================ logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) - logits = logits / self.logit_temp - logits = logits.type_as(x) - + # 余弦相似度:cos(θ) = (x·y)/(||x||·||y||),范围[-1, 1] + + # ============================================================================ + # 4. 温度缩放 + # ============================================================================ + logits = logits / self.logit_temp # 温度缩放:τ越小分布越锐利 + logits = logits.type_as(x) # 保持原始数据类型 + + # ============================================================================ + # 5. 处理虚假负样本 (将相同的负样本设为负无穷) + # ============================================================================ if is_xla_tensor(logits) or neg_is_pos.any(): if not hasattr(self, "_inftensor"): - fillval = -float(2**30) + fillval = -float(2**30) # 接近负无穷的值 self._inftensor = ( torch.tensor(fillval).to(x.device) if is_xla_tensor(logits) else float("-inf") ) + # 将虚假负样本的logits设为负无穷(在softmax中概率趋于0) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits @@ -605,243 +912,430 @@ def forward( padding_count=None, corpus_key=None, ): + """ + Wav2Vec 2.0主模型前向传播 + + 完整的自监督学习流程: + 1. 特征提取:原始音频 -> 卷积特征 + 2. 掩码应用:BERT式掩码预测 + 3. 上下文编码:Transformer编码器 + 4. 对比学习:InfoNCE损失 + + Args: + source: 原始音频波形 [B, T] + padding_mask: 填充掩码 [B, T] + mask: 是否应用掩码(训练时True,推理时False) + features_only: 是否只返回特征(用于下游任务) + layer: 目标编码器层 + mask_indices: 预计算的掩码索引 + mask_channel_indices: 预计算的通道掩码索引 + padding_count: 填充计数 + corpus_key: 语料库键(适配器选择) + + Returns: + result: 包含logits、掩码、困惑度等的字典 + """ + # ============================================================================ + # 1. 卷积特征提取 (原始音频 -> 高层特征) + # ============================================================================ if self.feature_grad_mult > 0: - features = self.feature_extractor(source) + # 允许梯度回传到特征提取器 + features = self.feature_extractor(source) # [B, T] -> [B, C, T] if self.feature_grad_mult != 1.0: + # 梯度缩放:控制特征提取器的学习速度 features = GradMultiply.apply(features, self.feature_grad_mult) else: + # 冻结特征提取器:不更新预训练的卷积层 with torch.no_grad(): features = self.feature_extractor(source) - features_pen = features.float().pow(2).mean() - - features = features.transpose(1, 2) - features = self.layer_norm(features) - unmasked_features = features.clone() - + # ============================================================================ + # 2. 特征惩罚项 (正则化) + # ============================================================================ + features_pen = features.float().pow(2).mean() # L2惩罚项:防止特征过大 + + # ============================================================================ + # 3. 特征格式转换和归一化 + # ============================================================================ + features = features.transpose(1, 2) # [B, C, T] -> [B, T, C] (Transformer格式) + features = self.layer_norm(features) # 层归一化:稳定训练 + unmasked_features = features.clone() # 保存未掩码的特征用于目标计算 + + # ============================================================================ + # 4. 填充掩码处理 (处理变长序列) + # ============================================================================ if padding_mask is not None and padding_mask.any(): - input_lengths = (1 - padding_mask.long()).sum(-1) - # apply conv formula to get real output_lengths + # 计算每个样本的有效长度 + input_lengths = (1 - padding_mask.long()).sum(-1) # [B] + # 根据卷积下采样计算输出长度 output_lengths = self._get_feat_extract_output_lengths(input_lengths) + # 重新构建padding_mask以匹配特征序列长度 padding_mask = torch.zeros( features.shape[:2], dtype=features.dtype, device=features.device ) - # these two operations makes sure that all values - # before the output lengths indices are attended to + # 标记有效序列的最后一个位置 padding_mask[ ( torch.arange(padding_mask.shape[0], device=padding_mask.device), - output_lengths - 1, + output_lengths - 1, # 最后一个有效位置 ) ] = 1 + # 使用cumsum创建掩码:有效位置为False,填充位置为True padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() else: padding_mask = None + # ============================================================================ + # 5. 序列长度裁剪 (满足模型架构要求) + # ============================================================================ time_steps_to_drop = features.size(1) % self.crop_seq_to_multiple if time_steps_to_drop != 0: + # 裁剪到指定倍数,确保后续处理的兼容性 features = features[:, :-time_steps_to_drop] unmasked_features = unmasked_features[:, :-time_steps_to_drop] if padding_mask is not None: padding_mask = padding_mask[:, :-time_steps_to_drop] + # ============================================================================ + # 6. 特征投影 (维度对齐) + # ============================================================================ if self.post_extract_proj is not None: + # 将卷积特征投影到编码器维度 features = self.post_extract_proj(features) - features = self.dropout_input(features) - unmasked_features = self.dropout_features(unmasked_features) - - num_vars = None - code_ppl = None - prob_ppl = None - curr_temp = None - + # ============================================================================ + # 7. 输入dropout (防止过拟合) + # ============================================================================ + features = self.dropout_input(features) # 对输入特征应用dropout + unmasked_features = self.dropout_features(unmasked_features) # 对未掩码特征应用dropout + + # ============================================================================ + # 8. 量化器状态变量初始化 + # ============================================================================ + num_vars = None # 码本变量数量 + code_ppl = None # 码本困惑度 + prob_ppl = None # 概率困惑度 + curr_temp = None # 当前温度参数 + + # ============================================================================ + # 9. 输入量化 (可选) + # ============================================================================ if self.input_quantizer: + # 对输入特征进行向量量化 q = self.input_quantizer(features, produce_targets=False) - features = q["x"] - num_vars = q["num_vars"] - code_ppl = q["code_perplexity"] - prob_ppl = q["prob_perplexity"] - curr_temp = q["temp"] - features = self.project_inp(features) - + features = q["x"] # 量化后的特征 + num_vars = q["num_vars"] # 码本大小 + code_ppl = q["code_perplexity"] # 码本使用的均匀程度 + prob_ppl = q["prob_perplexity"] # 概率分布的均匀程度 + curr_temp = q["temp"] # Gumbel softmax温度 + features = self.project_inp(features) # 投影到编码器维度 + + # ============================================================================ + # 10. 掩码应用 (核心:BERT式掩码语言建模) + # ============================================================================ if mask: + # 应用时序掩码和通道掩码 x, mask_indices = self.apply_mask( - features, - padding_mask, - mask_indices=mask_indices, - mask_channel_indices=mask_channel_indices, + features, # 输入特征 + padding_mask, # 填充掩码 + mask_indices=mask_indices, # 预计算的掩码索引 + mask_channel_indices=mask_channel_indices, # 预计算的通道掩码索引 ) + if not is_xla_tensor(x) and mask_indices is not None: - # tpu-comment: reducing the size in a dynamic way causes - # too many recompilations on xla. + # 提取被掩码位置的原始特征作为预测目标 y = unmasked_features[mask_indices].view( unmasked_features.size(0), -1, unmasked_features.size(-1) - ) + ) # [B, masked_timesteps, C] else: + # XLA模式下或未指定掩码时使用全部特征 y = unmasked_features else: - x = features - y = unmasked_features - mask_indices = None - + # 推理模式:不应用掩码 + x = features # 编码器输入 + y = unmasked_features # 目标特征 + mask_indices = None # 无掩码 + + # ============================================================================ + # 11. Transformer编码器 (上下文建模) + # ============================================================================ x, layer_results = self.encoder( - x, padding_mask=padding_mask, layer=layer, corpus_key=corpus_key - ) - + x, # 掩码后的特征 [B, T, C] + padding_mask=padding_mask, # 填充掩码 [B, T] + layer=layer, # 目标层 + corpus_key=corpus_key # 语料库键(适配器) + ) # 输出:x [B, T, C], layer_results [(x, z, lr), ...] + + # ============================================================================ + # 12. 特征提取模式 (下游任务使用) + # ============================================================================ if features_only: + # 只返回编码器特征,不进行对比学习 return { - "x": x, - "padding_mask": padding_mask, - "features": unmasked_features, - "layer_results": layer_results, + "x": x, # 编码器输出 [B, T, C] + "padding_mask": padding_mask, # 填充掩码 [B, T] + "features": unmasked_features, # 未掩码的原始特征 + "layer_results": layer_results, # 各层的输出结果 } + # ============================================================================ + # 13. 目标量化和负样本采样 (对比学习核心) + # ============================================================================ if self.quantizer: + # ------------------------------------------------------------------------ + # 13.1 目标特征量化 + # ------------------------------------------------------------------------ if self.negatives_from_everywhere: + # 从完整的未掩码特征中量化和采样 q = self.quantizer(unmasked_features, produce_targets=False) - y = q["x"] - num_vars = q["num_vars"] - code_ppl = q["code_perplexity"] - prob_ppl = q["prob_perplexity"] - curr_temp = q["temp"] - y = self.project_q(y) - + y = q["x"] # 量化后的目标特征 + num_vars = q["num_vars"] # 码本大小 + code_ppl = q["code_perplexity"] # 码本困惑度 + prob_ppl = q["prob_perplexity"] # 概率困惑度 + curr_temp = q["temp"] # 当前温度 + y = self.project_q(y) # 投影到最终维度 + + # 从量化特征中采样负样本 negs, _ = self.sample_negatives( y, - mask_indices[0].sum(), + mask_indices[0].sum(), # 掩码位置总数 padding_count=padding_count, ) + # 提取掩码位置的目标特征 y = y[mask_indices].view(y.size(0), -1, y.size(-1)) else: + # 仅对掩码位置的特征进行量化 q = self.quantizer(y, produce_targets=False) - y = q["x"] - num_vars = q["num_vars"] - code_ppl = q["code_perplexity"] - prob_ppl = q["prob_perplexity"] - curr_temp = q["temp"] + y = q["x"] # 量化后的目标特征 + num_vars = q["num_vars"] # 码本大小 + code_ppl = q["code_perplexity"] # 码本困惑度 + prob_ppl = q["prob_perplexity"] # 概率困惑度 + curr_temp = q["temp"] # 当前温度 - y = self.project_q(y) + y = self.project_q(y) # 投影到最终维度 + # 从目标特征中采样负样本 negs, _ = self.sample_negatives( y, - y.size(1), + y.size(1), # 目标序列长度 padding_count=padding_count, ) + # ------------------------------------------------------------------------ + # 13.2 码本负样本 (增加对比学习难度) + # ------------------------------------------------------------------------ if self.codebook_negatives > 0: + # 直接从码本中采样负样本 cb_negs = self.quantizer.sample_from_codebook( y.size(0) * y.size(1), self.codebook_negatives ) cb_negs = cb_negs.view( self.codebook_negatives, y.size(0), y.size(1), -1 - ) # order doesnt matter - cb_negs = self.project_q(cb_negs) - negs = torch.cat([negs, cb_negs], dim=0) + ) # [codebook_negs, B, T, C] + cb_negs = self.project_q(cb_negs) # 投影到最终维度 + negs = torch.cat([negs, cb_negs], dim=0) # 合并负样本 else: - y = self.project_q(y) + # ------------------------------------------------------------------------ + # 13.3 无量化器模式:直接使用连续特征 + # ------------------------------------------------------------------------ + y = self.project_q(y) # 投影目标特征 if self.negatives_from_everywhere: + # 从完整特征中采样负样本 negs, _ = self.sample_negatives( unmasked_features, y.size(1), padding_count=padding_count, ) - negs = self.project_q(negs) + negs = self.project_q(negs) # 投影负样本 else: + # 从目标特征中采样负样本 negs, _ = self.sample_negatives( y, y.size(1), padding_count=padding_count, ) + # ============================================================================ + # 14. 预测头计算 (对比学习损失) + # ============================================================================ if not is_xla_tensor(x): - # tpu-comment: reducing the size in a dynamic way causes - # too many recompilations on xla. - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + # 提取掩码位置的上下文表示 + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) # [B, masked_T, C] + # 可选的目标门控单元 (GLU) if self.target_glu: - y = self.target_glu(y) - negs = self.target_glu(negs) + y = self.target_glu(y) # 对目标特征应用GLU + negs = self.target_glu(negs) # 对负样本应用GLU - x = self.final_proj(x) - x = self.compute_preds(x, y, negs) + # 最终投影和对比预测 + x = self.final_proj(x) # 投影上下文表示到最终维度 + x = self.compute_preds(x, y, negs) # 计算对比学习logits + # ============================================================================ + # 15. 构建返回结果 + # ============================================================================ result = { - "x": x, - "padding_mask": padding_mask, - "features_pen": features_pen, + "x": x, # 对比学习logits [N+1, B, T] + "padding_mask": padding_mask, # 填充掩码 + "features_pen": features_pen, # 特征惩罚项 } + # 添加量化器相关的统计信息 if prob_ppl is not None: - result["prob_perplexity"] = prob_ppl - result["code_perplexity"] = code_ppl - result["num_vars"] = num_vars - result["temp"] = curr_temp + result["prob_perplexity"] = prob_ppl # 概率困惑度 + result["code_perplexity"] = code_ppl # 码本困惑度 + result["num_vars"] = num_vars # 码本大小 + result["temp"] = curr_temp # 当前温度 return result def quantize(self, x): - assert self.quantizer is not None - x = self.feature_extractor(x) - x = x.transpose(1, 2) - x = self.layer_norm(x) - return self.quantizer.forward_idx(x) + """ + 对输入音频进行量化 (返回离散码本索引) + + Args: + x: 原始音频波形 [B, T] + + Returns: + 量化索引 [B, T'] + """ + assert self.quantizer is not None, "模型必须配置量化器" + x = self.feature_extractor(x) # 特征提取 [B, T] -> [B, C, T] + x = x.transpose(1, 2) # [B, C, T] -> [B, T, C] + x = self.layer_norm(x) # 层归一化 + return self.quantizer.forward_idx(x) # 返回量化索引 def extract_features( self, source, padding_mask, mask=False, layer=None, corpus_key=None ): + """ + 提取特征的便利接口 (用于下游任务) + + Args: + source: 原始音频波形 [B, T] + padding_mask: 填充掩码 [B, T] + mask: 是否应用掩码 + layer: 目标编码器层 + corpus_key: 语料库键 + + Returns: + 特征字典:包含编码器输出、填充掩码、层结果等 + """ res = self.forward( source, padding_mask, mask=mask, - features_only=True, + features_only=True, # 只返回特征,不进行对比学习 layer=layer, corpus_key=corpus_key, ) return res def get_logits(self, net_output): - logits = net_output["x"] - logits = logits.transpose(0, 2) - logits = logits.reshape(-1, logits.size(-1)) + """ + 获取对比学习的logits (用于损失计算) + + Args: + net_output: 模型输出字典 + + Returns: + 展平的logits [B*T, N+1] (N+1: 1个正样本+N个负样本) + """ + logits = net_output["x"] # [N+1, B, T] 对比学习logits + logits = logits.transpose(0, 2) # [N+1, B, T] -> [T, B, N+1] + logits = logits.reshape(-1, logits.size(-1)) # [T*B, N+1] return logits def get_targets(self, sample, net_output, expand_steps=True): - x = net_output["x"] + """ + 获取对比学习的目标标签 (正样本索引始终为0) + + Args: + sample: 数据样本 + net_output: 模型输出 + expand_steps: 是否扩展到所有时间步 + + Returns: + 目标标签 [B*T] (全为0,表示第0个位置是正样本) + """ + x = net_output["x"] # [N+1, B, T] + # 对比学习中,正样本总是第0个位置,所以目标全为0 return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) def get_extra_losses(self, net_output): + """ + 计算额外的正则化损失 + + 包括: + 1. 量化器损失:鼓励码本的均匀使用 + 2. 特征惩罚项:防止特征幅值过大 + + Args: + net_output: 模型输出字典 + + Returns: + 正则化损失列表 + """ pen = [] if "prob_perplexity" in net_output: + # 量化器多样性损失:鼓励使用更多的码本条目 + # 损失 = (总码本数 - 实际困惑度) / 总码本数 + # 当困惑度接近码本数时,损失接近0(理想状态) pen.append( (net_output["num_vars"] - net_output["prob_perplexity"]) / net_output["num_vars"] ) if "features_pen" in net_output: + # 特征L2惩罚项:防止特征幅值过大 pen.append(net_output["features_pen"]) return pen def remove_pretraining_modules(self, last_layer=None): - self.quantizer = None - self.project_q = None - self.target_glu = None - self.final_proj = None + """ + 移除预训练模块 (用于下游任务微调) + + 移除自监督学习相关的组件: + - 量化器 + - 投影层 + - GLU门控 + - 最终投影 + + Args: + last_layer: 保留的最后一层编码器索引 + """ + self.quantizer = None # 移除量化器 + self.project_q = None # 移除目标投影层 + self.target_glu = None # 移除目标GLU + self.final_proj = None # 移除最终投影层 if last_layer is not None: + # 只保留前N层编码器(用于轻量化或特定任务) self.encoder.layers = nn.ModuleList( l for i, l in enumerate(self.encoder.layers) if i <= last_layer ) class ConvFeatureExtractionModel(nn.Module): + """ + 卷积特征提取器 (Wav2Vec 2.0的音频特征提取模块) + + 将原始音频波形转换为高层语义特征,相比Wav2Vec 1.0: + 1. 使用更小的卷积核和步长(降低下采样倍数) + 2. 支持不同的归一化策略(组归一化/层归一化) + 3. 使用GELU激活函数替代ReLU + + 默认配置:[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2 + 总下采样倍数:5 × 2^4 × 2^2 = 320 (vs Wav2Vec1.0的160) + """ + def __init__( self, conv_layers: List[Tuple[int, int, int]], @@ -849,9 +1343,18 @@ def __init__( mode: str = "default", conv_bias: bool = False, ): + """ + 初始化卷积特征提取器 + + Args: + conv_layers: 卷积层配置列表 [(输出维度, 核大小, 步长), ...] + dropout: dropout概率 + mode: 归一化模式 ("default"=组归一化, "layer_norm"=层归一化) + conv_bias: 是否使用卷积偏置 + """ super().__init__() - assert mode in {"default", "layer_norm"} + assert mode in {"default", "layer_norm"}, f"不支持的模式: {mode}" def block( n_in, @@ -862,83 +1365,138 @@ def block( is_group_norm=False, conv_bias=False, ): + """ + 构建单个卷积块 + + 架构: Conv1d -> Dropout -> Normalization -> GELU + """ def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) - nn.init.kaiming_normal_(conv.weight) + nn.init.kaiming_normal_(conv.weight) # He初始化,适合ReLU/GELU return conv assert ( is_layer_norm and is_group_norm - ) == False, "layer norm and group norm are exclusive" + ) == False, "层归一化和组归一化互斥" if is_layer_norm: + # 层归一化模式:每层都使用LayerNorm return nn.Sequential( make_conv(), nn.Dropout(p=dropout), nn.Sequential( - TransposeLast(), - Fp32LayerNorm(dim, elementwise_affine=True), - TransposeLast(), + TransposeLast(), # [B, C, T] -> [B, T, C] + Fp32LayerNorm(dim, elementwise_affine=True), # 层归一化 + TransposeLast(), # [B, T, C] -> [B, C, T] ), - nn.GELU(), + nn.GELU(), # GELU激活函数(相比ReLU更平滑) ) elif is_group_norm: + # 组归一化模式:仅第一层使用GroupNorm return nn.Sequential( make_conv(), nn.Dropout(p=dropout), - Fp32GroupNorm(dim, dim, affine=True), + Fp32GroupNorm(dim, dim, affine=True), # 组归一化(每个通道一组) nn.GELU(), ) else: + # 无归一化模式 return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) - in_d = 1 + # ============================================================================ + # 构建卷积层序列 + # ============================================================================ + in_d = 1 # 输入维度(原始音频) self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): - assert len(cl) == 3, "invalid conv definition: " + str(cl) - (dim, k, stride) = cl + assert len(cl) == 3, f"卷积层配置错误: {cl}" + (dim, k, stride) = cl # 输出维度、核大小、步长 self.conv_layers.append( block( - in_d, - dim, - k, - stride, - is_layer_norm=mode == "layer_norm", - is_group_norm=mode == "default" and i == 0, - conv_bias=conv_bias, + in_d, # 输入维度 + dim, # 输出维度 + k, # 卷积核大小 + stride, # 步长 + is_layer_norm=mode == "layer_norm", # 是否使用层归一化 + is_group_norm=mode == "default" and i == 0, # 仅第一层使用组归一化 + conv_bias=conv_bias, # 是否使用偏置 ) ) - in_d = dim + in_d = dim # 更新下一层的输入维度 def forward(self, x): - - # BxT -> BxCxT - x = x.unsqueeze(1) - + """ + 前向传播 + + Args: + x: 原始音频波形 [B, T] (批次大小, 时间步) + + Returns: + features: 卷积特征 [B, C, T'] (批次大小, 特征维度, 下采样后时间步) + """ + + # ============================================================================ + # 1. 添加通道维度: [B, T] -> [B, 1, T] + # ============================================================================ + x = x.unsqueeze(1) # 为1D卷积添加通道维度 + + # ============================================================================ + # 2. 逐层卷积特征提取 + # ============================================================================ for conv in self.conv_layers: - x = conv(x) + x = conv(x) # 应用卷积块:Conv1d -> Dropout -> Norm -> GELU - return x + return x # [B, C, T'] 其中T' = T / 下采样倍数 def make_conv_pos(e, k, g, is_batch_norm=False): + """ + 创建卷积位置编码层 (Wav2Vec 2.0的位置编码创新) + + 使用组卷积实现相对位置编码,相比绝对位置编码: + 1. 能够处理任意长度的序列 + 2. 具有平移不变性 + 3. 参数量更少 + + Args: + e: 嵌入维度 + k: 卷积核大小(控制位置感受野) + g: 组数(减少参数,增加归纳偏置) + is_batch_norm: 是否使用批归一化 + + Returns: + 位置编码模块: Conv1d -> (BatchNorm/WeightNorm) -> Padding -> GELU + """ + # ============================================================================ + # 1. 构建组卷积层 + # ============================================================================ pos_conv = nn.Conv1d( - e, - e, - kernel_size=k, - padding=k // 2, - groups=g, + e, # 输入通道数 + e, # 输出通道数(保持维度不变) + kernel_size=k, # 卷积核大小 + padding=k // 2, # 保持序列长度不变的填充 + groups=g, # 组卷积(减少参数:e*k -> e*k/g) ) + + # ============================================================================ + # 2. 权重初始化 (基于方差缩放) + # ============================================================================ dropout = 0 - std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) - nn.init.normal_(pos_conv.weight, mean=0, std=std) - nn.init.constant_(pos_conv.bias, 0) + std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) # 方差缩放初始化 + nn.init.normal_(pos_conv.weight, mean=0, std=std) # 正态分布初始化 + nn.init.constant_(pos_conv.bias, 0) # 偏置置零 + # ============================================================================ + # 3. 归一化策略选择和模块构建 + # ============================================================================ if not is_batch_norm: + # 权重归一化:稳定训练,适合序列模型 pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) else: + # 批归一化:可能在变长序列上不稳定 batch_norm = nn.BatchNorm1d(e) pos_conv = nn.Sequential(batch_norm, pos_conv, SamePad(k), nn.GELU()) @@ -946,99 +1504,152 @@ def make_conv_pos(e, k, g, is_batch_norm=False): class TransformerEncoder(nn.Module): + """ + Transformer编码器 (Wav2Vec 2.0的上下文建模核心) + + 支持三种编码器架构: + 1. Transformer: 标准自注意力机制 + 2. Conformer: 卷积增强的Transformer(结合CNN和自注意力) + 3. TRF_ADP: 带适配器的Transformer(用于参数高效迁移学习) + """ + def build_encoder_layer(self, args: Wav2Vec2Config, **kwargs): + """ + 构建编码器层 + + Args: + args: 模型配置 + **kwargs: 额外参数(如layer_idx用于适配器索引) + + Returns: + 编码器层实例 + """ if args.layer_type == "transformer": + # ======================================================================== + # 标准Transformer编码器层 + # ======================================================================== layer = TransformerSentenceEncoderLayer( - embedding_dim=self.embedding_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=self.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - activation_fn=args.activation_fn, - layer_norm_first=args.layer_norm_first, + embedding_dim=self.embedding_dim, # 嵌入维度 + ffn_embedding_dim=args.encoder_ffn_embed_dim, # 前馈网络维度 + num_attention_heads=args.encoder_attention_heads, # 注意力头数 + dropout=self.dropout, # 主dropout + attention_dropout=args.attention_dropout, # 注意力dropout + activation_dropout=args.activation_dropout, # 激活函数dropout + activation_fn=args.activation_fn, # 激活函数类型 + layer_norm_first=args.layer_norm_first, # Pre-LN vs Post-LN ) elif args.layer_type == "conformer": + # ======================================================================== + # Conformer编码器层 (卷积+自注意力) + # ======================================================================== layer = ConformerWav2Vec2EncoderLayer( - embed_dim=self.embedding_dim, - ffn_embed_dim=args.encoder_ffn_embed_dim, - attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, - activation_fn="swish", - attn_type=args.attn_type, - use_fp16=args.fp16, - pos_enc_type="abs", + embed_dim=self.embedding_dim, # 嵌入维度 + ffn_embed_dim=args.encoder_ffn_embed_dim, # 前馈网络维度 + attention_heads=args.encoder_attention_heads, # 注意力头数 + dropout=args.dropout, # dropout概率 + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, # 深度卷积核大小 + activation_fn="swish", # Swish激活函数 + attn_type=args.attn_type, # 注意力类型 + use_fp16=args.fp16, # 是否使用半精度 + pos_enc_type="abs", # 绝对位置编码 ) elif args.layer_type == "trf_adp": + # ======================================================================== + # 适配器增强的Transformer (参数高效迁移学习) + # ======================================================================== use_adp = False if args.adp_trf_idx == "all": + # 所有层都使用适配器 use_adp = True else: + # 仅指定层使用适配器 (格式: "start:end") adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")])) if kwargs.get("layer_idx", None) in adp_trf_idx: use_adp = True + if use_adp: + # 带适配器的Transformer层 layer = TransformerSentenceEncoderWithAdapterLayer( - embedding_dim=self.embedding_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=self.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - activation_fn=args.activation_fn, - layer_norm_first=args.layer_norm_first, - adapter_num=args.adp_num, - adapter_dim=args.adp_dim, - adapter_act_fn=args.adp_act_fn, + embedding_dim=self.embedding_dim, # 嵌入维度 + ffn_embedding_dim=args.encoder_ffn_embed_dim, # 前馈网络维度 + num_attention_heads=args.encoder_attention_heads, # 注意力头数 + dropout=self.dropout, # 主dropout + attention_dropout=args.attention_dropout, # 注意力dropout + activation_dropout=args.activation_dropout, # 激活函数dropout + activation_fn=args.activation_fn, # 激活函数类型 + layer_norm_first=args.layer_norm_first, # Pre-LN vs Post-LN + adapter_num=args.adp_num, # 适配器数量 + adapter_dim=args.adp_dim, # 适配器维度 + adapter_act_fn=args.adp_act_fn, # 适配器激活函数 ) else: + # 标准Transformer层(不使用适配器) layer = TransformerSentenceEncoderLayer( - embedding_dim=self.embedding_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=self.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - activation_fn=args.activation_fn, - layer_norm_first=args.layer_norm_first, + embedding_dim=self.embedding_dim, # 嵌入维度 + ffn_embedding_dim=args.encoder_ffn_embed_dim, # 前馈网络维度 + num_attention_heads=args.encoder_attention_heads, # 注意力头数 + dropout=self.dropout, # 主dropout + attention_dropout=args.attention_dropout, # 注意力dropout + activation_dropout=args.activation_dropout, # 激活函数dropout + activation_fn=args.activation_fn, # 激活函数类型 + layer_norm_first=args.layer_norm_first, # Pre-LN vs Post-LN ) - layer = fsdp_wrap(layer) + # ============================================================================ + # 优化配置:分布式训练和内存优化 + # ============================================================================ + layer = fsdp_wrap(layer) # FSDP包装:用于大模型分布式训练 if args.checkpoint_activations: + # 激活检查点:用时间换显存 layer = checkpoint_wrapper(layer) return layer def __init__(self, args: Wav2Vec2Config, skip_pos_conv: bool = False, override_encoder_layer: int = None): + """ + 初始化Transformer编码器 + + Args: + args: 模型配置 + skip_pos_conv: 是否跳过位置卷积 + override_encoder_layer: 覆盖编码器层数 + """ super().__init__() - self.dropout = args.dropout - self.embedding_dim = args.encoder_embed_dim - self.required_seq_len_multiple = args.required_seq_len_multiple + # ============================================================================ + # 基础配置 + # ============================================================================ + self.dropout = args.dropout # dropout概率 + self.embedding_dim = args.encoder_embed_dim # 嵌入维度 + self.required_seq_len_multiple = args.required_seq_len_multiple # 序列长度倍数要求 + # ============================================================================ + # 位置编码配置 (三种模式) + # ============================================================================ pos_conv_depth = getattr(args, "pos_conv_depth", 1) if pos_conv_depth > 1: - num_layers = args.pos_conv_depth - k = max(3, args.conv_pos // num_layers) + # 深度位置卷积:多层卷积块 + num_layers = args.pos_conv_depth # 位置卷积层数 + k = max(3, args.conv_pos // num_layers) # 每层卷积核大小 def make_conv_block(e, k, g, l): + """构建多层位置卷积块""" return nn.Sequential( *[ nn.Sequential( nn.Conv1d( - e, - e, - kernel_size=k, - padding=k // 2, - groups=g, + e, # 输入/输出维度 + e, # 保持维度不变 + kernel_size=k, # 卷积核大小 + padding=k // 2, # 保持序列长度 + groups=g, # 组卷积 ), - SamePad(k), - TransposeLast(), - LayerNorm(e, elementwise_affine=False), - TransposeLast(), - nn.GELU(), + SamePad(k), # 相同填充 + TransposeLast(), # [B,C,T] -> [B,T,C] + LayerNorm(e, elementwise_affine=False), # 层归一化 + TransposeLast(), # [B,T,C] -> [B,C,T] + nn.GELU(), # GELU激活 ) - for _ in range(l) + for _ in range(l) # 重复l层 ] ) @@ -1046,37 +1657,69 @@ def make_conv_block(e, k, g, l): self.embedding_dim, k, args.conv_pos_groups, num_layers ) elif skip_pos_conv: + # 无位置编码:适用于预训练特征已包含位置信息的情况 self.pos_conv = None else: + # 标准位置卷积:单层组卷积 self.pos_conv = make_conv_pos( - self.embedding_dim, - args.conv_pos, - args.conv_pos_groups, - is_batch_norm=args.conv_pos_batch_norm + self.embedding_dim, # 嵌入维度 + args.conv_pos, # 卷积核大小 + args.conv_pos_groups, # 组数 + is_batch_norm=args.conv_pos_batch_norm # 是否使用批归一化 if hasattr(args, "conv_pos_batch_norm") else False, ) + # ============================================================================ + # 编码器层构建 + # ============================================================================ if override_encoder_layer is None: - encoder_layers = args.encoder_layers + encoder_layers = args.encoder_layers # 使用配置中的层数 else: - encoder_layers = override_encoder_layer + encoder_layers = override_encoder_layer # 使用覆盖的层数 self.layers = nn.ModuleList( [self.build_encoder_layer(args, layer_idx=ii) for ii in range(encoder_layers)] ) - self.layer_norm_first = args.layer_norm_first - self.layer_norm = LayerNorm(self.embedding_dim) - self.layerdrop = args.encoder_layerdrop - - self.apply(init_bert_params) + + # ============================================================================ + # 归一化和正则化配置 + # ============================================================================ + self.layer_norm_first = args.layer_norm_first # Pre-LN vs Post-LN + self.layer_norm = LayerNorm(self.embedding_dim) # 最终层归一化 + self.layerdrop = args.encoder_layerdrop # 层dropout概率 + + # ============================================================================ + # 参数初始化 + # ============================================================================ + self.apply(init_bert_params) # 使用BERT风格的参数初始化 def forward(self, x, padding_mask=None, layer=None, corpus_key=None): + """ + Transformer编码器前向传播 + + Args: + x: 输入特征 [B, T, C] + padding_mask: 填充掩码 [B, T] + layer: 目标层索引(如果只需要前N层的输出) + corpus_key: 语料库键(用于适配器选择) + + Returns: + x: 编码后的特征 [B, T, C] + layer_results: 各层的输出结果列表 + """ + # ============================================================================ + # 1. 特征提取:通过所有Transformer层 + # ============================================================================ x, layer_results = self.extract_features( x, padding_mask, layer, corpus_key=corpus_key ) + # ============================================================================ + # 2. 最终归一化 (Pre-LN模式下需要) + # ============================================================================ if self.layer_norm_first and layer is None: + # Pre-LN: 在最后应用层归一化 x = self.layer_norm(x) return x, layer_results @@ -1089,80 +1732,145 @@ def extract_features( min_layer=0, corpus_key=None, ): + """ + 特征提取:通过Transformer层序列处理输入特征 + + Args: + x: 输入特征 [B, T, C] + padding_mask: 填充掩码 [B, T] + tgt_layer: 目标层索引(如果指定,只计算到该层) + min_layer: 最小层索引(从该层开始记录结果) + corpus_key: 语料库键(用于适配器选择) + + Returns: + x: 最终特征 [B, T, C] + layer_results: 各层结果 [(x, z, lr), ...] + """ + # ============================================================================ + # 1. 填充位置置零 (避免填充位置参与计算) + # ============================================================================ if padding_mask is not None: x = index_put(x, padding_mask, 0) + # ============================================================================ + # 2. 位置编码 (卷积位置编码) + # ============================================================================ if self.pos_conv is not None: - x_conv = self.pos_conv(x.transpose(1, 2)) - x_conv = x_conv.transpose(1, 2) - x = x + x_conv + x_conv = self.pos_conv(x.transpose(1, 2)) # [B, T, C] -> [B, C, T] -> 卷积 -> [B, C, T] + x_conv = x_conv.transpose(1, 2) # [B, C, T] -> [B, T, C] + x = x + x_conv # 残差连接:原特征 + 位置编码 + # ============================================================================ + # 3. 输入层归一化 (Post-LN模式) + # ============================================================================ if not self.layer_norm_first: + # Post-LN: 在输入端应用层归一化 x = self.layer_norm(x) - # pad to the sequence length dimension + # ============================================================================ + # 4. 序列长度填充 (满足模型要求的倍数关系) + # ============================================================================ x, pad_length = pad_to_multiple( x, self.required_seq_len_multiple, dim=-2, value=0 ) if pad_length > 0 and padding_mask is None: + # 为新增的填充位置创建掩码 padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) padding_mask[:, -pad_length:] = True else: + # 同步填充padding_mask padding_mask, _ = pad_to_multiple( padding_mask, self.required_seq_len_multiple, dim=-1, value=True ) + + # ============================================================================ + # 5. 应用dropout + # ============================================================================ x = F.dropout(x, p=self.dropout, training=self.training) - # B x T x C -> T x B x C + # ============================================================================ + # 6. 转换为Transformer格式 [B, T, C] -> [T, B, C] + # ============================================================================ x = x.transpose(0, 1) - layer_results = [] - r = None + # ============================================================================ + # 7. 逐层处理 + # ============================================================================ + layer_results = [] # 存储各层的输出结果 + r = None # 目标层的输出 for i, layer in enumerate(self.layers): + # ------------------------------------------------------------------------ + # 7.1 层dropout (随机跳过某些层以提高泛化性) + # ------------------------------------------------------------------------ dropout_probability = np.random.random() if self.layerdrop > 0 else 1 if not self.training or (dropout_probability > self.layerdrop): + # ------------------------------------------------------------------------ + # 7.2 处理FSDP包装的层 + # ------------------------------------------------------------------------ layer_check = layer if isinstance(layer, FullyShardedDataParallel): layer_check = layer.unwrapped_module + + # ------------------------------------------------------------------------ + # 7.3 根据层类型选择调用方式 + # ------------------------------------------------------------------------ if (corpus_key is None) or ( not isinstance(layer_check, ( TransformerSentenceEncoderWithAdapterLayer, ) ) ): + # 标准Transformer层或Conformer层 x, (z, lr) = layer( x, self_attn_padding_mask=padding_mask, need_weights=False ) else: + # 带适配器的层(需要corpus_key参数) x, (z, lr) = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key, ) + + # ------------------------------------------------------------------------ + # 7.4 记录层结果 + # ------------------------------------------------------------------------ if i >= min_layer: layer_results.append((x, z, lr)) + + # ------------------------------------------------------------------------ + # 7.5 检查是否达到目标层 + # ------------------------------------------------------------------------ if i == tgt_layer: - r = x + r = x # 保存目标层的输出 break + # ============================================================================ + # 8. 使用目标层输出(如果指定) + # ============================================================================ if r is not None: x = r - # T x B x C -> B x T x C + # ============================================================================ + # 9. 转换回标准格式 [T, B, C] -> [B, T, C] + # ============================================================================ x = x.transpose(0, 1) - # undo paddding + # ============================================================================ + # 10. 移除填充 (恢复原始序列长度) + # ============================================================================ if pad_length > 0: - x = x[:, :-pad_length] + x = x[:, :-pad_length] # 移除序列末尾的填充 def undo_pad(a, b, c): + """移除layer_results中的填充""" return ( - a[:-pad_length], - b[:-pad_length] if b is not None else b, - c[:-pad_length], + a[:-pad_length], # 移除x的填充 + b[:-pad_length] if b is not None else b, # 移除z的填充 + c[:-pad_length], # 移除lr的填充 ) layer_results = [undo_pad(*u) for u in layer_results] @@ -1170,96 +1878,196 @@ def undo_pad(a, b, c): return x, layer_results def max_positions(self): - """Maximum output length supported by the encoder.""" + """ + 获取编码器支持的最大序列长度 + + Returns: + 最大位置数(通常用于位置编码的上限) + """ return self.args.max_positions def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" + """ + 升级状态字典以兼容新版本的Fairseq + + 用于模型版本兼容性,处理参数名称变化等 + + Args: + state_dict: 旧版本的状态字典 + name: 模块名称 + + Returns: + 升级后的状态字典 + """ return state_dict class ConformerEncoder(TransformerEncoder): + """ + Conformer编码器 (卷积增强的Transformer) + + Conformer = Transformer + CNN,结合了两种架构的优势: + 1. Transformer: 全局长距离依赖建模 + 2. CNN: 局部特征提取和位置不变性 + + 架构创新: + - 多头注意力 + 卷积模块的串联设计 + - 残差连接和Layer Norm的优化布局 + - Swish激活函数和相对位置编码 + """ + def build_encoder_layer(self, args): + """ + 构建Conformer编码器层 + + Args: + args: 模型配置 + + Returns: + ConformerWav2Vec2EncoderLayer实例 + """ layer = ConformerWav2Vec2EncoderLayer( - embed_dim=self.embedding_dim, - ffn_embed_dim=args.encoder_ffn_embed_dim, - attention_heads=args.encoder_attention_heads, - dropout=args.dropout, - depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, - activation_fn="swish", - attn_type=args.attn_type, - pos_enc_type=args.pos_enc_type, - use_fp16=args.fp16, # only used for rope + embed_dim=self.embedding_dim, # 嵌入维度 + ffn_embed_dim=args.encoder_ffn_embed_dim, # 前馈网络维度 + attention_heads=args.encoder_attention_heads, # 注意力头数 + dropout=args.dropout, # dropout概率 + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, # 深度卷积核大小 + activation_fn="swish", # Swish激活函数(对语音更有效) + attn_type=args.attn_type, # 注意力类型 + pos_enc_type=args.pos_enc_type, # 位置编码类型 + use_fp16=args.fp16, # 半精度训练(仅用于RoPE) ) - layer = fsdp_wrap(layer) + layer = fsdp_wrap(layer) # FSDP包装 if args.checkpoint_activations: - layer = checkpoint_wrapper(layer) + layer = checkpoint_wrapper(layer) # 激活检查点 return layer def __init__(self, args): - super().__init__(args) - self.args = args - self.dropout = args.dropout - self.embedding_dim = args.encoder_embed_dim - self.pos_enc_type = args.pos_enc_type - max_source_positions = self.max_positions() - + """ + 初始化Conformer编码器 + + Args: + args: 模型配置 + """ + super().__init__(args) # 调用父类TransformerEncoder的初始化 + + # ============================================================================ + # 基础配置 + # ============================================================================ + self.args = args # 保存配置引用 + self.dropout = args.dropout # dropout概率 + self.embedding_dim = args.encoder_embed_dim # 嵌入维度 + self.pos_enc_type = args.pos_enc_type # 位置编码类型 + max_source_positions = self.max_positions() # 最大序列长度 + + # ============================================================================ + # 位置编码配置 (支持多种类型) + # ============================================================================ if self.pos_enc_type == "rel_pos": + # 相对位置编码:T5/Transformer-XL风格 self.embed_positions = RelPositionalEncoding( max_source_positions, self.embedding_dim ) elif self.pos_enc_type == "rope": + # 旋转位置编码:RoFormer风格,无需显式位置嵌入 self.embed_positions = None else: + # 抛出异常:不支持的位置编码类型 raise Exception("Unsupported positional encoding type") + # ============================================================================ + # 编码器层构建 + # ============================================================================ self.layers = nn.ModuleList( [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] ) - self.layer_norm_first = args.layer_norm_first - self.layer_norm = LayerNorm(self.embedding_dim) - self.layerdrop = args.encoder_layerdrop - - self.apply(init_bert_params) + + # ============================================================================ + # 归一化和正则化配置 + # ============================================================================ + self.layer_norm_first = args.layer_norm_first # Pre-LN vs Post-LN + self.layer_norm = LayerNorm(self.embedding_dim) # 最终层归一化 + self.layerdrop = args.encoder_layerdrop # 层dropout概率 + + # ============================================================================ + # 参数初始化 + # ============================================================================ + self.apply(init_bert_params) # 使用BERT风格的参数初始化 def extract_features(self, x, padding_mask=None, tgt_layer=None): + """ + Conformer特征提取 (与标准Transformer的差异) + + Args: + x: 输入特征 [B, T, C] + padding_mask: 填充掩码 [B, T] + tgt_layer: 目标层索引 + + Returns: + x: 编码特征 [B, T, C] + layer_results: 各层结果列表 + """ + # ============================================================================ + # 1. 填充位置置零 + # ============================================================================ if padding_mask is not None: x = index_put(x, padding_mask, 0) - # B x T x C -> T x B x C + # ============================================================================ + # 2. 转换维度格式 [B, T, C] -> [T, B, C] + # ============================================================================ x = x.transpose(0, 1) - # B X T X C here + # ============================================================================ + # 3. 位置编码计算 (相对位置编码) + # ============================================================================ position_emb = None if self.pos_enc_type == "rel_pos": + # 相对位置编码:基于序列内位置关系 position_emb = self.embed_positions(x) + # ============================================================================ + # 4. 输入层归一化 (Post-LN模式) + # ============================================================================ if not self.layer_norm_first: x = self.layer_norm(x) + # ============================================================================ + # 5. 输入dropout + # ============================================================================ x = F.dropout(x, p=self.dropout, training=self.training) + # ============================================================================ + # 6. 逐层Conformer处理 + # ============================================================================ layer_results = [] r = None for i, layer in enumerate(self.layers): + # 层级dropout:随机跳过某些层 dropout_probability = np.random.random() if not self.training or (dropout_probability > self.layerdrop): + # Conformer层前向传播 (注意力+卷积+前馈) x, z = layer( x, self_attn_padding_mask=padding_mask, need_weights=False, - position_emb=position_emb, + position_emb=position_emb, # 传递位置编码 ) if tgt_layer is not None: layer_results.append((x, z)) if i == tgt_layer: - r = x + r = x # 保存目标层输出 break + # ============================================================================ + # 7. 使用目标层输出(如果指定) + # ============================================================================ if r is not None: x = r - # T x B x C -> B x T x C + # ============================================================================ + # 8. 转换回标准格式 [T, B, C] -> [B, T, C] + # ============================================================================ x = x.transpose(0, 1) return x, layer_results @@ -1267,8 +2075,14 @@ def extract_features(self, x, padding_mask=None, tgt_layer=None): class TransformerSentenceEncoderLayer(nn.Module): """ - Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained - models. + 标准Transformer编码器层 (BERT/XLM风格) + + 架构: Multi-Head Self-Attention + Position-wise FFN + 支持Pre-LN和Post-LN两种归一化模式 + + 结构: + - Pre-LN: LN -> Attn -> Residual -> LN -> FFN -> Residual + - Post-LN: Attn -> Residual -> LN -> FFN -> Residual -> LN """ def __init__( @@ -1282,34 +2096,63 @@ def __init__( activation_fn: str = "relu", layer_norm_first: bool = False, ) -> None: - + """ + 初始化Transformer编码器层 + + Args: + embedding_dim: 嵌入维度 (默认768) + ffn_embedding_dim: 前馈网络隐层维度 (默认3072) + num_attention_heads: 注意力头数 (默认8) + dropout: 主dropout概率 + attention_dropout: 注意力dropout概率 + activation_dropout: 激活函数dropout概率 + activation_fn: 激活函数类型 (relu/gelu等) + layer_norm_first: 是否使用Pre-LN (默认False为Post-LN) + """ super().__init__() - # Initialize parameters - self.embedding_dim = embedding_dim - self.dropout = dropout - self.activation_dropout = activation_dropout - - # Initialize blocks - self.activation_fn = utils.get_activation_fn(activation_fn) + + # ============================================================================ + # 基础参数配置 + # ============================================================================ + self.embedding_dim = embedding_dim # 嵌入维度 + self.dropout = dropout # 主dropout概率 + self.activation_dropout = activation_dropout # 激活函数dropout概率 + + # ============================================================================ + # 子模块初始化 + # ============================================================================ + self.activation_fn = utils.get_activation_fn(activation_fn) # 激活函数 + + # 多头自注意力机制 self.self_attn = MultiheadAttention( - self.embedding_dim, - num_attention_heads, - dropout=attention_dropout, - self_attention=True, + self.embedding_dim, # Query/Key/Value维度 + num_attention_heads, # 注意力头数 + dropout=attention_dropout, # 注意力dropout + self_attention=True, # 自注意力模式 ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(self.activation_dropout) - self.dropout3 = nn.Dropout(dropout) + # 三个dropout层 (注意力后、激活函数后、前馈网络后) + self.dropout1 = nn.Dropout(dropout) # 注意力输出dropout + self.dropout2 = nn.Dropout(self.activation_dropout) # 激活函数dropout + self.dropout3 = nn.Dropout(dropout) # 前馈网络输出dropout - self.layer_norm_first = layer_norm_first + # ============================================================================ + # 归一化策略配置 + # ============================================================================ + self.layer_norm_first = layer_norm_first # Pre-LN vs Post-LN - # layer norm associated with the self attention layer - self.self_attn_layer_norm = LayerNorm(self.embedding_dim) - self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) - self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + # ============================================================================ + # 注意力子层组件 + # ============================================================================ + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) # 注意力层归一化 - # layer norm associated with the position wise feed-forward NN + # ============================================================================ + # 前馈网络子层组件 + # ============================================================================ + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) # FFN第一层 + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) # FFN第二层 + + # 前馈网络层归一化 self.final_layer_norm = LayerNorm(self.embedding_dim) def forward( @@ -1321,81 +2164,154 @@ def forward( att_args=None, ): """ - LayerNorm is applied either before or after the self-attention/ffn - modules similar to the original Transformer imlementation. + Transformer编码器层前向传播 + + 支持两种归一化策略: + - Pre-LN (layer_norm_first=True): 层归一化在子层之前 + - Post-LN (layer_norm_first=False): 层归一化在残差连接之后 + + Args: + x: 输入特征 [T, B, C] 或 [B, T, C] + self_attn_mask: 自注意力掩码 + self_attn_padding_mask: 填充掩码 + need_weights: 是否返回注意力权重 + att_args: 注意力额外参数 + + Returns: + x: 输出特征 [T, B, C] 或 [B, T, C] + (attn, layer_result): 注意力权重和层结果 """ + # ============================================================================ + # 保存输入用于残差连接 + # ============================================================================ residual = x if self.layer_norm_first: - x = self.self_attn_layer_norm(x) + # ======================================================================== + # Pre-LN模式: LN -> Attn -> Residual -> LN -> FFN -> Residual + # ======================================================================== + + # ------------------------------------------------------------------------ + # 1. 自注意力子层 (Pre-LN) + # ------------------------------------------------------------------------ + x = self.self_attn_layer_norm(x) # 注意力前的层归一化 x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - attn_mask=self_attn_mask, - need_weights=False, + query=x, # Query向量 + key=x, # Key向量 (自注意力中与Query相同) + value=x, # Value向量 + key_padding_mask=self_attn_padding_mask, # 填充掩码 + attn_mask=self_attn_mask, # 注意力掩码 + need_weights=False, # 不返回注意力权重 ) - x = self.dropout1(x) - x = residual + x - - residual = x - x = self.final_layer_norm(x) - x = self.activation_fn(self.fc1(x)) - x = self.dropout2(x) - x = self.fc2(x) - - layer_result = x - - x = self.dropout3(x) - x = residual + x + x = self.dropout1(x) # 注意力输出dropout + x = residual + x # 残差连接 + + # ------------------------------------------------------------------------ + # 2. 前馈网络子层 (Pre-LN) + # ------------------------------------------------------------------------ + residual = x # 更新残差连接的基准 + x = self.final_layer_norm(x) # 前馈网络前的层归一化 + x = self.activation_fn(self.fc1(x)) # 第一层线性变换 + 激活函数 + x = self.dropout2(x) # 激活函数后dropout + x = self.fc2(x) # 第二层线性变换 + + layer_result = x # 保存层输出结果 + + x = self.dropout3(x) # 前馈网络输出dropout + x = residual + x # 残差连接 + else: + # ======================================================================== + # Post-LN模式: Attn -> Residual -> LN -> FFN -> Residual -> LN + # ======================================================================== + + # ------------------------------------------------------------------------ + # 1. 自注意力子层 (Post-LN) + # ------------------------------------------------------------------------ x, attn = self.self_attn( - query=x, - key=x, - value=x, - key_padding_mask=self_attn_padding_mask, - need_weights=False, + query=x, # Query向量 + key=x, # Key向量 + value=x, # Value向量 + key_padding_mask=self_attn_padding_mask, # 填充掩码 + need_weights=False, # 不返回注意力权重 ) - x = self.dropout1(x) - x = residual + x - - x = self.self_attn_layer_norm(x) + x = self.dropout1(x) # 注意力输出dropout + x = residual + x # 残差连接 + x = self.self_attn_layer_norm(x) # 残差连接后的层归一化 - residual = x - x = self.activation_fn(self.fc1(x)) - x = self.dropout2(x) - x = self.fc2(x) + # ------------------------------------------------------------------------ + # 2. 前馈网络子层 (Post-LN) + # ------------------------------------------------------------------------ + residual = x # 更新残差连接的基准 + x = self.activation_fn(self.fc1(x)) # 第一层线性变换 + 激活函数 + x = self.dropout2(x) # 激活函数后dropout + x = self.fc2(x) # 第二层线性变换 - layer_result = x + layer_result = x # 保存层输出结果 - x = self.dropout3(x) - x = residual + x - x = self.final_layer_norm(x) + x = self.dropout3(x) # 前馈网络输出dropout + x = residual + x # 残差连接 + x = self.final_layer_norm(x) # 残差连接后的层归一化 return x, (attn, layer_result) class AdapterFast(nn.Module): + """ + 快速适配器模块 (参数高效的迁移学习) + + 适配器的核心思想: + 1. 在预训练模型中插入少量可训练参数 + 2. 冻结主模型,只训练适配器参数 + 3. 实现高效的任务特定化 + + 架构: LayerNorm -> Linear -> Activation -> Linear + 相比标准适配器的优化: + - 使用3D张量存储多个适配器参数,避免ModuleList开销 + - 优化了训练吞吐量 + """ + def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): """ - Implements adapter modules directly with 3D tensor weight as parameters - and without using ModuleList orto speed up training throughput. + 初始化快速适配器 + + Args: + adapter_num: 适配器数量(支持多任务/多语料) + input_dim: 输入维度 + hidden_dim: 隐层维度(通常比input_dim小,实现降维) + act_fn: 激活函数类型 """ super().__init__() - self.adapter_num = adapter_num - self.input_dim = input_dim - self.hidden_dim = hidden_dim + # ============================================================================ + # 基础配置 + # ============================================================================ + self.adapter_num = adapter_num # 适配器数量 + self.input_dim = input_dim # 输入维度 + self.hidden_dim = hidden_dim # 隐层维度 + + # ============================================================================ + # 适配器权重参数 (3D张量,批量存储) + # ============================================================================ + # 下投影:input_dim -> hidden_dim (降维) self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim)) + # 上投影:hidden_dim -> input_dim (升维) self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim)) + # 偏置项 self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim)) self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim)) - self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) - self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) - self.act_fn = nn.Identity() + # ============================================================================ + # 层归一化参数 + # ============================================================================ + self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim)) # 缩放参数 + self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim)) # 偏移参数 + + # ============================================================================ + # 激活函数配置 + # ============================================================================ + self.act_fn = nn.Identity() # 默认恒等映射 if act_fn == "relu": self.act_fn = nn.ReLU() elif act_fn == "gelu": @@ -1403,45 +2319,81 @@ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn): elif act_fn == "selu": self.act_fn = nn.SELU() else: - raise ValueError(f"unsupported {act_fn}") - + raise ValueError(f"不支持的激活函数: {act_fn}") self.input_dim = input_dim - self.reset_parameters() + self.reset_parameters() # 参数初始化 def reset_parameters(self): + """ + 参数初始化 (Kaiming初始化 + 均匀分布偏置) + """ for ii in range(self.adapter_num): + # ======================================================================== + # 权重矩阵初始化 (Kaiming均匀分布) + # ======================================================================== nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5)) nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5)) + + # ======================================================================== + # 偏置初始化 (基于fan_in的均匀分布) + # ======================================================================== fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.b_a[ii], -bound, bound) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii]) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(self.b_b[ii], -bound, bound) - nn.init.ones_(self.ln_W) - nn.init.zeros_(self.ln_b) + # ======================================================================== + # 层归一化参数初始化 + # ======================================================================== + nn.init.ones_(self.ln_W) # 缩放参数初始化为1 + nn.init.zeros_(self.ln_b) # 偏移参数初始化为0 def forward(self, x, adapter_id): - ii = adapter_id - h = x - h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]) - h = F.linear(h, self.W_a[ii], self.b_a[ii]) - h = self.act_fn(h) - h = F.linear(h, self.W_b[ii], self.b_b[ii]) + """ + 适配器前向传播 + + Args: + x: 输入特征 [B, T, C] + adapter_id: 适配器索引 (选择使用哪个适配器) + + Returns: + 适配器输出 [B, T, C] + """ + ii = adapter_id # 适配器索引 + h = x # 输入特征 + + # ============================================================================ + # 适配器计算流程: LN -> Linear -> Act -> Linear + # ============================================================================ + h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]) # 层归一化 + h = F.linear(h, self.W_a[ii], self.b_a[ii]) # 下投影(降维) + h = self.act_fn(h) # 激活函数 + h = F.linear(h, self.W_b[ii], self.b_b[ii]) # 上投影(升维) + outputs = h return outputs def extra_repr(self): - return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim)) + """返回模块的字符串表示""" + return ('adapter={}, input_dim={}, hidden_dim={}'.format( + self.adapter_num, self.input_dim, self.hidden_dim)) class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer): """ - Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained - models. An adapter module is added along with vanilla Transformer module. + 带适配器的Transformer编码器层 (参数高效迁移学习) + + 在标准Transformer层的基础上添加适配器模块: + 1. 继承标准TransformerSentenceEncoderLayer的所有功能 + 2. 在层输出后插入适配器模块 + 3. 支持多语料/多任务的参数高效微调 + + 架构: Transformer Layer + Adapter Layer + Residual Connection """ def __init__( @@ -1458,7 +2410,25 @@ def __init__( adapter_dim=64, adapter_act_fn="relu", ) -> None: - + """ + 初始化带适配器的Transformer编码器层 + + Args: + embedding_dim: 嵌入维度 + ffn_embedding_dim: 前馈网络隐层维度 + num_attention_heads: 注意力头数 + dropout: 主dropout概率 + attention_dropout: 注意力dropout概率 + activation_dropout: 激活函数dropout概率 + activation_fn: 激活函数类型 + layer_norm_first: 是否使用Pre-LN + adapter_num: 适配器数量(支持多语料) + adapter_dim: 适配器隐层维度 + adapter_act_fn: 适配器激活函数 + """ + # ============================================================================ + # 调用父类初始化 (标准Transformer层) + # ============================================================================ super().__init__( embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, @@ -1468,12 +2438,21 @@ def __init__( activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first, - ) - self.adapter_num = adapter_num - self.adapter_dim = adapter_dim - self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn) + # ============================================================================ + # 适配器模块配置 + # ============================================================================ + self.adapter_num = adapter_num # 适配器数量 + self.adapter_dim = adapter_dim # 适配器隐层维度 + + # 创建快速适配器模块 + self.adapter_layer = AdapterFast( + adapter_num, # 适配器数量 + self.embedding_dim, # 输入维度(与Transformer层输出维度相同) + self.adapter_dim, # 隐层维度(通常更小) + adapter_act_fn # 激活函数 + ) def forward( self, @@ -1484,7 +2463,29 @@ def forward( att_args=None, corpus_key=None, ): - + """ + 带适配器的Transformer层前向传播 + + 计算流程: + 1. 标准Transformer层前向传播 + 2. 根据corpus_key选择适配器 + 3. 适配器输出与Transformer输出相加(残差连接) + + Args: + x: 输入特征 [T, B, C] 或 [B, T, C] + self_attn_mask: 自注意力掩码 + self_attn_padding_mask: 填充掩码 + need_weights: 是否返回注意力权重 + att_args: 注意力额外参数 + corpus_key: 语料库键列表(用于选择适配器) + + Returns: + x: 输出特征 [T, B, C] 或 [B, T, C] + (attn, layer_result): 注意力权重和层结果 + """ + # ============================================================================ + # 1. 标准Transformer层前向传播 + # ============================================================================ x, (attn, layer_result) = super().forward( x=x, self_attn_mask=self_attn_mask, @@ -1492,8 +2493,19 @@ def forward( need_weights=need_weights, att_args=att_args, ) - assert corpus_key is not None - assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}" - y = self.adapter_layer(x, corpus_key[0]) - x = x + y + + # ============================================================================ + # 2. 适配器模块处理 + # ============================================================================ + assert corpus_key is not None, "适配器层需要corpus_key参数" + assert len(set(corpus_key)) == 1, f"批次内corpus_key必须相同: {corpus_key}" + + # 通过适配器处理 + y = self.adapter_layer(x, corpus_key[0]) # 使用第一个corpus_key作为适配器索引 + + # ============================================================================ + # 3. 残差连接 (适配器 + Transformer输出) + # ============================================================================ + x = x + y # 残差连接:原始特征 + 适配器特征 + return x, (attn, layer_result)