candle
candle copied to clipboard
Mamba model is broken with `f16` precision
Running the command:
cargo run --example mamba --release --features metal -- --prompt "Tell me a joke please" --dtype f16
does not work. The problem seems to lie in code:
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
where logits
is a Tensor
of null
values.