From d9e25fe8c287cf69fa95499122060118afe5a039 Mon Sep 17 00:00:00 2001 From: khanld Date: Tue, 18 Nov 2025 12:14:04 +0700 Subject: [PATCH 1/3] add stream rnn-t --- apps/realtime-asr/stream_asr.py | 94 +++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 5 deletions(-) diff --git a/apps/realtime-asr/stream_asr.py b/apps/realtime-asr/stream_asr.py index 91ad227..447f25d 100644 --- a/apps/realtime-asr/stream_asr.py +++ b/apps/realtime-asr/stream_asr.py @@ -45,6 +45,12 @@ def __init__(self, config: StreamingConfig): self.offset = 0 self.total_frames_processed = 0 self.accumulated_text = "" # Accumulate text across all chunks + + # Transducer predictor caches + self.pred_cache_m = None + self.pred_cache_c = None + self.pred_input = None + self.reset_cache() # Audio capture - prefer PyAudio on macOS for better stability @@ -133,6 +139,18 @@ def reset_cache(self): self.offset = 0 self.total_frames_processed = 0 self.accumulated_text = "" # Reset accumulated text + + # Reset transducer predictor cache if the model has transducer + if hasattr(self.model.model, "predictor"): + predictor = self.model.model.predictor + self.pred_cache_m, self.pred_cache_c = predictor.init_state( + batch_size=1, method="zero", device=self.device + ) + # Initialize with blank token + blank_id = self.model.model.blank + self.pred_input = ( + torch.tensor([blank_id]).reshape(1, 1).to(self.device) + ) def extract_features(self, audio_chunk: np.ndarray) -> torch.Tensor: """Extract fbank features from audio chunk""" @@ -188,20 +206,86 @@ def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[torch.Tensor, str]: def decode(self, encoder_out: torch.Tensor) -> str: """Decode encoder output to text""" text: str - if hasattr(self.model.model, "ctc"): + if self.model.config.model == "asr_model": # CTC decoding ctc_probs = self.model.model.ctc.log_softmax(encoder_out) # [B, T, vocab] topk = ctc_probs.argmax(dim=-1) # [B, T] hyps = [hyp.tolist() for hyp in topk] text = str(get_output(hyps, self.model.char_dict, self.model.config.model)[0]) - elif hasattr(self.model, "decoder"): - # Transducer or attention decoder - # Implement appropriate decoding here - text = "[Decoder output]" + elif self.model.config.model == "transducer": + # Transducer decoding using streaming optimized search + hyps = self.decode_transducer_streaming(encoder_out) + text = str(get_output([hyps], self.model.char_dict, self.model.config.model)[0]) else: text = "[Unknown decoder type]" return text + + def decode_transducer_streaming(self, encoder_out: torch.Tensor, n_steps: int = 64) -> list: + """ + Streaming transducer decoder based on optimized_search. + + This function processes encoder output frame by frame and maintains + predictor state across chunks for streaming inference. + + Args: + encoder_out: Encoder output tensor [B=1, T, E] + n_steps: Maximum non-blank predictions per frame + + Returns: + List of predicted token IDs (without blanks) + """ + model = self.model.model + blank_id = model.blank + + batch_size = encoder_out.size(0) + max_len = encoder_out.size(1) + + # Use persistent predictor cache across chunks + cache_m = self.pred_cache_m + cache_c = self.pred_cache_c + pred_input = self.pred_input + + # Output buffer for this chunk + chunk_hyps = [] + + # Process each frame + for t in range(max_len): + encoder_out_t = encoder_out[:, t : t + 1, :] # [B=1, 1, E] + + # Allow up to n_steps non-blank predictions per frame + for step in range(1, n_steps + 1): + # Forward through predictor + pred_out_step, new_cache = model.predictor.forward_step( + pred_input, (cache_m, cache_c) + ) # [B=1, 1, P] + + # Forward through joint network + joint_out_step = model.joint(encoder_out_t, pred_out_step) # [B=1, 1, V] + joint_out_probs = joint_out_step.log_softmax(dim=-1) + + # Get best prediction + joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # scalar + if joint_out_max == blank_id: + # Blank prediction - move to next frame + break + else: + # Non-blank prediction + chunk_hyps.append(joint_out_max.item()) + + # Update predictor input and cache for next step + pred_input = joint_out_max.reshape(1, 1) + cache_m, cache_c = new_cache + + # Check if we've reached max steps per frame + if step >= n_steps: + break + + # Update persistent cache for next chunk + self.pred_cache_m = cache_m + self.pred_cache_c = cache_c + self.pred_input = pred_input + return chunk_hyps def run(self): """Main streaming loop""" From 4f625cffcc4212bf7d9aa6559cbf34a1caafd246 Mon Sep 17 00:00:00 2001 From: khanld Date: Wed, 19 Nov 2025 14:10:00 +0700 Subject: [PATCH 2/3] add stream rnn-t training config --- ...chunkformer-rnnt-small-vie-stream-dct.yaml | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml diff --git a/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml b/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml new file mode 100644 index 0000000..5827c6e --- /dev/null +++ b/examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml @@ -0,0 +1,137 @@ +# network architecture +# encoder related +encoder: chunkformer +encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + cnn_module_kernel: 15 + use_cnn_module: True + activation_type: 'swish' + pos_enc_layer_type: 'chunk_rel_pos' + selfattention_layer_type: 'chunk_rel_seflattn' + cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster + # Enable the settings below for joint training on full and chunk context + dynamic_conv: true + dynamic_chunk_sizes: [4, 6, 8] # 1 frame 80ms - [320, 480, 640] + # Note that the left context is relative spaned depeding on the number of encoder layer + dynamic_left_context_sizes: [40, 50, 60] + dynamic_right_context_sizes: [0] # No right context for streaming + streaming: true + +joint: transducer_joint +joint_conf: + enc_output_size: 256 + pred_output_size: 256 + join_dim: 512 + prejoin_linear: True + postjoin_linear: false + joint_mode: 'add' + activation: 'tanh' + +predictor: rnn +predictor_conf: + embed_size: 256 + output_size: 256 + embed_dropout: 0.1 + hidden_size: 256 + num_layers: 2 + bias: true + rnn_type: 'lstm' + dropout: 0.1 + +decoder: bitransformer +decoder_conf: + attention_heads: 4 + dropout_rate: 0.1 + linear_units: 2048 + num_blocks: 3 + positional_dropout_rate: 0.1 + r_num_blocks: 3 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +tokenizer: bpe +tokenizer_conf: + symbol_table_path: 'data/lang_char/train_hf_bpe1024_units.txt' + split_with_space: false + bpe_path: 'data/lang_char/train_hf_bpe1024.model' + non_lang_syms_path: null + is_multilingual: false + num_languages: 1 + special_tokens: + : 0 + : 1 + : 2 + : 2 + +ctc: ctc +ctc_conf: + ctc_blank_id: 0 + +cmvn: global_cmvn +cmvn_conf: + cmvn_file: 'data/train_hf/global_cmvn' + is_json_cmvn: true + +# hybrid transducer+ctc+attention +model: transducer +model_conf: + transducer_weight: 0.75 + ctc_weight: 0.1 + attention_weight: 0.15 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + enable_k2: True + +dataset: asr +dataset_conf: + filter_conf: + max_length: 40960 + min_length: 0 + token_max_length: 400 + token_min_length: 1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1000 + sort: False + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'dynamic' # static or dynamic + max_frames_in_batch: 300000 + pad_feat: True + + +grad_clip: 5 +accum_grad: 2 +max_epoch: 200 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 0.001 +scheduler: warmuplr # pytorch v1.1.0+ required +scheduler_conf: + warmup_steps: 15000 From 9e5c2be679c78e9db8f64d1fb1f4d8a2576a60a3 Mon Sep 17 00:00:00 2001 From: khanld Date: Wed, 26 Nov 2025 10:26:33 +0000 Subject: [PATCH 3/3] fix lint --- apps/realtime-asr/stream_asr.py | 39 +++++++++++++++------------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/apps/realtime-asr/stream_asr.py b/apps/realtime-asr/stream_asr.py index 447f25d..9a26db4 100644 --- a/apps/realtime-asr/stream_asr.py +++ b/apps/realtime-asr/stream_asr.py @@ -45,12 +45,12 @@ def __init__(self, config: StreamingConfig): self.offset = 0 self.total_frames_processed = 0 self.accumulated_text = "" # Accumulate text across all chunks - + # Transducer predictor caches self.pred_cache_m = None self.pred_cache_c = None self.pred_input = None - + self.reset_cache() # Audio capture - prefer PyAudio on macOS for better stability @@ -139,7 +139,7 @@ def reset_cache(self): self.offset = 0 self.total_frames_processed = 0 self.accumulated_text = "" # Reset accumulated text - + # Reset transducer predictor cache if the model has transducer if hasattr(self.model.model, "predictor"): predictor = self.model.model.predictor @@ -148,9 +148,7 @@ def reset_cache(self): ) # Initialize with blank token blank_id = self.model.model.blank - self.pred_input = ( - torch.tensor([blank_id]).reshape(1, 1).to(self.device) - ) + self.pred_input = torch.tensor([blank_id]).reshape(1, 1).to(self.device) def extract_features(self, audio_chunk: np.ndarray) -> torch.Tensor: """Extract fbank features from audio chunk""" @@ -220,50 +218,49 @@ def decode(self, encoder_out: torch.Tensor) -> str: text = "[Unknown decoder type]" return text - + def decode_transducer_streaming(self, encoder_out: torch.Tensor, n_steps: int = 64) -> list: """ Streaming transducer decoder based on optimized_search. - + This function processes encoder output frame by frame and maintains predictor state across chunks for streaming inference. - + Args: encoder_out: Encoder output tensor [B=1, T, E] n_steps: Maximum non-blank predictions per frame - + Returns: List of predicted token IDs (without blanks) """ model = self.model.model blank_id = model.blank - - batch_size = encoder_out.size(0) + max_len = encoder_out.size(1) - + # Use persistent predictor cache across chunks cache_m = self.pred_cache_m cache_c = self.pred_cache_c pred_input = self.pred_input - + # Output buffer for this chunk chunk_hyps = [] - + # Process each frame for t in range(max_len): encoder_out_t = encoder_out[:, t : t + 1, :] # [B=1, 1, E] - + # Allow up to n_steps non-blank predictions per frame for step in range(1, n_steps + 1): # Forward through predictor pred_out_step, new_cache = model.predictor.forward_step( pred_input, (cache_m, cache_c) ) # [B=1, 1, P] - + # Forward through joint network joint_out_step = model.joint(encoder_out_t, pred_out_step) # [B=1, 1, V] joint_out_probs = joint_out_step.log_softmax(dim=-1) - + # Get best prediction joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # scalar if joint_out_max == blank_id: @@ -272,15 +269,15 @@ def decode_transducer_streaming(self, encoder_out: torch.Tensor, n_steps: int = else: # Non-blank prediction chunk_hyps.append(joint_out_max.item()) - + # Update predictor input and cache for next step pred_input = joint_out_max.reshape(1, 1) cache_m, cache_c = new_cache - + # Check if we've reached max steps per frame if step >= n_steps: break - + # Update persistent cache for next chunk self.pred_cache_m = cache_m self.pred_cache_c = cache_c