mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature Request] Cannot create tensor from raw bytes + dtypes

Open Narsil opened this issue 1 year ago • 3 comments

It is not possible at the moment create/manipulate tensors containing bfloat16 outside of MLX.

x = mx.array(jnp.ones((2, 2), dtype=jnp.bfloat16))

This is because it seems everything using memoryview and consorts fail on bf16. This would be important in safetensors in order to implement in-memory loading (weights = safetensors.mlx.load(f.read()) for instance) and also for lazy loading certain tensors.

mx.load("file.safetensors") works great.

As far as I can tell memoryview object loose the actual dtype anyway.

Any API that would get a bytes + shape + dtype would work super generically I feel (with or without copying depending on constraints). This would allow me to correctly implement all supported dtypes on MLX within safetensors itself. (And others to do advance stuff like loading files from network sockets directly)

Thanks a lot for this work.

Narsil avatar Jul 29 '24 13:07 Narsil

Quick comment re-reading this:

The feature is not to fix the memory view for JAX -> MLX, but really for a way to create tensors from raw bytes instead. (The memoryview just show cases the issue why it's necessary, but we cannot expect jax/tf to be existant for this to work).

Narsil avatar Jul 30 '24 07:07 Narsil

Can you give an example of what you mean / how that would look? As far as I understand the Python buffer protocol does not support bfloat16.

E.g. memoryview(jnp.ones((2, 2), dtype=jnp.bfloat16)) raises an error.

awni avatar Jul 30 '24 13:07 awni

@Narsil I'm still not fully understanding what API you are looking for / what's missing? Right now you can create an array from a Python memoryview object which should be pretty flexible:

a = np.array([1,2,3])

buffer = memoryview(a)
a_mx = mx.array(buffer)

Does that work for you or are you looking for something different? Maybe one thing that's missing is the lack of support for bfloat16?

awni avatar Aug 01 '24 20:08 awni