bark icon indicating copy to clipboard operation
bark copied to clipboard

Boilerplate MPS code

Open huksley opened this issue 1 year ago • 6 comments

Adds get_device() and renames _clear_cuda_cache to _clear_device_cache. This does not give Apple Silicon support yet but makes it easier to implement in the future. Uses CUDA if available first.

Run as PYTORCH_ENABLE_MPS_FALLBACK=1 python 1.py

text_prompt = """
    WOMAN: I would like an oat milk latte please.
    MAN: Wow, that's expensive!
"""
audio_array = generate_audio(text_prompt)

Takes around 15 mins on M2 Pro with 32Gb and generates some noise or I just write WAV file in a wrong way :)

FIXME: aten::_weight_norm_interface are not currently implemented in MPS as tracked in:

https://github.com/pytorch/pytorch/issues/77764#issuecomment-1516967265

huksley avatar Apr 20 '23 21:04 huksley

haha awesome thanks. not much experience here, would love for someone to test and verify

gkucsko avatar Apr 21 '23 19:04 gkucsko

cleaned up and rebased on current main branch @gkucsko @sevdari

huksley avatar Apr 23 '23 19:04 huksley

fantastic thanks, do we have any benchmarks here just to make sure it actually improves things vs cpu? also, gotta double check it doesn't break anything on gpu, but looks good on quick glance

gkucsko avatar Apr 24 '23 15:04 gkucsko

hm, i'm getting much faster prediction speeds on CPU..

gkucsko avatar Apr 24 '23 22:04 gkucsko

Technically works for me on MacbookPro M1 (macOS Monterey 12.6.5, Python 3.11.3, PyTorch nightly build via Conda install) but there is only noise output.

Quick performance test with following prompt

text_prompt = """
     Hello, my name is Suno. And, uh — and I like pizza. [laughs] 
     But I also have other interests such as playing tic tac toe.
"""

First inference step is nearly the same for CPU and GPU (~44 seconds) while second inference step is about twice as fast (GPU 1:45 vs CPU 3:08 minutes).

domcross avatar Apr 25 '23 08:04 domcross

the noise output comes from torch.multinomial, somehow that seems bugged under mps and returns something that looks like constant numeric overflow. shuttling to cpu first and then back works.

gkucsko avatar Apr 25 '23 13:04 gkucsko

Anyone else getting this?

NotImplementedError: The operator 'aten::_weight_norm_interface' is not currently implemented for the MPS device.
If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764.
As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op.
WARNING: this will be slower than running natively on MPS

olso avatar Apr 25 '23 20:04 olso

Anyone else getting this?

Yes. You need to set PYTORCH_ENABLE_MPS_FALLBACK=1. Either by export PYTORCH_ENABLE_MPS_FALLBACK=1 in your console session or prefixing your python-call, e.g. PYTORCH_ENABLE_MPS_FALLBACK=1 python myscript.py

domcross avatar Apr 25 '23 21:04 domcross

ok just simplified device placement a bit as well as added experimental MPS support (disabled by default via environment variable): https://github.com/suno-ai/bark/commit/6c26fb7b3463334ae9fb4d63dae52f3c29506db0

on M1 max i am seeing: image

so it kinda looks like the semantic model is better on cpu but coarse benefits from mps. (note that this is on pytorch nightly). some more experiments needed....

gkucsko avatar Apr 25 '23 21:04 gkucsko

hmm, looks like it's warmup related. running again is faster for both. Alright, considering this experimental MPS support for now which makes inference roughly 2x faster. gonna close this PR for now. MPS can now be enabled via:

import os
os.environ["SUNO_ENABLE_MPS"] = "True"

gkucsko avatar Apr 25 '23 21:04 gkucsko

Got IndexError: index out of range in self when turning on SUNO_ENABLE_MPS

File ~/Code/Miniforge3/lib/python3.9/site-packages/bark/api.py:66, in semantic_to_waveform(semantic_tokens, history_prompt, temp, silent, output_full)
     54 coarse_tokens = generate_coarse(
     55     semantic_tokens,
     56     history_prompt=history_prompt,
   (...)
     59     use_kv_caching=True
     60 )
     61 fine_tokens = generate_fine(
     62     coarse_tokens,
     63     history_prompt=history_prompt,
     64     temp=0.5,
     65 )
---> 66 audio_arr = codec_decode(fine_tokens)
     67 if output_full:
     68     full_generation = {
     69         "semantic_prompt": semantic_tokens,
     70         "coarse_prompt": coarse_tokens,
     71         "fine_prompt": fine_tokens,
     72     }

File ~/Code/Miniforge3/lib/python3.9/site-packages/bark/generation.py:860, in codec_decode(fine_tokens)
    858 arr = arr.to(device)
    859 arr = arr.transpose(0, 1)
--> 860 emb = model.quantizer.decode(arr)
    861 out = model.decoder(emb)
    862 audio_arr = out.detach().cpu().numpy().squeeze()

File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/vq.py:112, in ResidualVectorQuantizer.decode(self, codes)
    109 def decode(self, codes: torch.Tensor) -> torch.Tensor:
    110     """Decode the given codes to the quantized representation.
    111     """
--> 112     quantized = self.vq.decode(codes)
    113     return quantized

File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:361, in ResidualVectorQuantization.decode(self, q_indices)
    359 for i, indices in enumerate(q_indices):
    360     layer = self.layers[i]
--> 361     quantized = layer.decode(indices)
    362     quantized_out = quantized_out + quantized
    363 return quantized_out

File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:288, in VectorQuantization.decode(self, embed_ind)
    287 def decode(self, embed_ind):
--> 288     quantize = self._codebook.decode(embed_ind)
    289     quantize = self.project_out(quantize)
    290     quantize = rearrange(quantize, "b n d -> b d n")

File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:202, in EuclideanCodebook.decode(self, embed_ind)
    201 def decode(self, embed_ind):
--> 202     quantize = self.dequantize(embed_ind)
    203     return quantize

File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:188, in EuclideanCodebook.dequantize(self, embed_ind)
    187 def dequantize(self, embed_ind):
--> 188     quantize = F.embedding(embed_ind, self.embed)
    189     return quantize

File ~/Code/Miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2204     # Note [embedding_renorm set_grad_enabled]
   2205     # XXX: equivalent to
   2206     # with torch.no_grad():
   2207     #   torch.embedding_renorm_
   2208     # remove once script supports set_grad_enabled
   2209     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

My device: M1 Pro MBP torch.__version__: 2.0.0

hywhuangyuwei avatar Apr 26 '23 19:04 hywhuangyuwei

try torch nightly

gkucsko avatar Apr 26 '23 22:04 gkucsko