candle
candle copied to clipboard
Running models with different precisions
I am testing different model architectures, and when loading the model weights (e.g. for falcon or mamba architectures) with precision either bf16
or f16
I usually get this error:
Candle error: 'unexpected dtype, expected: F32, got: BF16'
I am running the examples on candle, and passing in precision of f16
or bf16
. Is there a way around this by tweaking the code ? Or should I load weights directly from f16/bf16
precision through some other repo in HuggingFace ?
Is it possible that it's only when retrieving the results back with to_vec<f32>
or equivalent. With a var store, the conversion should be handled for you. When retrieving results, we error out when converting values back, hence you have to call to_dtype
beforehand.
A good way to know where the issue is coming from is to enable RUST_BACKTRACE=1
, also you probably want the profile release-with-debug
or the equivalent from your project so that line numbers are properly related.
(just trying to guess here, happy to give a more in depth look if you can provide a simple repro)
Thank you @LaurentMazare ! The issue was, I believe, I was not converting the input tensor to a dtype other than f32
. I refactored the code from
for &token in tokens.iter() {
let input = Tensor::new(&[token], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(token)? {
output.push_str(t.as_str());
}
}
to
for &token in tokens.iter() {
let input = Tensor::new(&[token], &self.device)?.to_dtype(self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(token)? {
output.push_str(t.as_str());
}
}
However, running the later code on my Macbook (with Metal features) I get the following error:
Candle error:
Metal contiguous index_select BF16 BF16 not implemented`
Is it the case that current metal kernels do not support types other than f32
?
Most metal ops should support f32
, f16
, and bf16
, this one was missing somehow so I added it in #2035 That said, my macbook doesn't support bf16
so I wasn't able to really test but hopefully that will work for you.
Thanks a lot for the PR ! Unfortunately, I also have the same issue with other dtypes, including f16
:
Candle error: Metal contiguous index_select F16 F16 not implemented
This one is different, you can only index in a tensor with an integer tensors so u32 f16 makes sense but f16 f16 wouldn't as the index cannot be a float.
I see, right. It seems though that many of these models do not have support for f16
or bf16
. Without erroneously converting the indices to f16
, I am getting this error:
dtype mismatch in mul, lhs: F16, rhs: F32
.
I am running these experiments on mamba and falcon, and from the implementation it seems these models do not support other dtypes other than f32
(mamba state is hardcoded to be in f32 precision, whereas falcon the mask is also hardcoded on f32 precision.
I wonder, if it is possible to allow other precision types for these models (including f16
and bf16
) ?
Yeah there is no real limitation for this, I've made #2036 for mamba. It works with bf16
but not with f16
(which is somewhat expected, models trained in f32 or bf16 are likely to break with f16). On my RTX 4080, speed slightly increases from 320 token/s to 360 token/s so wouldn't consider it as a big improvement.
This is interesting, on my Macbook pro machine it works with f16
, but not with bf16
. Thanks for the PR @LaurentMazare, it would be great to have this for both Llama and Falcon models, too.