trlx icon indicating copy to clipboard operation
trlx copied to clipboard

8-bit inference (#512)

Open glerzing opened this issue 1 year ago • 13 comments

glerzing avatar Jun 24 '23 05:06 glerzing

@glerzing Do you have an example run using 8bit?

Dahoas avatar Jul 10 '23 13:07 Dahoas

There are a few things to improve, I'm working on it. I'll also add an example.

glerzing avatar Jul 13 '23 00:07 glerzing

There are a few things to improve, I'm working on it. I'll also add an example.

@glerzing Thank you for the great PR, do you have any update on this or anything that you need to help with?

PhungVanDuy avatar Jul 17 '23 17:07 PhungVanDuy

I added from_pretrained_kwargs to the model config to add some flexibility to how the model is loaded.

When testing, I ran into 2 problems with ppo_sentiments_8bit.py when it executes the function generate :

  • Something similar to https://github.com/Vision-CAIR/MiniGPT-4/issues/74 : traceback_2.txt
  • An error due to a matrix in half precision instead being a boolean matrix : traceback.txt.

In both cases, it doesn't look related to trlx. Quantization can introduce bugs because it additionally relies on accelerate and bitsandbytes which also have dependencies, and there can be problems with the versions of different libraries. With the library versions listed in requirements.txt, I run into the 2nd problem. If I take with the latest versions, I run into the 1st one.

glerzing avatar Jul 17 '23 22:07 glerzing

@PhungVanDuy If you have time can you help to debug this? I think having lower precision inference and training options will be very useful.

Dahoas avatar Jul 21 '23 10:07 Dahoas

@glerzing Are you able to get quantized model inference working with our package requirements? (but without any training)

Dahoas avatar Jul 21 '23 10:07 Dahoas

No, when I have the version 4.28.1 of the transformers library like in trlx, I have RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half, when it's >= 4.30.0, I get RuntimeError: probability tensor contains either 'inf', 'nan' or element < 0, which I guess happens further in the processing (my guess is that this bug is also present with the version 4.28.1 but the processing doesn't go so far).

glerzing avatar Jul 21 '23 19:07 glerzing

Actually, adding the argument torch_dtype=torch.bfloat16 to from_pretrained and using a more recent version of the transformers library solves the issue, and enables to run ppo_sentiments_8bit.py.

glerzing avatar Jul 22 '23 15:07 glerzing

@glerzing @Dahoas I tried to run inference with 8-bit but I dont think this way could help inference faster: https://wandb.ai/pvduy/trlx/reports/8bit-Sentiment-Rollout--Vmlldzo0OTUxOTM5

This is also mentioned by the author here:

The main purpose of the LLM.int8() method is to make large models more accessible without performance degradation. But the method would be less useful if it is very slow. So we benchmarked the generation speed of multiple models. We find that BLOOM-176B with LLM.int8() is about 15% to 23% slower than the fp16 version – which is still quite acceptable. 

Let's come up with another idea like using vLLM, with my experiments vLLM actually boosts the inference time. I will work in that direction.

PhungVanDuy avatar Jul 23 '23 22:07 PhungVanDuy

Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.

On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising

Dahoas avatar Jul 24 '23 11:07 Dahoas

Thanks for checking this. Were you able to run this experiment with the trlX's pinned transformer's version? Or will we need to update it.

On the inference speedup side, vLLM seems like a good idea. In general implementing some kind of asynchronous PPO like v-trace seems promising

I have to update that one, I guess we should also update the transformer's version in terms of supporting LLaMA 2.

I am checking vLLM to see how hard to integrate. Thank you for your suggestion on asynchronous PPO.

PhungVanDuy avatar Jul 24 '23 12:07 PhungVanDuy

I was wondering if there should be an example of how to train 16-bit models. Because now that there is the config argument from_pretrained_kwargs you can easily set torch_dtype=torch.bfloat16, which doesn't seem obvious to newcomers. On the other side, I'm not sure whether it's worth adding another file ppo_sentiments_16bit.py just to show that we can easily do that.

glerzing avatar Aug 07 '23 15:08 glerzing

@glerzing Checking in on the state of this pr. Do you have any more features you would like to add? If not, let's get it merged sometime this week

Dahoas avatar Aug 28 '23 10:08 Dahoas