candle icon indicating copy to clipboard operation
candle copied to clipboard

Mamba model is broken with `f16` precision

Open jorgeantonio21 opened this issue 1 year ago • 1 comments

Running the command:

cargo run --example mamba --release --features metal -- --prompt "Tell me a joke please" --dtype f16

does not work. The problem seems to lie in code:

for &t in tokens.iter() {
            let input = Tensor::new(&[t], &self.device)?;
            let logits = self.model.forward(&input, &mut state)?;
            next_logits = Some(logits);
            if let Some(t) = self.tokenizer.next_token(t)? {
                print!("{t}")
            }
        }

where logits is a Tensor of null values.

jorgeantonio21 avatar Jun 21 '24 09:06 jorgeantonio21

f16 has a far smaller range than f32 so it's quite common for models trained in f32 or bf16 to return some nans if you try to evaluate them in f16. Maybe you could try with the python version and see if it's the same? Alternatively you should be able to run it in bf16 though this will only work on cuda at the moment.

LaurentMazare avatar Jun 21 '24 15:06 LaurentMazare