Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions apps/realtime-asr/stream_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def __init__(self, config: StreamingConfig):
self.offset = 0
self.total_frames_processed = 0
self.accumulated_text = "" # Accumulate text across all chunks

# Transducer predictor caches
self.pred_cache_m = None
self.pred_cache_c = None
self.pred_input = None

self.reset_cache()

# Audio capture - prefer PyAudio on macOS for better stability
Expand Down Expand Up @@ -134,6 +140,16 @@ def reset_cache(self):
self.total_frames_processed = 0
self.accumulated_text = "" # Reset accumulated text

# Reset transducer predictor cache if the model has transducer
if hasattr(self.model.model, "predictor"):
predictor = self.model.model.predictor
self.pred_cache_m, self.pred_cache_c = predictor.init_state(
batch_size=1, method="zero", device=self.device
)
# Initialize with blank token
blank_id = self.model.model.blank
self.pred_input = torch.tensor([blank_id]).reshape(1, 1).to(self.device)

def extract_features(self, audio_chunk: np.ndarray) -> torch.Tensor:
"""Extract fbank features from audio chunk"""
# Convert to torch tensor
Expand Down Expand Up @@ -188,21 +204,86 @@ def process_chunk(self, audio_chunk: np.ndarray) -> Tuple[torch.Tensor, str]:
def decode(self, encoder_out: torch.Tensor) -> str:
"""Decode encoder output to text"""
text: str
if hasattr(self.model.model, "ctc"):
if self.model.config.model == "asr_model":
# CTC decoding
ctc_probs = self.model.model.ctc.log_softmax(encoder_out) # [B, T, vocab]
topk = ctc_probs.argmax(dim=-1) # [B, T]
hyps = [hyp.tolist() for hyp in topk]
text = str(get_output(hyps, self.model.char_dict, self.model.config.model)[0])
elif hasattr(self.model, "decoder"):
# Transducer or attention decoder
# Implement appropriate decoding here
text = "[Decoder output]"
elif self.model.config.model == "transducer":
# Transducer decoding using streaming optimized search
hyps = self.decode_transducer_streaming(encoder_out)
text = str(get_output([hyps], self.model.char_dict, self.model.config.model)[0])
else:
text = "[Unknown decoder type]"

return text

def decode_transducer_streaming(self, encoder_out: torch.Tensor, n_steps: int = 64) -> list:
"""
Streaming transducer decoder based on optimized_search.

This function processes encoder output frame by frame and maintains
predictor state across chunks for streaming inference.

Args:
encoder_out: Encoder output tensor [B=1, T, E]
n_steps: Maximum non-blank predictions per frame

Returns:
List of predicted token IDs (without blanks)
"""
model = self.model.model
blank_id = model.blank

max_len = encoder_out.size(1)

# Use persistent predictor cache across chunks
cache_m = self.pred_cache_m
cache_c = self.pred_cache_c
pred_input = self.pred_input

# Output buffer for this chunk
chunk_hyps = []

# Process each frame
for t in range(max_len):
encoder_out_t = encoder_out[:, t : t + 1, :] # [B=1, 1, E]

# Allow up to n_steps non-blank predictions per frame
for step in range(1, n_steps + 1):
# Forward through predictor
pred_out_step, new_cache = model.predictor.forward_step(
pred_input, (cache_m, cache_c)
) # [B=1, 1, P]

# Forward through joint network
joint_out_step = model.joint(encoder_out_t, pred_out_step) # [B=1, 1, V]
joint_out_probs = joint_out_step.log_softmax(dim=-1)

# Get best prediction
joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # scalar
if joint_out_max == blank_id:
# Blank prediction - move to next frame
break
else:
# Non-blank prediction
chunk_hyps.append(joint_out_max.item())

# Update predictor input and cache for next step
pred_input = joint_out_max.reshape(1, 1)
cache_m, cache_c = new_cache

# Check if we've reached max steps per frame
if step >= n_steps:
break

# Update persistent cache for next chunk
self.pred_cache_m = cache_m
self.pred_cache_c = cache_c
self.pred_input = pred_input
return chunk_hyps

def run(self):
"""Main streaming loop"""
print("\n" + "=" * 60)
Expand Down
137 changes: 137 additions & 0 deletions examples/asr/rnnt/conf/chunkformer-rnnt-small-vie-stream-dct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# network architecture
# encoder related
encoder: chunkformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: dw_striding # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'chunk_rel_pos'
selfattention_layer_type: 'chunk_rel_seflattn'
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
# Enable the settings below for joint training on full and chunk context
dynamic_conv: true
dynamic_chunk_sizes: [4, 6, 8] # 1 frame 80ms - [320, 480, 640]
# Note that the left context is relative spaned depeding on the number of encoder layer
dynamic_left_context_sizes: [40, 50, 60]
dynamic_right_context_sizes: [0] # No right context for streaming
streaming: true

joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
joint_mode: 'add'
activation: 'tanh'

predictor: rnn
predictor_conf:
embed_size: 256
output_size: 256
embed_dropout: 0.1
hidden_size: 256
num_layers: 2
bias: true
rnn_type: 'lstm'
dropout: 0.1

decoder: bitransformer
decoder_conf:
attention_heads: 4
dropout_rate: 0.1
linear_units: 2048
num_blocks: 3
positional_dropout_rate: 0.1
r_num_blocks: 3
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1

tokenizer: bpe
tokenizer_conf:
symbol_table_path: 'data/lang_char/train_hf_bpe1024_units.txt'
split_with_space: false
bpe_path: 'data/lang_char/train_hf_bpe1024.model'
non_lang_syms_path: null
is_multilingual: false
num_languages: 1
special_tokens:
<blank>: 0
<unk>: 1
<sos>: 2
<eos>: 2

ctc: ctc
ctc_conf:
ctc_blank_id: 0

cmvn: global_cmvn
cmvn_conf:
cmvn_file: 'data/train_hf/global_cmvn'
is_json_cmvn: true

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
attention_weight: 0.15
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
reverse_weight: 0.3
enable_k2: True

dataset: asr
dataset_conf:
filter_conf:
max_length: 40960
min_length: 0
token_max_length: 400
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1000
sort: False
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'dynamic' # static or dynamic
max_frames_in_batch: 300000
pad_feat: True


grad_clip: 5
accum_grad: 2
max_epoch: 200
log_interval: 100

optim: adamw
optim_conf:
lr: 0.001
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 15000