Skip to content

Commit

Permalink
Decoding fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrmorrison committed Jan 14, 2024
1 parent 0fe93b9 commit 9e515d2
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions penn/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def __call__(self, logits):
else distributions.device.index)
bins = torbi.decode(
distributions[0].T,
self.transition(),
self.initial(),
self.transition,
self.initial,
gpu)

# Convert to frequency in Hz
Expand All @@ -103,6 +103,7 @@ def initial(self):
"""Create initial probability matrix for PYIN"""
initial = torch.zeros(2 * penn.PITCH_BINS)
initial[penn.PITCH_BINS:] = 1 / penn.PITCH_BINS
return initial

@functools.cached_property
def transition(self):
Expand All @@ -113,6 +114,7 @@ def transition(self):
transition = torch.kron(
torch.tensor([[.99, .01], [.01, .99]]),
transition)
return transition


class Viterbi(Decoder):
Expand All @@ -123,16 +125,16 @@ def __init__(self, local_expected_value=True):
def __call__(self, logits):
"""Decode pitch using viterbi decoding (from librosa)"""
distributions = torch.nn.functional.softmax(logits, dim=1)
distributions = distributions.permute(2, 1, 0)
distributions = distributions.permute(2, 1, 0) # F x C x 1 -> 1 x C x F

# Viterbi decoding
gpu = (
None if distributions.device.type == 'cpu'
else distributions.device.index)
bins = torbi.decode(
distributions[0].T,
self.transition(),
self.initial(),
self.transition,
self.initial,
gpu)

# Convert to frequency in Hz
Expand All @@ -149,7 +151,7 @@ def __call__(self, logits):
return bins.T, pitch.T

@functools.cached_property
def transition(self):
def initial(self):
"""Create uniform initial probabilities"""
return torch.full((penn.PITCH_BINS,), 1 / penn.PITCH_BINS)

Expand Down

0 comments on commit 9e515d2

Please sign in to comment.