optimum icon indicating copy to clipboard operation
optimum copied to clipboard

[ONNX export] Musicgen for text-to-audio

Open xenova opened this issue 2 years ago • 5 comments

Feature request

Musicgen was recently added to 🤗 Transformers (model doc) and it would be great to be able to export those models to ONNX with Optimum.

Motivation

This will allow me to support music generation models in Transformers.js

Your contribution

I will integrate into transformers.js once available in optimum.

xenova avatar Aug 21 '23 00:08 xenova

Hi, I'm also interested in converting musicgen model to onnx format so I can try to deploy it to the device. May i know is it support on Optimum now?

kanger45 avatar Sep 12 '23 09:09 kanger45

It would be great if this feature is done. Btw, how can I get the transformers.js ?

MaiZhiHao avatar Sep 12 '23 12:09 MaiZhiHao

May i know is it support on Optimum now?

Not yet 😇 cc @fxmarty

Btw, how can I get the transformers.js ?

You can check out the repo here or the documentation here. Since musicgen is not yet available in Optimum, however, it won't be available in transformers.js until then.

xenova avatar Sep 12 '23 15:09 xenova

hi @xenova,

May i know if have a plan or schedule to support Optimum for convert it to ONNX model?

kanger45 avatar Sep 22 '23 02:09 kanger45

any update?

zeke-john avatar Jan 29 '24 22:01 zeke-john

Hi @kanger45 @MaiZhiHao @zeke-john https://github.com/huggingface/optimum/pull/1779 is merged, which exports Musicgen in several parts to generate audio samples conditioned on a text prompt (Reference: https://huggingface.co/docs/transformers/model_doc/musicgen#text-conditional-generation). This uses the decoder KV cache. The following subcomponents are exported:

  • text_encoder.onnx: corresponds to the text encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L1457.
  • encodec_decode.onnx: corresponds to the Encodec audio encoder part in https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/musicgen/modeling_musicgen.py#L2472-L2480.
  • decoder_model.onnx: The Musicgen decoder, without past key values input, and computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).
  • decoder_with_past_model.onnx: The Musicgen decoder, with past_key_values input (KV cache filled), not computing cross attention. Not required at inference (use decoder_model_merged.onnx instead).
  • decoder_model_merged.onnx: The two previous models fused in one, to avoid duplicating weights. A boolean input use_cache_branch allows to select the branch to use. In the first forward pass where the KV cache is empty, dummy past key values inputs need to be passed and are ignored with use_cache_branch=False.
  • build_delay_pattern_mask.onnx: A model taking as input input_ids, pad_token_id, max_length, and building a delayed pattern mask to the input_ids. Implements https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/musicgen/modeling_musicgen.py#L1054.

This is usable e.g. in transformers.js, there is no implementation in Optimum for the runtime for now.

fxmarty avatar Apr 10 '24 09:04 fxmarty

@fxmarty Would this work for fintuned models on Musicgen? I used this repo to finetune the meduim model, and the output is a .pt model.

zeke-john avatar Apr 10 '24 09:04 zeke-john

@zeke-john yes, it should work as long as the checkpoint (& model repo) follows Transformers style (e.g. https://huggingface.co/facebook/musicgen-small/tree/main). .bin & .safetensors are supported, not sure about .pt

fxmarty avatar Apr 10 '24 09:04 fxmarty

Are there any supported ways to finetune musicgen besides the way i did it, so it stays a transformers model? Or can you convert a .pt model into a transformers model format?

zeke-john avatar Apr 10 '24 09:04 zeke-john

@zeke-john You should try to use https://github.com/huggingface/transformers/blob/main/src/transformers/models/musicgen/convert_musicgen_transformers.py which should allow you to do the conversion (audiocraft format to transformers format).

fxmarty avatar Apr 10 '24 10:04 fxmarty

@fxmarty after we export several onnx model, how can we run these onnx model locally?

Dannyjhl avatar Jun 14 '24 08:06 Dannyjhl

@fxmarty Would it be possible to add support for the stereo model? It seems to error out with finding the right index of the number of heads (I think). Since the codebook of the stereo model is larger. It could have something to do with the way the config file is read and the model is loaded, not fully sure.

Using framework PyTorch: 2.1.0+cu121
/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/model_patcher.py:942: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  if len(audio_codes) != 1:
/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.py:433: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  quantized_out = torch.tensor(0.0, device=codes.device)
/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.py:434: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  for i, indices in enumerate(codes):
Traceback (most recent call last):
  File "/home/[USER]/projects/[project]/optimum/test.py", line 3, in <module>
    model = ORTModelForCausalLM.from_pretrained('facebook/musicgen-stereo-small', export=True, task='text-to-audio')
  File "/home/[USER]/projects/[project]/optimum/optimum/onnxruntime/modeling_ort.py", line 737, in from_pretrained
    return super().from_pretrained(
  File "/home/[USER]/projects/[project]/optimum/optimum/modeling_base.py", line 438, in from_pretrained
    return from_pretrained_method(
  File "/home/[USER]/projects/[project]/optimum/optimum/onnxruntime/modeling_decoder.py", line 653, in _from_transformers
    main_export(
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/__main__.py", line 374, in main_export
    onnx_export_from_model(
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/convert.py", line 1188, in onnx_export_from_model
    _, onnx_outputs = export_models(
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/convert.py", line 782, in export_models
    export(
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/convert.py", line 887, in export
    export_output = export_pytorch(
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/convert.py", line 583, in export_pytorch
    onnx_export(
  File "<@beartype(torch.onnx.utils.export) at 0x79bbd8720ee0>", line 440, in export
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "<@beartype(torch.onnx.utils._model_to_graph) at 0x79bbd8721d80>", line 12, in _model_to_graph
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/[USER]/projects/[project]/optimum/optimum/exporters/onnx/model_patcher.py", line 944, in patched_forward
    audio_values = self._model._decode_frame(audio_codes[0], audio_scales)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.py", line 702, in _decode_frame
    embeddings = self.quantizer.decode(codes)
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.py", line 435, in decode
    layer = self.layers[i]
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/nn/modules/container.py", line 293, in __getitem__
    return self._modules[self._get_abs_string_index(idx)]
  File "/home/[USER]/anaconda3/envs/[ENV]/lib/python3.10/site-packages/torch/nn/modules/container.py", line 283, in _get_abs_string_index
    raise IndexError(f'index {idx} is out of range')
IndexError: index 4 is out of range

Revess avatar Oct 10 '24 23:10 Revess