jax icon indicating copy to clipboard operation
jax copied to clipboard

Add support for `device` and `copy` kwargs in `from_dlpack` to match Array API

Open Micky774 opened this issue 4 months ago • 6 comments

Towards https://github.com/google/jax/issues/20200

cf. https://github.com/data-apis/array-api/pull/741

[!NOTE] "In principle, arbitrary cross-device copies could be allowed too, but the consensus in https://github.com/data-apis/array-api/issues/626 was that limiting to device-to-host copies is enough for now". This PR includes the optional device-to-device transfer.

Default behavior is preserved when device=None, copy=None

Micky774 avatar Mar 11 '24 18:03 Micky774

Lint errors might indicate a problem:

jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device"  [arg-type]
jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None"  [arg-type]

jakevdp avatar Mar 11 '24 22:03 jakevdp

Lint errors might indicate a problem:

jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device"  [arg-type]
jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None"  [arg-type]

This is a bug in the xla_extension stub annotations. I'll open a PR in XLA to resolve this. For now, I've added an inline ignore.

Micky774 avatar Mar 11 '24 22:03 Micky774

Internal pytype tests are failing with many variations of this error:

File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'

jakevdp avatar Mar 12 '24 01:03 jakevdp

Can you change your commit message to something more informative? Thanks!

jakevdp avatar Mar 12 '24 19:03 jakevdp

Internal pytype tests are failing with many variations of this error:

File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'

@jakevdp are the internal tests still failing? If so, I will update the typing to use the _Device type to avoid letting this get stalled.

Micky774 avatar Mar 18 '24 12:03 Micky774

Still failing:

File "third_party/py/jax/__init__.py", line 163, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  Can't find pyi for 'jaxlib.xla_client', referenced from 'jax.numpy'

I think the issue is that trying to depend on xla_client doesn't have an interface file. It's why we use Device = Any in other similar locations, e.g. here: https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/numpy/init.pyi#L21-L22 and here: https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/_src/basearray.pyi#L22-L23

jakevdp avatar Mar 18 '24 16:03 jakevdp

This looks good, and is probably ready to merge more-or-less. However there's a subtlety here that I've been thinking about with @yashk2810 – the issue is that the behavior of device=XXX under jit and other transformations is kind of ambiguous: currently device_put is a no-op in this context, which means that this function will silently ignore the device. I think in the short-term we'd prefer to make that an error: from_dlpack should basically fail within any transformation, becuase its semantics are impure: it's reading an external buffer that's not tracked by JAX's normal tracing mechanisms, so e.g. if the buffer changes between the first and second call to the function, cacheing semantics may lead to incorrect outputs.

All of this is somewhat second-order though, so we should probably merge this change and iterate from there.

jakevdp avatar Apr 04 '24 17:04 jakevdp