Yunfei Cheng
Results
2
comments of
Yunfei Cheng
Thanks @awni for the reply! I did another experiment to explicitly cast the dtypes to float32 before calling sdpa. Here's the code and output ```python import time import mlx.core as...
Sure! For example ``` q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.37887302093622566 ms all casted to mlx.core.float32, q mlx.core.bfloat16, k mlx.core.bfloat16, v mlx.core.float32, mask mlx.core.float32 takes 0.2908905468757439 ms...