xla
xla copied to clipboard
Re-land: Make `as_strided_copy` materialize a new tensor with `index`.
Re-land: #6624
This PR adds a fast path on top of #6624 changes.
Fast path: keep old behavior of as_strided_copy
- Check that the size and strides specify a non-overlapping and dense tensor
Slow path: new behavior
- Slower due to CPU dispatch and computation
- Should work with any argument combination
cc @miladm @JackCaoG @lsy323
I will test for the regression described here on the GPU machine I have access to.
dynamo issue can be fixed by rebasing, fine to ignore.
@lsy323 Could you help me checking if the regression is gone?
Do we need this pr in the 2.3 release? It is a rather dangerous change, if we don;t have a strong reason I'd rather leave it in nightly for now.
@vanbasten23 can you please help @ysiraichi benchmark this fix on TPU and confirm perf outcome?
@JackCaoG given the risk, I'd be ok we leave this PR out for 2.3
yea, unless there is a strong reason I would prefer to leave this out of 2.3 releas.
Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....
Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....
I'm running the tests in https://github.com/pytorch/xla/pull/6624#issuecomment-1984717508.
@ysiraichi sorry for the delayed response. I tested on my v3-8. Before this PR (master branch 6ac32233a238cfb351f9aa87dfd0308ecf547a96):
root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
Epoch 1 train begin 03:32:23
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
With the PR:
root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
Epoch 1 train begin 04:06:25
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
I don't see any slowdown. The change lgtm. Thanks Yukio.
Thanks, @vanbasten23.