diff --git a/pretty_midi/pretty_midi.py b/pretty_midi/pretty_midi.py index 30f2988..76d111c 100644 --- a/pretty_midi/pretty_midi.py +++ b/pretty_midi/pretty_midi.py @@ -287,7 +287,7 @@ def __get_instrument(program, channel, track, create_new): for track_idx, track in enumerate(midi_data.tracks): # Keep track of last note on location: # key = (instrument, note), - # value = (note on time, velocity) + # value = (note-on tick, velocity) last_note_on = collections.defaultdict(list) # Keep track of which instrument is playing in each channel # initialize to program 0 for all channels @@ -306,20 +306,37 @@ def __get_instrument(program, channel, track, create_new): # Store this as the last note-on location note_on_index = (event.channel, event.note) last_note_on[note_on_index].append(( - self.__tick_to_time[event.time], - event.velocity)) + event.time, event.velocity)) # Note offs can also be note on events with 0 velocity elif event.type == 'note_off' or (event.type == 'note_on' and event.velocity == 0): # Check that a note-on exists (ignore spurious note-offs) - if (event.channel, event.note) in last_note_on: + key = (event.channel, event.note) + if key in last_note_on: # Get the start/stop times and velocity of every note - # which was turned on with this instrument/drum/pitch - for start, velocity in last_note_on[ - (event.channel, event.note)]: - end = self.__tick_to_time[event.time] + # which was turned on with this instrument/drum/pitch. + # One note-off may close multiple note-on events from + # previous ticks. In case there's a note-off and then + # note-on at the same tick we keep the open note from + # this tick. + end_tick = event.time + open_notes = last_note_on[key] + + notes_to_close = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick != end_tick] + notes_to_keep = [ + (start_tick, velocity) + for start_tick, velocity in open_notes + if start_tick == end_tick] + + for start_tick, velocity in notes_to_close: + start_time = self.__tick_to_time[start_tick] + end_time = self.__tick_to_time[end_tick] # Create the note event - note = Note(velocity, event.note, start, end) + note = Note(velocity, event.note, start_time, + end_time) # Get the program and drum type for the current # instrument program = current_instrument[event.channel] @@ -330,8 +347,14 @@ def __get_instrument(program, channel, track, create_new): program, event.channel, track_idx, 1) # Add the note event instrument.notes.append(note) - # Remove the last note on for this instrument - del last_note_on[(event.channel, event.note)] + + if len(notes_to_close) > 0 and len(notes_to_keep) > 0: + # Note-on on the same tick but we already closed + # some previous notes -> it will continue, keep it. + last_note_on[key] = notes_to_keep + else: + # Remove the last note on for this instrument + del last_note_on[key] # Store pitch bends elif event.type == 'pitchwheel': # Create pitch bend class instance diff --git a/tests/test_pretty_midi.py b/tests/test_pretty_midi.py index d0068c6..c629e73 100644 --- a/tests/test_pretty_midi.py +++ b/tests/test_pretty_midi.py @@ -1,5 +1,7 @@ import pretty_midi import numpy as np +import mido +from tempfile import NamedTemporaryFile def test_get_beats(): @@ -272,3 +274,60 @@ def simple(): for ks, t, k in zip(pm.key_signature_changes, ks_times, ks_keys): assert ks.time == t assert ks.key_number == k + + +def test_properly_order_overlapping_notes(): + def make_mido_track(notes_str, file): + track = mido.MidiTrack() + for line in notes_str.split('\n'): + line = line.strip() + if line: + track.append(mido.Message.from_str(line)) + mido_file = mido.MidiFile() + mido_file.tracks.append(track) + mido_file.save(file=file) + + # two notes with pitch 72 open at once + bad_track = """ + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=72 velocity=88 time=48 + note_on channel=0 note=72 velocity=0 time=0 + note_on channel=0 note=74 velocity=88 time=48 + note_on channel=0 note=72 velocity=0 time=0 + note_on channel=0 note=72 velocity=88 time=48 + note_on channel=0 note=74 velocity=0 time=0 + note_on channel=0 note=72 velocity=0 time=48 + """ + + # the 72 note is first closed, then another one opened + good_track = """ + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=72 velocity=0 time=48 + note_on channel=0 note=72 velocity=88 time=0 + note_on channel=0 note=74 velocity=88 time=48 + note_on channel=0 note=72 velocity=0 time=0 + note_on channel=0 note=72 velocity=88 time=48 + note_on channel=0 note=74 velocity=0 time=0 + note_on channel=0 note=72 velocity=0 time=48 + """ + + for kind, track in (('good', good_track), ('bad', bad_track)): + with NamedTemporaryFile() as file: + make_mido_track(track, file) + file.seek(0) + pm_song = pretty_midi.PrettyMIDI(file) + + def extract_notes(pm_track): + return np.array([(note.pitch, note.end-note.start) for note + in pm_track.notes]) + + expected = np.array([[72, 0.05], [72, 0.05], [74, 0.05], [72, 0.05]]) + assert np.allclose(expected, extract_notes(pm_song.instruments[0])) + + with NamedTemporaryFile() as file: + pm_song.write(file) + file.seek(0) + pm_song_written = pretty_midi.PrettyMIDI(file) + + assert np.allclose(expected, + extract_notes(pm_song_written.instruments[0]))