Awni Hannun
Awni Hannun
Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look...
@arnold-yan. I took a look at this benchmark. The performance issue turns out to be from the gradient of the second call to `nn.Upsample`. It uses nearest neighbor interpolation by...
@arnold-yan you're right the benchmark is slower with linear 😓 . I had a mistake. Let me keep digging.
Hi @arnold-yan https://github.com/ml-explore/mlx/pull/1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now: Machine | MLX | PT ----...
Can you give an example of what you mean / how that would look? As far as I understand the Python buffer protocol does not support bfloat16. E.g. `memoryview(jnp.ones((2, 2),...
@Narsil I'm still not fully understanding what API you are looking for / what's missing? Right now you can create an array from a Python memoryview object which should be...
I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model...
> The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage...
> Otherwise one node will get all the queries and no keys and so on. Ah that makes sense now. Some suggestions on alternative names: - shards - segments -...
@NripeshN are you planning to come back to this?