mt3
mt3 copied to clipboard
Bugfix in event decoder (?)
Background
While running this decoder on a sequence of tokens (midifile -> NoteSequence
-> event_tokens), I realised that after every note is processed and we return to processing shifts, the cur_time
is reset (since start_time
is 0). This then causes an exception to be thrown from note_sequences#decode_note_event
:
if time < state.current_time:
raise ValueError('event time < current time, %f < %f' % (
time, state.current_time))
I believe the cur_time
should be an offset of state.current_time
, not of start_time
- this allows the decoding sequence to pick up where it left off along the timeline.
Code example:
subsequences = note_seq.split_note_sequence(ns, 1)
event_batches = []
for i, subseq in enumerate(subsequences):
subseq = note_seq.apply_sustain_control_changes(subseq)
midi_times, midi_events = midi.note_sequence_to_events(subseq)
del subseq.control_changes[:]
events, _, _, _, _ = midi.encode_midi_events(audio_times, midi_times, midi_events)
event_batches.append(events)
reconstructed = midi.event_batches_to_note_sequence(event_batches, codec=utils.CODEC)
midi.note_sequence_to_midi_file(reconstructed, 'moo.mid')
# midi.py
def midi_file_to_note_sequence(midi_path) -> note_seq.NoteSequence:
"""
Convert a midi file to a list of onset and offset times and pitches
"""
print(f"Converting midi file to note sequence: {midi_path}")
ns = note_seq.midi_file_to_note_sequence(midi_path)
return ns
def note_sequence_to_events(ns: note_seq.NoteSequence) -> Tuple[Sequence[float], Sequence[note_sequences.NoteEventData]]:
return note_sequences.note_sequence_to_onsets_and_offsets_and_programs(ns)
def event_batches_to_note_sequence(event_batches, codec: event_codec.Codec=utils.CODEC) -> note_seq.NoteSequence:
print("converting event batches to note sequence")
decoding_state = note_sequences.NoteDecodingState()
total_invalid_ids = 0
total_dropped_events = 0
for events in event_batches:
invalid_ids, dropped_events = run_length_encoding.decode_events(
state=decoding_state,
tokens=events,
start_time=decoding_state.current_time,
max_time=None,
codec=codec,
decode_event_fn=note_sequences.decode_note_event
)
total_invalid_ids += invalid_ids
total_dropped_events += dropped_events
ns = note_sequences.flush_note_decoding_state(decoding_state)
print(f'Dropped {total_dropped_events} events')
print(f'Invalid ids: {total_invalid_ids}')
return ns
def note_sequence_to_midi_file(ns: note_seq.NoteSequence, midi_path: str):
"""
Convert a list of onset and offset times and pitches to a midi file
"""
print(f"Converting events to midi file: {midi_path}")
return note_seq.midi_io.note_sequence_to_midi_file(ns, midi_path)
def encode_midi_events(
audio_frame_times: Sequence[float],
midi_event_times: Sequence[float],
midi_event_values: Sequence[note_sequences.NoteEventData]
) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int], Sequence[int]]:
events, event_start_indices, event_end_indices, state_events, state_event_indices = run_length_encoding.encode_and_index_events(
state=note_sequences.NoteEncodingState(),
event_times=midi_event_times,
event_values=midi_event_values,
encode_event_fn=note_sequences.note_event_data_to_events,
codec=utils.CODEC,
frame_times=audio_frame_times,
encoding_state_to_events_fn=note_sequences.note_encoding_state_to_events
)
return events, event_start_indices, event_end_indices, state_events, state_event_indices
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
I've signed the CLA, can that build be rerun?
Can you provide a little more information on how you're using the decoder? I'm guessing it's for some kind of custom setup.
The way we're using it (as illustrated in metrics_utils.event_predictions_to_ns
), the ground truth for the current time offset comes from the start_time
passed into decode_events
. That's then used to set state.current_time
. We do this because we're decoding independently-inferred chunks of the full sequence.
We do this because we're decoding independently-inferred chunks of the full sequence.
Right, I had suspected as such. I was just trying to understand how the MT3 + note-seq libraries work and wrote some code to split a midi file into subsequences just to see if I could then reconstruct the original midi file. So my input would be a sequence of events corresponding to the entire midi file (several minutes worth of events).
I can see how this function would work as-is for a small slice containing only a single note event + some shift events, but it would fail if it encounters multiple note events (unless there are increasingly more shifts between subsequent note events).
So perhaps it's by design - but I believe this change is still an improvement, as it should not change functionality for the 'small slice' use-case and will prevent errors in a longer slice use-case.
I have updated the description with my relevant code.
I think it still doesn't work for our case because state.current_time
at the end of one chunk isn't necessarily the right start time of the subsequence chunk. For example, what if there are several chunks in a row with no note events? The start time of the chunk needs to come from some external source.