bitsandbytes icon indicating copy to clipboard operation
bitsandbytes copied to clipboard

Save 4bits llama model to Torchscript failed

Open dcy0577 opened this issue 1 year ago • 1 comments

System Info

bitsandbytes=0.42.0 transformers=4.37.1

Reproduction

import torch
import torch.nn as nn
import transformers
from transformers import BitsAndBytesConfig

class Wrapper(nn.Module):
    def __init__(self):
        super().__init__()

        # get the llama model
        llama_model = transformers.LlamaModel 

        bnb_config_4bit = BitsAndBytesConfig(  
                    load_in_4bit=True, 
                    bnb_4bit_use_double_quant=False, # whether to use double quantization
                    bnb_4bit_quant_type="nf4",  
                    bnb_4bit_compute_dtype=torch.float16) # 4 bits qlora

        # load the based model weights with quantization
        self.model = llama_model.from_pretrained(
                        "Llama_weights/llama-2-7b-hf-weights", 
                        low_cpu_mem_usage=True, 
                        device_map= 0, 
                        quantization_config=bnb_config_4bit,
                        torchscript=True,)
        
        self.model.output_hidden_states = False
        
    def forward(self, tokens_tensor):
        self.model.eval()
        o = self.model(tokens_tensor, output_hidden_states=False)
        return o[0]


model = Wrapper()
model.eval()    
with torch.no_grad():
    dummy_tokens_tensor = torch.randint(0, 1000, (1, 50), dtype=torch.long).to("cuda")
    outputs = model(dummy_tokens_tensor)
    trace_model = torch.jit.trace(model, [dummy_tokens_tensor]) # this works, but with some trace waring
    print("traced_model done")
    torch.jit.save(trace_model, "llama_4bit.pt") # --> error!

The linetorch.jit.save(trace_model, "llama_4bit.pt") # --> error! gives me error:

RuntimeError: 
Could not export Python function call 'MatMul4Bit'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/autograd/function.py(506): apply
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py(577): matmul_4bit
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/bitsandbytes/nn/modules.py(256): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(386): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(798): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py(1070): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/accelerate/hooks.py(165): new_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home///test/test_torchscript_llama.py(62): forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1488): _slow_forward
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/nn/modules/module.py(1501): _call_impl
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(1056): trace_module
/home//miniconda3/envs/t4rec_23.06/lib/python3.10/site-packages/torch/jit/_trace.py(794): trace

Expected behavior

troch.jit.save shall work properly for quantized model

dcy0577 avatar Feb 01 '24 12:02 dcy0577

Thanks a lot @dcy0577 ! I will have a look and loop back here

younesbelkada avatar Feb 01 '24 22:02 younesbelkada