bark
bark copied to clipboard
Boilerplate MPS code
Adds get_device() and renames _clear_cuda_cache to _clear_device_cache. This does not give Apple Silicon support yet but makes it easier to implement in the future. Uses CUDA if available first.
Run as PYTORCH_ENABLE_MPS_FALLBACK=1 python 1.py
text_prompt = """
WOMAN: I would like an oat milk latte please.
MAN: Wow, that's expensive!
"""
audio_array = generate_audio(text_prompt)
Takes around 15 mins on M2 Pro with 32Gb and generates some noise or I just write WAV file in a wrong way :)
FIXME: aten::_weight_norm_interface
are not currently implemented in MPS as tracked in:
https://github.com/pytorch/pytorch/issues/77764#issuecomment-1516967265
haha awesome thanks. not much experience here, would love for someone to test and verify
cleaned up and rebased on current main branch @gkucsko @sevdari
fantastic thanks, do we have any benchmarks here just to make sure it actually improves things vs cpu? also, gotta double check it doesn't break anything on gpu, but looks good on quick glance
hm, i'm getting much faster prediction speeds on CPU..
Technically works for me on MacbookPro M1 (macOS Monterey 12.6.5, Python 3.11.3, PyTorch nightly build via Conda install) but there is only noise output.
Quick performance test with following prompt
text_prompt = """
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
But I also have other interests such as playing tic tac toe.
"""
First inference step is nearly the same for CPU and GPU (~44 seconds) while second inference step is about twice as fast (GPU 1:45 vs CPU 3:08 minutes).
the noise output comes from torch.multinomial
, somehow that seems bugged under mps and returns something that looks like constant numeric overflow. shuttling to cpu first and then back works.
Anyone else getting this?
NotImplementedError: The operator 'aten::_weight_norm_interface' is not currently implemented for the MPS device.
If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764.
As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op.
WARNING: this will be slower than running natively on MPS
Anyone else getting this?
Yes. You need to set PYTORCH_ENABLE_MPS_FALLBACK=1
. Either by export PYTORCH_ENABLE_MPS_FALLBACK=1
in your console session or prefixing your python-call, e.g. PYTORCH_ENABLE_MPS_FALLBACK=1 python myscript.py
ok just simplified device placement a bit as well as added experimental MPS support (disabled by default via environment variable): https://github.com/suno-ai/bark/commit/6c26fb7b3463334ae9fb4d63dae52f3c29506db0
on M1 max i am seeing:
so it kinda looks like the semantic model is better on cpu but coarse benefits from mps. (note that this is on pytorch nightly). some more experiments needed....
hmm, looks like it's warmup related. running again is faster for both. Alright, considering this experimental MPS support for now which makes inference roughly 2x faster. gonna close this PR for now. MPS can now be enabled via:
import os
os.environ["SUNO_ENABLE_MPS"] = "True"
Got IndexError: index out of range in self
when turning on SUNO_ENABLE_MPS
File ~/Code/Miniforge3/lib/python3.9/site-packages/bark/api.py:66, in semantic_to_waveform(semantic_tokens, history_prompt, temp, silent, output_full)
54 coarse_tokens = generate_coarse(
55 semantic_tokens,
56 history_prompt=history_prompt,
(...)
59 use_kv_caching=True
60 )
61 fine_tokens = generate_fine(
62 coarse_tokens,
63 history_prompt=history_prompt,
64 temp=0.5,
65 )
---> 66 audio_arr = codec_decode(fine_tokens)
67 if output_full:
68 full_generation = {
69 "semantic_prompt": semantic_tokens,
70 "coarse_prompt": coarse_tokens,
71 "fine_prompt": fine_tokens,
72 }
File ~/Code/Miniforge3/lib/python3.9/site-packages/bark/generation.py:860, in codec_decode(fine_tokens)
858 arr = arr.to(device)
859 arr = arr.transpose(0, 1)
--> 860 emb = model.quantizer.decode(arr)
861 out = model.decoder(emb)
862 audio_arr = out.detach().cpu().numpy().squeeze()
File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/vq.py:112, in ResidualVectorQuantizer.decode(self, codes)
109 def decode(self, codes: torch.Tensor) -> torch.Tensor:
110 """Decode the given codes to the quantized representation.
111 """
--> 112 quantized = self.vq.decode(codes)
113 return quantized
File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:361, in ResidualVectorQuantization.decode(self, q_indices)
359 for i, indices in enumerate(q_indices):
360 layer = self.layers[i]
--> 361 quantized = layer.decode(indices)
362 quantized_out = quantized_out + quantized
363 return quantized_out
File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:288, in VectorQuantization.decode(self, embed_ind)
287 def decode(self, embed_ind):
--> 288 quantize = self._codebook.decode(embed_ind)
289 quantize = self.project_out(quantize)
290 quantize = rearrange(quantize, "b n d -> b d n")
File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:202, in EuclideanCodebook.decode(self, embed_ind)
201 def decode(self, embed_ind):
--> 202 quantize = self.dequantize(embed_ind)
203 return quantize
File ~/Code/Miniforge3/lib/python3.9/site-packages/encodec/quantization/core_vq.py:188, in EuclideanCodebook.dequantize(self, embed_ind)
187 def dequantize(self, embed_ind):
--> 188 quantize = F.embedding(embed_ind, self.embed)
189 return quantize
File ~/Code/Miniforge3/lib/python3.9/site-packages/torch/nn/functional.py:2210, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2204 # Note [embedding_renorm set_grad_enabled]
2205 # XXX: equivalent to
2206 # with torch.no_grad():
2207 # torch.embedding_renorm_
2208 # remove once script supports set_grad_enabled
2209 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2210 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
My device: M1 Pro MBP
torch.__version__
: 2.0.0
try torch nightly