jax
jax copied to clipboard
Add support for `device` and `copy` kwargs in `from_dlpack` to match Array API
Towards https://github.com/google/jax/issues/20200
cf. https://github.com/data-apis/array-api/pull/741
[!NOTE] "In principle, arbitrary cross-device copies could be allowed too, but the consensus in https://github.com/data-apis/array-api/issues/626 was that limiting to device-to-host copies is enough for now". This PR includes the optional device-to-device transfer.
Default behavior is preserved when device=None, copy=None
Lint errors might indicate a problem:
jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device" [arg-type]
jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None" [arg-type]
Lint errors might indicate a problem:
jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device" [arg-type] jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None" [arg-type]
This is a bug in the xla_extension
stub annotations. I'll open a PR in XLA to resolve this. For now, I've added an inline ignore.
Internal pytype tests are failing with many variations of this error:
File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'
Can you change your commit message to something more informative? Thanks!
Internal pytype tests are failing with many variations of this error:
File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error] No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'
@jakevdp are the internal tests still failing? If so, I will update the typing to use the _Device
type to avoid letting this get stalled.
Still failing:
File "third_party/py/jax/__init__.py", line 163, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
Can't find pyi for 'jaxlib.xla_client', referenced from 'jax.numpy'
I think the issue is that trying to depend on xla_client
doesn't have an interface file. It's why we use Device = Any
in other similar locations, e.g. here: https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/numpy/init.pyi#L21-L22
and here:
https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/_src/basearray.pyi#L22-L23
This looks good, and is probably ready to merge more-or-less. However there's a subtlety here that I've been thinking about with @yashk2810 – the issue is that the behavior of device=XXX
under jit
and other transformations is kind of ambiguous: currently device_put
is a no-op in this context, which means that this function will silently ignore the device. I think in the short-term we'd prefer to make that an error: from_dlpack
should basically fail within any transformation, becuase its semantics are impure: it's reading an external buffer that's not tracked by JAX's normal tracing mechanisms, so e.g. if the buffer changes between the first and second call to the function, cacheing semantics may lead to incorrect outputs.
All of this is somewhat second-order though, so we should probably merge this change and iterate from there.