text-generation-webui
text-generation-webui copied to clipboard
Make the RWKV model cache the RNN state between messages
As opposed to rescanning the character definition + the whole chat before replying to each new message. Now it only has to rescan the last two messages (its own + the user's)
Just realized I'm not deepcopying the cached state 😅 Will fix.
Fixed that. Also fixed a corner case where it caches so much that it has no new tokens to generate from (such as when hitting Regenerate).
Excellent work! I'm consistently getting a substantial boost in inference speed after the initial inference.
Output generated in 4.49 seconds (15.60 tokens/s, 70 tokens, context 33, seed 871616523)
Output generated in 2.48 seconds (28.25 tokens/s, 70 tokens, context 33, seed 2053810344)
Output generated in 2.39 seconds (29.23 tokens/s, 70 tokens, context 33, seed 2112064913)
Output generated in 4.25 seconds (16.49 tokens/s, 70 tokens, context 33, seed 1161902304)
Output generated in 2.14 seconds (32.75 tokens/s, 70 tokens, context 33, seed 1307464566)
Output generated in 2.08 seconds (33.70 tokens/s, 70 tokens, context 33, seed 634597054)
Thanks! The increase in tokens/s that you see is caused by the fact that ooba measures number of tokens generated / total time of inference
without taking into account that some of that time is spent scanning through the context and not actually generating new tokens yet. The PR eliminates most of it in most cases, hence the tokens/s counter is closer to the real value.
Does this work with RWKV raven? I read somewhere that you shouldn't call forward() manually for that model.
Raven is what I tested it with. There's no reason why you can't call forward
manually.
If you're referring to this message BlinkDL posted on Discord:
- Never call raw forward() directly. Instead, put it in a function that will record the text corresponding to the state.
That's just advice for state management, not an actual limitation. The generate_from_cached_state
function is extremely similar to PIPELINE.generate
, so it should be safe. Plus it does record the text corresponding to the state in self.cached_context
, so it follows the advice too.
In my tests, this makes new generations about 2x faster, which is very nice.
My only fear is that the chatrwkv library will change in the future causing this implementation to break, but if that happens, I can just uncomment the old function call.
The llama-cpp-python library does something similar (matching the prompt against a cache to check if the prompt has to be re-ingested) automatically. Maybe chatrwkv can do the same? @BlinkDL