trlx
trlx copied to clipboard
8-bit inference (#512)
@glerzing Do you have an example run using 8bit?
There are a few things to improve, I'm working on it. I'll also add an example.
There are a few things to improve, I'm working on it. I'll also add an example.
@glerzing Thank you for the great PR, do you have any update on this or anything that you need to help with?
I added from_pretrained_kwargs
to the model config to add some flexibility to how the model is loaded.
When testing, I ran into 2 problems with ppo_sentiments_8bit.py
when it executes the function generate
:
- Something similar to https://github.com/Vision-CAIR/MiniGPT-4/issues/74 : traceback_2.txt
- An error due to a matrix in half precision instead being a boolean matrix : traceback.txt.
In both cases, it doesn't look related to trlx. Quantization can introduce bugs because it additionally relies on accelerate and bitsandbytes which also have dependencies, and there can be problems with the versions of different libraries. With the library versions listed in requirements.txt, I run into the 2nd problem. If I take with the latest versions, I run into the 1st one.
@PhungVanDuy If you have time can you help to debug this? I think having lower precision inference and training options will be very useful.
@glerzing Are you able to get quantized model inference working with our package requirements? (but without any training)
No, when I have the version 4.28.1 of the transformers library like in trlx, I have RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
, when it's >= 4.30.0, I get RuntimeError: probability tensor contains either 'inf', 'nan' or element < 0
, which I guess happens further in the processing (my guess is that this bug is also present with the version 4.28.1 but the processing doesn't go so far).
Actually, adding the argument torch_dtype=torch.bfloat16
to from_pretrained
and using a more recent version of the transformers library solves the issue, and enables to run ppo_sentiments_8bit.py
.
@glerzing @Dahoas I tried to run inference with 8-bit but I dont think this way could help inference faster: https://wandb.ai/pvduy/trlx/reports/8bit-Sentiment-Rollout--Vmlldzo0OTUxOTM5
This is also mentioned by the author here:
The main purpose of the LLM.int8() method is to make large models more accessible without performance degradation. But the method would be less useful if it is very slow. So we benchmarked the generation speed of multiple models. We find that BLOOM-176B with LLM.int8() is about 15% to 23% slower than the fp16 version – which is still quite acceptable.
Let's come up with another idea like using vLLM, with my experiments vLLM actually boosts the inference time. I will work in that direction.
Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.
On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising
Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.
On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising
I have to update that one, I guess we should also update the transformer's version in terms of supporting LLaMA 2.
I am checking vLLM to see how hard to integrate. Thank you for your suggestion on asynchronous PPO.
I was wondering if there should be an example of how to train 16-bit models.
Because now that there is the config argument from_pretrained_kwargs
you can easily set torch_dtype=torch.bfloat16
, which doesn't seem obvious to newcomers. On the other side, I'm not sure whether it's worth adding another file ppo_sentiments_16bit.py
just to show that we can easily do that.
@glerzing Checking in on the state of this pr. Do you have any more features you would like to add? If not, let's get it merged sometime this week