bark icon indicating copy to clipboard operation
bark copied to clipboard

Add key/value caching for autoregressive generation

Open zygi opened this issue 1 year ago • 6 comments

Modifies model and generation code to support key/value caching. This should mostly be a no-op in terms of behavior except for some numerical stability related changes.

Tested locally. Semantic array generation only speeds up a little bit, likely because the context is small, but coarse code generation is sped up quite substantially. Tested by generating audio with randomness held constant and confirming that very-similar-sounding results are generated.

zygi avatar Apr 21 '23 02:04 zygi

amazing thanks, lots going on right now but will defo try to merge in the next couple of days. When i played with it in the past i didn't see much speedup on modern gpus with flash but maybe that's different on other gpus/cpus? or maybe i had a bug haha. did you try to benchmark on cpu? i know it's slow but might still be way faster

gkucsko avatar Apr 21 '23 19:04 gkucsko

This patch makes a huge difference for me, on a Radeon 7900XTX (running ROCm 5.5rc4). I tested this with my extension for the Oobabooga LLM frontend, and it's about 4.5 times faster with the k/v cache enabled (including text generation, which obviously doesn't change, so it's actually even faster). I went from roughly one minute for 15 seconds of audio to realtime generation.

wsippel avatar Apr 21 '23 20:04 wsippel

amazing thanks, lots going on right now but will defo try to merge in the next couple of days. When i played with it in the past i didn't see much speedup on modern gpus with flash but maybe that's different on other gpus/cpus? or maybe i had a bug haha. did you try to benchmark on cpu? i know it's slow but might still be way faster

I haven't tested cpu yet. Currently profiling (on gpu) tells me a significant fraction of the time is spent on python/torch overhead so I'm working on squeezing the code through torch.compile to see if that improves things.

zygi avatar Apr 21 '23 20:04 zygi

much appreciated!

gkucsko avatar Apr 22 '23 00:04 gkucsko

CPU benchmarks: there's a massive improvement. Time to generate an example sentence drops from ~30min to 2min. Tested on an i9-13900K, with the following code:

from bark import SAMPLE_RATE, generate_audio
from IPython.display import Audio

text_prompt = """
     Hello, my name is Suno. And, uh — and I like pizza. [laughs] 
     But I also have other interests such as playing tic tac toe.
"""
audio_array = generate_audio(text_prompt, history_prompt="en_speaker_1", use_kv_caching=True) # or False
Audio(audio_array, rate=SAMPLE_RATE)

torch.compile: bad news, unfortunately, Torch doesn't properly support the combination of torch._dynamo + dynamic shapes + AMP. I needed to switch to nightly to even get it to show signs of life, but even then it's not functioning well. Dropping either AMP or dynamic shapes makes it work, but results in a net negative change in performance. In conclusion, no torch.compile for now unless someone's willing to do significant refactoring.

zygi avatar Apr 22 '23 00:04 zygi

maybe you can try to use GPTCache. It can provide similar search, customize embedding function, provide storage function, and customize similar evaluation function for cached results, which can control cache more flexibly. 🤗 Hope GPTCache can help you.

SimFG avatar Apr 22 '23 03:04 SimFG

boom, amazing thanks! will prob make default in another commit.

gkucsko avatar Apr 22 '23 19:04 gkucsko