xla icon indicating copy to clipboard operation
xla copied to clipboard

Re-land: Make `as_strided_copy` materialize a new tensor with `index`.

Open ysiraichi opened this issue 11 months ago • 6 comments

Re-land: #6624

This PR adds a fast path on top of #6624 changes.

Fast path: keep old behavior of as_strided_copy

  • Check that the size and strides specify a non-overlapping and dense tensor

Slow path: new behavior

  • Slower due to CPU dispatch and computation
  • Should work with any argument combination

cc @miladm @JackCaoG @lsy323

ysiraichi avatar Mar 08 '24 14:03 ysiraichi

I will test for the regression described here on the GPU machine I have access to.

ysiraichi avatar Mar 08 '24 14:03 ysiraichi

dynamo issue can be fixed by rebasing, fine to ignore.

JackCaoG avatar Mar 08 '24 18:03 JackCaoG

@lsy323 Could you help me checking if the regression is gone?

ysiraichi avatar Mar 08 '24 19:03 ysiraichi

Do we need this pr in the 2.3 release? It is a rather dangerous change, if we don;t have a strong reason I'd rather leave it in nightly for now.

JackCaoG avatar Mar 08 '24 22:03 JackCaoG

@vanbasten23 can you please help @ysiraichi benchmark this fix on TPU and confirm perf outcome?

@JackCaoG given the risk, I'd be ok we leave this PR out for 2.3

miladm avatar Mar 11 '24 17:03 miladm

yea, unless there is a strong reason I would prefer to leave this out of 2.3 releas.

JackCaoG avatar Mar 11 '24 18:03 JackCaoG

Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....

JackCaoG avatar Mar 18 '24 23:03 JackCaoG

Do we have bandwidth to test this one? Otherwise we can merge and see if DDP test started to fail tmr....

I'm running the tests in https://github.com/pytorch/xla/pull/6624#issuecomment-1984717508.

vanbasten23 avatar Mar 19 '24 03:03 vanbasten23

@ysiraichi sorry for the delayed response. I tested on my v3-8. Before this PR (master branch 6ac32233a238cfb351f9aa87dfd0308ecf547a96):

root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
Epoch 1 train begin 03:32:23
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=03:33:05
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=03:36:09
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=03:37:42

With the PR:

root@67df528db184:/ansible# PJRT_DEVICE=TPU python pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --log_steps=200 --ddp --pjrt_distributed --fake_data --batch_size=256
| Training Device=xla:0/4 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
Epoch 1 train begin 04:06:25
| Training Device=xla:0/0 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/3 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/5 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/1 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.89620 Rate=0.00 GlobalRate=0.00 Time=04:07:07
| Training Device=xla:1/7 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/1 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/5 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/0 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:1/3 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/4 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/6 Epoch=1 Step=200 Loss=0.05069 Rate=0.00 GlobalRate=0.00 Time=04:09:56
| Training Device=xla:0/2 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/6 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/4 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/1 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/3 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/7 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:0/0 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29
| Training Device=xla:1/5 Epoch=1 Step=400 Loss=0.01512 Rate=0.00 GlobalRate=0.00 Time=04:11:29

I don't see any slowdown. The change lgtm. Thanks Yukio.

vanbasten23 avatar Mar 19 '24 04:03 vanbasten23

Thanks, @vanbasten23.

ysiraichi avatar Mar 19 '24 13:03 ysiraichi