Skip to content

Commit

Permalink
Merge pull request #16 from CameronChurchwell/dev
Browse files Browse the repository at this point in the history
torbi
  • Loading branch information
maxrmorrison authored Jan 14, 2024
2 parents 0fe93b9 + 61e7b4b commit 63c36f5
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions penn/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def initial(self):
initial = torch.zeros(2 * penn.PITCH_BINS)
initial[penn.PITCH_BINS:] = 1 / penn.PITCH_BINS

return initial

@functools.cached_property
def transition(self):
"""Create the Viterbi transition matrix for PYIN"""
Expand All @@ -114,6 +116,8 @@ def transition(self):
torch.tensor([[.99, .01], [.01, .99]]),
transition)

return transition


class Viterbi(Decoder):

Expand All @@ -129,11 +133,11 @@ def __call__(self, logits):
gpu = (
None if distributions.device.type == 'cpu'
else distributions.device.index)
bins = torbi.decode(
distributions[0].T,
self.transition(),
self.initial(),
gpu)
bins = torbi.from_probabilities(
observation=distributions[0].T.unsqueeze(dim=0),
transition=self.transition,
initial=self.initial,
gpu=gpu)

# Convert to frequency in Hz
if self.local_expected_value:
Expand All @@ -149,7 +153,7 @@ def __call__(self, logits):
return bins.T, pitch.T

@functools.cached_property
def transition(self):
def initial(self):
"""Create uniform initial probabilities"""
return torch.full((penn.PITCH_BINS,), 1 / penn.PITCH_BINS)

Expand Down

0 comments on commit 63c36f5

Please sign in to comment.