Skip to content

Commit

Permalink
fix test functions of models; just semantic fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfix committed Nov 14, 2023
1 parent f380a47 commit 113420d
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions LabsSolutions/02-pytorch-asr/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def ex_ctc():
# The size of our vocabulary (including the blank character)
vocab_size = 44
# The class id for the blank token
blank_id = 43
blank_id = 0

max_spectro_length = 50
min_transcript_length = 10
Expand All @@ -427,7 +427,7 @@ def ex_ctc():
# Compute a dummy vector of probabilities over the vocabulary (including the blank)
# log_probs is here batch_first, i.e. (Batch, Tx, vocab_size)
log_probs = torch.randn(batch_size, max_spectro_length, vocab_size).log_softmax(
dim=1
dim=2
)
spectro_lengths = torch.randint(
low=max_transcript_length, high=max_spectro_length, size=(batch_size,)
Expand All @@ -437,15 +437,15 @@ def ex_ctc():
# targets is here (batch_size, Ty)
targets = torch.randint(
low=0,
high=vocab_size + 1, # include the blank character
high=vocab_size, # include the blank character
size=(batch_size, max_transcript_length),
)
target_lengths = torch.randint(
low=min_transcript_length, high=max_transcript_length, size=(batch_size,)
)
loss = torch.nn.CTCLoss(blank=blank_id)

# The log_probs must be given as (Tx, Batch, vocab_size)
# The log_probs must be given as (Tx, Batch, vocab_size), hence we transpose it
vloss = loss(log_probs.transpose(0, 1), targets, spectro_lengths, target_lengths)

print(f"Our dummy loss equals {vloss}")
Expand Down Expand Up @@ -485,7 +485,6 @@ def ex_pack():


if __name__ == "__main__":
# ex_ctc()
# ex_pack()
test_model()
ex_ctc()
ex_pack()
# SOL@

0 comments on commit 113420d

Please sign in to comment.