candle icon indicating copy to clipboard operation
candle copied to clipboard

Quantized tensors load support with candle_nn::VarBuilder

Open yinqiwen opened this issue 1 year ago • 1 comments
trafficstars

there is multiple data type tensors in the quantized models(fp16, int32..), but candle_nn::VarBuilder only use same dtype to load all tensors. test with llama awq

eg:

#[test]
fn test_varbb() -> candle_core::Result<()> {
    let device = candle_core::Device::new_cuda(0).unwrap();

    let model_weight_files = vec!["./model.safetensors"];
    let vb = unsafe {
        candle_nn::VarBuilder::from_mmaped_safetensors(&model_weight_files, DType::F16, &device)?
    };
    let test_tensor0 = vb.pp("model.layers.1.self_attn.q_proj");
    let test_tensor0 = test_tensor0.get(((5120_usize, 640_usize)), "qweight")?; // failed to get tensor with wrong dtype
    println!("{:?}", test_tensor0.dtype());

    let test_tensor1 = vb.pp("model.layers.0.input_layernorm");
    let test_tensor1 = test_tensor1.get(5120_usize, "weight")?;
    println!("{:?}", test_tensor1.dtype());

    Ok(())
}

with python

#!/usr/bin/python3

from safetensors.numpy import save_file, load_file

loaded = load_file("./model.safetensors")
x=loaded['model.layers.1.self_attn.q_proj.qweight']
print(x.dtype, x.shape)

y=loaded['model.layers.0.input_layernorm.weight']
print(y.dtype, y.shape)
int32 (5120, 640)
float16 (5120,)

yinqiwen avatar Mar 19 '24 10:03 yinqiwen

Hi, did this issue resolved?

lucasjinreal avatar Mar 09 '25 04:03 lucasjinreal