onsets-and-frames icon indicating copy to clipboard operation
onsets-and-frames copied to clipboard

Upload pretrained model to run inference

Open greenbech opened this issue 5 years ago • 3 comments
trafficstars

It would be great if anyone could upload a pretrained model so that we could try this model/project without needing to train the model. It is quite a big commitment to wait a week for training (as mentioned in #10 ) if you primarily just want to check out the performance on some .wav files.

And I would also like to say this repo is very well written and educational. Thanks!

greenbech avatar Apr 09 '20 10:04 greenbech

Hi, please try this one, trained for 500,000 iterations on the MAESTRO dataset.

I haven't touched the model in a while, but torch.load('model-500000.pt’) should be able to load the PyTorch model.

jongwook avatar Apr 09 '20 23:04 jongwook

The provided file works great, thanks a lot! I didn't need to use torch.load('model-500000.pt’) since both evaluate.py and transcribe.py has the model file as an argument.

However, after I first got this error message when trying the run the scripts:

Traceback (most recent call last):
  File "transcribe.py", line 101, in <module>
    transcribe_file(**vars(parser.parse_args()))
  File "transcribe.py", line 74, in transcribe_file
    predictions = transcribe(model, audio)
  File "transcribe.py", line 53, in transcribe
    onset_pred, offset_pred, _, frame_pred, velocity_pred = model(mel)
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/greenbech/git/onsets-and-frames/onsets_and_frames/transcriber.py", line 87, in forward
    onset_pred = self.onset_stack(mel)
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/container.py", line 100, in forward
    input = module(input)
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/greenbech/git/onsets-and-frames/onsets_and_frames/lstm.py", line 29, in forward
    output[:, start:end, :], (h, c) = self.rnn(x[:, start:end, :], (h, c))
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 558, in forward
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
  File "/Users/greenbech/.pyenv/versions/3.7.5/lib/python3.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
    type(self).__name__, name))
AttributeError: 'LSTM' object has no attribute '_flat_weights'

Downgrading from 1.4.0 to torch==1.2.0 fixed it for me.

It is also quite cumbersome to resample to the audio file to 16kHz before hand, so I added this locally to transcribe.py:

def float_samples_to_int16(y):
  """Convert floating-point numpy array of audio samples to int16."""
  # From https://github.com/tensorflow/magenta/blob/671501934ff6783a7912cc3e0e628fd0ea2dc609/magenta/music/audio_io.py#L48
  if not issubclass(y.dtype.type, np.floating):
    raise ValueError('input samples not floating-point')
  return (y * np.iinfo(np.int16).max).astype(np.int16)


def load_and_process_audio(flac_path, sequence_length, device):

    random = np.random.RandomState(seed=42)

    audio, sr = librosa.load(flac_path, sr=SAMPLE_RATE)
    audio = float_samples_to_int16(audio)
    
    assert sr == SAMPLE_RATE
    assert audio.dtype == 'int16'
    ...

There might be elegant ways of doing this, but I was not able to convert to uint16 with librosa or resample with soundfile.read.

I also think the the model you provided should be available in the README for others to try out without going to this issue. I was thinking either directly in ./data/pretrained since this it the easiest setup but increases the repo size unnecessarily or with the drive url you provided.

Would you mind a PR with this?

greenbech avatar Apr 10 '20 12:04 greenbech

Yeah! I'll need some housekeeping to make the checkpoint work cross-version. PR is welcome! Thanks :D

jongwook avatar Apr 10 '20 18:04 jongwook