candle
candle copied to clipboard
Quantized tensors load support with candle_nn::VarBuilder
trafficstars
there is multiple data type tensors in the quantized models(fp16, int32..), but candle_nn::VarBuilder only use same dtype to load all tensors. test with llama awq
eg:
#[test]
fn test_varbb() -> candle_core::Result<()> {
let device = candle_core::Device::new_cuda(0).unwrap();
let model_weight_files = vec!["./model.safetensors"];
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&model_weight_files, DType::F16, &device)?
};
let test_tensor0 = vb.pp("model.layers.1.self_attn.q_proj");
let test_tensor0 = test_tensor0.get(((5120_usize, 640_usize)), "qweight")?; // failed to get tensor with wrong dtype
println!("{:?}", test_tensor0.dtype());
let test_tensor1 = vb.pp("model.layers.0.input_layernorm");
let test_tensor1 = test_tensor1.get(5120_usize, "weight")?;
println!("{:?}", test_tensor1.dtype());
Ok(())
}
with python
#!/usr/bin/python3
from safetensors.numpy import save_file, load_file
loaded = load_file("./model.safetensors")
x=loaded['model.layers.1.self_attn.q_proj.qweight']
print(x.dtype, x.shape)
y=loaded['model.layers.0.input_layernorm.weight']
print(y.dtype, y.shape)
int32 (5120, 640)
float16 (5120,)
Hi, did this issue resolved?