jax icon indicating copy to clipboard operation
jax copied to clipboard

Add support for device kwarg in astype, and add matching utility func

Open Micky774 opened this issue 9 months ago • 9 comments

Towards https://github.com/google/jax/issues/20200

This PR adds a private device-placement utility jax._src.numpy.util._place_array to manage array API compliant array placement behavior for use in the jax.numpy namespace. Copies are mediated by the lax._array_copy utility, while device transfer is performed via api.device_put.

cc: @jakevdp

Micky774 avatar May 06 '24 16:05 Micky774

@yashk2810 Sorry for the delay in getting this message out. I've updated the PR with simplifications to the _place_array utility, as you suggested. Originally I intended to maximize strict compliance with array API semantics wrt setting copy=False while still specifying device, hence the attempts at detecting when a transfer is not required in the first place (in hindsight, the checks weren't quite right, since I'm still getting familiar with the sharding system).

Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard) would be a no-op if x has sharding equivalent to shard? If so, or if you think something like that is feasible, then it may be worth including here.

cc: @jakevdp if you have any insights regarding how we want to handle copy=False, device=... behavior here

Micky774 avatar May 10 '24 17:05 Micky774

The easiest thing would just be to error if copy=False and device is specified.

jakevdp avatar May 10 '24 17:05 jakevdp

Is there already a mechanism to predict when two shards are equivalent in the sense that calling device_put(x, shard) would be a no-op if x has sharding equivalent to shard?

sorry, I don't understand what this means. Also, why do you need such a mechanism?

yashk2810 avatar May 10 '24 17:05 yashk2810

sorry, I don't understand what this means. Also, why do you need such a mechanism?

I was looking for a way to tell if x already follows whatever sharding/layout device would specify. That way we could short-cut to a no-op and avoid raising an error in situations like _place_array(x, device=x.sharding, copy=False).

Micky774 avatar May 10 '24 18:05 Micky774

I was looking for a way to tell if x already follows whatever sharding/layout device would specify

device_put should do that for you :)

yashk2810 avatar May 10 '24 18:05 yashk2810

Yash - the issue is that the semantics of copy=False are "error if a copy is required", so device_put handling that transparently doesn't help because we need to know what's happening in order to know whether to error.

jakevdp avatar May 10 '24 19:05 jakevdp

Ok, then why not error instead of doing the no-op logic here? If you want to transfer or have it be a no-op should be device_put's job. If device is not None and copy=True, then we should error since it makes no sense right?

yashk2810 avatar May 10 '24 19:05 yashk2810

That's what I suggested above: the most conservative thing would be to error if copy and device are used together. But the more "correct" thing would be to somehow introspect whether or not device_put forces a copy in any particular case.

jakevdp avatar May 10 '24 19:05 jakevdp

We can abstract away a function which can determine that which we can share here but device_put is complex so let's error for now and file a bug against me to give you such a function which you can call here.

yashk2810 avatar May 10 '24 20:05 yashk2810