jax icon indicating copy to clipboard operation
jax copied to clipboard

Buffer donation for Metal plugin

Open dlwh opened this issue 11 months ago • 1 comments

Buffer donation would be nice. I don't see an issue for it so just opening it for tracking/asking if it's on the Apple JAX Metal Team's roadmap

>>> import jax
>>> jax.devices()
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-03-12 16:07:16.498439: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

[METAL(id=0)]
>>> import jax.numpy as jnp
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5))
...
...
...
... )
>>> x = jax.jit(lambda x: x)(jnp.zeros((4, 5)))
>>> x = jax.jit(lambda x: x, donate_args=True)(jnp.zeros((4, 5)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: jit() got an unexpected keyword argument 'donate_args'
>>> x = jax.jit(lambda x: x, donate_argnums=(0,))(jnp.zeros((4, 5)))
/opt/homebrew/Caskroom/miniforge/base/envs/jax_metal/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[4,5]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
>>>

dlwh avatar Mar 12 '24 23:03 dlwh

Thanks for requesting the feature, and we will track it and update here when we have a plan.

shuhand0 avatar Mar 13 '24 01:03 shuhand0