diff --git a/LabsSolutions/02-pytorch-asr/models.py b/LabsSolutions/02-pytorch-asr/models.py index 78f736a..833eabf 100644 --- a/LabsSolutions/02-pytorch-asr/models.py +++ b/LabsSolutions/02-pytorch-asr/models.py @@ -418,7 +418,7 @@ def ex_ctc(): # The size of our vocabulary (including the blank character) vocab_size = 44 # The class id for the blank token - blank_id = 43 + blank_id = 0 max_spectro_length = 50 min_transcript_length = 10 @@ -427,7 +427,7 @@ def ex_ctc(): # Compute a dummy vector of probabilities over the vocabulary (including the blank) # log_probs is here batch_first, i.e. (Batch, Tx, vocab_size) log_probs = torch.randn(batch_size, max_spectro_length, vocab_size).log_softmax( - dim=1 + dim=2 ) spectro_lengths = torch.randint( low=max_transcript_length, high=max_spectro_length, size=(batch_size,) @@ -437,7 +437,7 @@ def ex_ctc(): # targets is here (batch_size, Ty) targets = torch.randint( low=0, - high=vocab_size + 1, # include the blank character + high=vocab_size, # include the blank character size=(batch_size, max_transcript_length), ) target_lengths = torch.randint( @@ -445,7 +445,7 @@ def ex_ctc(): ) loss = torch.nn.CTCLoss(blank=blank_id) - # The log_probs must be given as (Tx, Batch, vocab_size) + # The log_probs must be given as (Tx, Batch, vocab_size), hence we transpose it vloss = loss(log_probs.transpose(0, 1), targets, spectro_lengths, target_lengths) print(f"Our dummy loss equals {vloss}") @@ -485,7 +485,6 @@ def ex_pack(): if __name__ == "__main__": - # ex_ctc() - # ex_pack() - test_model() + ex_ctc() + ex_pack() # SOL@