candle
candle copied to clipboard
Phi-3 implementation seems to be buggy on metal devices
After running multiple times the command:
cargo run --release --features metal --example phi -- --model 3 --prompt "The best thing about coding in rust is "
I realized a very degrading performance in the token generation time, on my Macbook Pro M3. After profiling the issue, realized that with a repeat_penalty = 1.1
, there is roughly (on avg) 50secs spent on
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
and notably, the time taken is mostly on the following:
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
I find this weird, as my Macbook Pro M3 is very fast for a Mamba 2.8b model, which is roughly the size of Phi-3 mini. Also the above operation is an allocation of a data buffer of roughly 30/40 thousand f32
's, which is definitely not that large.
As a follow-up question, a few lines above, I see that the forward pass is done through:
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
which might have an impact on the display of the tensor and therefore it might affect the allocation. I can keep investigating though.
PS: I haven't tested this same command on cuda devices, yet.
One thing that might be tricky here is that the metal api might act in an asynchronous way. This could explain why you end up spending most of the time in the operation that retrieves the logits from the device. One way to check this would be to remove the repeat penalty and check that the slowness properly disappears there (there is also a device.synchronize()
that should ensure that all the gpu ops are complete at that point but I think it might be broken for metal at the moment).
After inspection, and remove the penalty. I realized there is also a considerable amount of time spent on
let next_token = self.logits_processor.sample(&logits)?;
Regarding the asynchronicity of the metal api, why doesn't cause issues for other small models, like mamba 2.8b or llama tiny ?
PS: after running on a RTX4090, I find it particularly fast to run inference on this model, roughly 90 tokens/sec.
I just gave it a try on my macbook M2 Pro 16GB and it's indeed extremly slow. I think it might be because we make the computations using f32
so with 3.8b parameters this fills the memory as it would take roughly the 4*3.8GB (also f32
computations are a lot slower than bf16
ones).
The model was initially created in bf16
which is what gets used with cuda but this is not available in candle at the moment, it's being worked on though. I've just added a new parameter so that you can specify the dtype via --dtype f16
for example. With this you can try using f16
, it seems to be a lot faster but note that f16
has a much narrower range than bf16
so it's quite possible that you would run into nans when trying to use it.