jax
jax copied to clipboard
Add support for device kwarg in astype, and add matching utility func
Towards https://github.com/google/jax/issues/20200
This PR adds a private device-placement utility jax._src.numpy.util._place_array
to manage array API compliant array placement behavior for use in the jax.numpy
namespace. Copies are mediated by the lax._array_copy
utility, while device transfer is performed via api.device_put
.
cc: @jakevdp
@yashk2810 Sorry for the delay in getting this message out. I've updated the PR with simplifications to the _place_array
utility, as you suggested. Originally I intended to maximize strict compliance with array API semantics wrt setting copy=False
while still specifying device
, hence the attempts at detecting when a transfer is not required in the first place (in hindsight, the checks weren't quite right, since I'm still getting familiar with the sharding system).
Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard)
would be a no-op if x
has sharding equivalent to shard
? If so, or if you think something like that is feasible, then it may be worth including here.
cc: @jakevdp if you have any insights regarding how we want to handle copy=False, device=...
behavior here
The easiest thing would just be to error if copy=False
and device
is specified.
Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard) would be a no-op if x has sharding equivalent to shard?
sorry, I don't understand what this means. Also, why do you need such a mechanism?
sorry, I don't understand what this means. Also, why do you need such a mechanism?
I was looking for a way to tell if x
already follows whatever sharding/layout device
would specify. That way we could short-cut to a no-op and avoid raising an error in situations like _place_array(x, device=x.sharding, copy=False)
.
I was looking for a way to tell if x already follows whatever sharding/layout device would specify
device_put should do that for you :)
Yash - the issue is that the semantics of copy=False
are "error if a copy is required", so device_put
handling that transparently doesn't help because we need to know what's happening in order to know whether to error.
Ok, then why not error instead of doing the no-op logic here? If you want to transfer or have it be a no-op should be device_put's job. If device is not None and copy=True, then we should error since it makes no sense right?
That's what I suggested above: the most conservative thing would be to error if copy
and device
are used together. But the more "correct" thing would be to somehow introspect whether or not device_put
forces a copy in any particular case.
We can abstract away a function which can determine that which we can share here but device_put is complex so let's error for now and file a bug against me to give you such a function which you can call here.