candle icon indicating copy to clipboard operation
candle copied to clipboard

Can't loop over model implementation based off examples more than N times (7-20+ it ends up breaking)

Open groovybits opened this issue 1 year ago • 12 comments
trafficstars

Hi I have these errors and my code tries to even retry but there seems to be an issue when I try to continuously use the gemma and mistral examples integrated into my code over 7+ prompt iterations. It breaks the weights where they come out 0 then if retrying it breaks the shape.

crash4.txt

Here is my code: https://github.com/groovybits/rsllm/blob/main/src/candle_gemma.rs

Basically have it threaded and able to stream out the tokens and allow concurrency. Yet I do see issues and need to re-init the model and tokens etc which doesn't seem possible while keeping the binary running. It's the only caveat so far I have with this vs. other implementations of model running, (this whole use case beats the hell out of the LLM and models of course with 24/7 story telling high volume output maxing out an M2 Ultra GPU with 192g) once fixed this is definitely the best so far after trying Typescript and Python, Rust is really efficient at running the threading and low latency of concurrent pipeline processing of the llm text + media generated off of the text segments chunked out. Thank you for this, I would have been stuck with llama.cpp API and other less optimized / unified into one binary ways of doing such a thing. (with metavoice, it will mostly be Rust, I doubt NDI will ever be Rust native though :/ but the bindings work well sending the assets produced with Candle!).

groovybits avatar Mar 23 '24 09:03 groovybits

It's a bit tricky to investigate as the code you're pointing at is fairly large. Maybe you could come up with a small standalone repro that would make it easier to look at? My guess looking at the error is that you're using the same models with multiple prompts without resetting the kv-cache but I don't have much of a clue whether this is a likely thing or not. The gemma model (and most transformer based models in candle) has a mutable state because of the kv cache, if you want to use the same model to serve different users, you will have to clone it so that each model has its separate cache. Cloning is fairly cheap as the weights will be shared between all the models. Note that you can also reset the kv cache but then the model will have lost all its context.

LaurentMazare avatar Mar 23 '24 09:03 LaurentMazare

Ah so the model keeps context actually? I keep context but it does feel like what you are saying is probably occurring. I guess this is something lower level to deal with running this way. Very interesting, I think that may help already, sounds like I need to clear the kv cache? (any example for me to run with?).

Thank you!

groovybits avatar Mar 23 '24 11:03 groovybits

Reseting the kv cache each time would work but would result in very poor performance, you probably want to do the same as is done in all candle examples, i.e. process at the first the initial prompt as a whole and then proceed token by token (thanks to the internal kv cache). When it comes to serving multiple users with the same model in parallel, I would recommend cloning the model rather than resetting the cache etc, again otherwise performance is likely to be really bad.

LaurentMazare avatar Mar 23 '24 11:03 LaurentMazare

If you are interested in having multiple requests, mistral.rs provides that functionality by managing the KV cache behind the scenes.

EricLBuehler avatar Mar 23 '24 11:03 EricLBuehler

I did this and suspect it will allow me to have a clean slate each run?

diff --git a/src/candle_gemma.rs b/src/candle_gemma.rs
index 913bd24..b02b2d1 100644
--- a/src/candle_gemma.rs
+++ b/src/candle_gemma.rs
@@ -59,6 +59,7 @@ impl TextGeneration {
     async fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
         let verbose_prompt: bool = false;

+        self.model.clear_kv_cache();
         self.tokenizer.clear();
         let mut tokens = self
             .tokenizer
diff --git a/src/candle_mistral.rs b/src/candle_mistral.rs
index bfc6570..2870775 100644
--- a/src/candle_mistral.rs
+++ b/src/candle_mistral.rs
@@ -67,6 +67,10 @@ impl TextGeneration {

     async fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
         let verbose_prompt: bool = false;
+        match &mut self.model {
+            Model::Mistral(m) => m.clear_kv_cache(),
+            Model::Quantized(m) => m.clear_kv_cache(),
+        };
         self.tokenizer.clear();
         let mut tokens = self
             .tokenizer

groovybits avatar Mar 23 '24 14:03 groovybits

It will, but as mentioned earlier performance will be really poor.

LaurentMazare avatar Mar 23 '24 14:03 LaurentMazare

It will, but as mentioned earlier performance will be really poor.

Ah okay, I am cloning them now instead for the moment. At least it sounds like it will avoid crashing and seems to run fine / better than the kv cache clear (which ran but could feel some slowness).

I guess to persist it and manage the cache sounds ideal, which I am going to work on with the general goal of making the models "load only once" yet avoid having the kv cache fill up too :) It runs pretty nice for me vs. the Python performance already, so everything else is gravy and am seeing at least it can only get better from here. I am definitely interested in figuring out how to persist and optimize the caching without filling as easy as possible with Candle. Reading that example and attempting to grok it..

groovybits avatar Mar 23 '24 16:03 groovybits

Great that it ends up working well. Not sure what you mean by "avoid having the kv cache fill up too", the cloning strategy should work for most use cases: you load the "empty" model once to start with and each time you run a thread to handle query for a separate client, you just clone the "empty" model to start with and keep it for the duration of the client "session". Not sure if it makes sense but anyway :)

LaurentMazare avatar Mar 23 '24 16:03 LaurentMazare

Great that it ends up working well. Not sure what you mean by "avoid having the kv cache fill up too", the cloning strategy should work for most use cases: you load the "empty" model once to start with and each time you run a thread to handle query for a separate client, you just clone the "empty" model to start with and keep it for the duration of the client "session". Not sure if it makes sense but anyway :)

Yes I believe I am doing that, was figuring there was more to it but that is wonderful and it seems to work great again and hasn't crashed yet :D and suspect it won't since this makes a lot of sense. I was seeing how the caching was actually causing issues when I needed a clean slate, so now all is good! Thanks.

groovybits avatar Mar 23 '24 17:03 groovybits

Hmmm it actually does this still with either clearing or cloning :/ Seems like I should be doing it right, not sure what is happening potentially to still do it. Reopening for now and investigating more.

I think it seems to run longer, but still hits the issue and breaks it from working on subsequent runs.

groovybits avatar Mar 24 '24 03:03 groovybits

Also note that now I am seeing that over time the stable diffusion becomes all black. Is it also in need of some kv refresh type issue?

I didn't have it run so long before because the LLM side of candle would break eventually. I put in API llama.cpp support and when running on that now it goes long enough to have the stable diffusion break :/ it would before but I thought it was token input count, yet I fixed that and not running longer it goes all black. I detect all black, and when that happens I use the last good image. It seems to hit that often now after 10+ iterations.

So what is the right way to handle SD reset? I clone it too, no fix.

Is there any command to reset this in Candle or way to add that? I don't see how this can run in loops with continuous stability or any examples? My system tests that by running a 24/7 generation in a loop and seems to catch these things. It's making me now want to add API usage of another SD server, probably the automatic one, which sucks because it's python and with SD there are no other implementations that seem as low usage on metal like this. I ideally want to use all Rust/Candle so am about to use only API's, so am hoping there's some solution that doesn't require a lot of internal knowledge of LLMs and Candle itself, feels like it's a one-run capable but multi-run over time either has bugs with this issue or else I somehow really messed up the examples (minimal changes in them so feels like I mostly lifted them and ran them in a loop technically and found these 2 issues with the zero weights value in LLM mode (gemma and mistral, no difference) and SD goes all black. Both seeming to be from some kv cache / token run-out over time from not really resetting (which I still don't get, I don't save anything it re-init's each time the code, so I don't get what is storing it even after I clone the function each run / do that kv cache clear in the LLMs, both ways fail to fix the issues).

groovybits avatar Mar 26 '24 08:03 groovybits

Re stable diffusion I don't think that there is any form of caching there so would expect it to work in a loop, actually the examples let you specify some num_samples I just used it to generate 100 images and all were fine - there was also no increase of cuda memory usage during the whole run (haven't tried on metal but hopefully it's similar).

Re LLM, having a kv cache is I think fairly common and something that you are likely to have to deal with in any language. I'm not sure to understand what you mean by "break" and how it actually breaks so you may want to give more details but the way to go should be to clone the model for each "session" and then use it as per the examples.

And there should be no notion of global state for candle or anything you could "reset".

LaurentMazare avatar Mar 26 '24 09:03 LaurentMazare