xla icon indicating copy to clipboard operation
xla copied to clipboard

Backward tooks so much time (nearly 44s, but in GPU is ~0.8s) in huge models

Open farhadinima75 opened this issue 3 years ago • 3 comments

🐛 Bug

I was trying to train a segmentation model on colab (free version) with a single core torch xla.

The model is huge and takes so much time in backward step.

Stuck in here for ~44s just for backward calculations: image

But for smaller models, it takes 3s to complete a cycle.

Input and Output shape: (16, 3, 256, 256) (16, 2, 256, 256) here is the report after the first iteration:

Metric: CompileTime
  TotalSamples: 19
  Accumulator: 722ms446.172us
  ValueRate: 008ms298.252us / second
  Rate: 0.21824 / second
  Percentiles: 1%=002ms861.164us; 5%=002ms861.164us; 10%=002ms450.221us; 20%=008ms803.322us; 50%=016ms614.016us; 80%=071ms913.140us; 90%=158ms419.131us; 95%=180ms318.931us; 99%=180ms318.931us
Metric: DeviceLockWait
  TotalSamples: 20
  Accumulator: 001ms129.113us
  ValueRate: 012.969us / second
  Rate: 0.229728 / second
  Percentiles: 1%=004.208us; 5%=004.441us; 10%=004.588us; 20%=004.711us; 50%=005.662us; 80%=007.207us; 90%=023.808us; 95%=998.744us; 99%=998.744us
Metric: ExecuteTime
  TotalSamples: 19
  Accumulator: 58s180ms918.458us
  ValueRate: 710ms706.824us / second
  Rate: 0.231771 / second
  Percentiles: 1%=002ms028.711us; 5%=002ms028.711us; 10%=003ms653.155us; 20%=028ms359.028us; 50%=02s147ms330.984us; 80%=05s991ms335.959us; 90%=06s014ms470.205us; 95%=14s408ms616.776us; 99%=14s408ms616.776us
Metric: InboundData
  TotalSamples: 19
  Accumulator: 1.78GB
  ValueRate: 22.73MB / second
  Rate: 0.236406 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=1.00MB; 50%=64.00MB; 80%=128.00MB; 90%=256.00MB; 95%=512.00MB; 99%=512.00MB
Metric: InputOutputAliasCount
  TotalSamples: 18
  Accumulator: 600.00
  ValueRate: 6.89 / second
  Rate: 0.206849 / second
  Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=3.00; 50%=3.00; 80%=3.00; 90%=3.00; 95%=555.00; 99%=555.00
Metric: IrValueTensorToXlaData
  TotalSamples: 245
  Accumulator: 15s192ms657.388us
  ValueRate: 165ms391.118us / second
  Rate: 2.66731 / second
  Percentiles: 1%=001ms366.195us; 5%=001ms478.816us; 10%=002ms575.362us; 20%=002ms687.775us; 50%=002ms018.077us; 80%=020ms163.611us; 90%=104ms163.374us; 95%=213ms295.221us; 99%=02s582ms906.310us
Metric: OutboundData
  TotalSamples: 270
  Accumulator: 1.13GB
  ValueRate: 12.61MB / second
  Rate: 2.9308 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=8.00B; 50%=1.00KB; 80%=512.00KB; 90%=9.00MB; 95%=16.00MB; 99%=128.00MB
Metric: ReleaseDataHandlesTime
  TotalSamples: 31
  Accumulator: 01s012ms027.950us
  ValueRate: 013ms613.706us / second
  Rate: 0.386378 / second
  Percentiles: 1%=001ms430.383us; 5%=001ms434.815us; 10%=002ms621.342us; 20%=003ms946.111us; 50%=005ms690.084us; 80%=013ms002.678us; 90%=071ms239.845us; 95%=212ms512.513us; 99%=312ms274.607us
Metric: TensorToData
  TotalSamples: 270
  Accumulator: 30s664ms358.184us
  ValueRate: 322ms072.148us / second
  Rate: 2.93145 / second
  Percentiles: 1%=001ms236.247us; 5%=001ms461.173us; 10%=002ms563.105us; 20%=002ms693.867us; 50%=002ms037.846us; 80%=014ms010.957us; 90%=104ms878.930us; 95%=213ms287.579us; 99%=02s608ms018.549us
Metric: TensorsGraphSize
  TotalSamples: 19
  Accumulator: 9367.00
  ValueRate: 107.59 / second
  Rate: 0.218241 / second
  Percentiles: 1%=4.00; 5%=4.00; 10%=4.00; 20%=71.00; 50%=325.00; 80%=509.00; 90%=628.00; 95%=4094.00; 99%=4094.00
Metric: TransferFromServerTime
  TotalSamples: 19
  Accumulator: 14s511ms862.286us
  ValueRate: 168ms107.536us / second
  Rate: 0.236406 / second
  Percentiles: 1%=002ms546.927us; 5%=002ms546.927us; 10%=017ms201.227us; 20%=020ms422.553us; 50%=251ms316.779us; 80%=01s175ms365.453us; 90%=02s993ms617.187us; 95%=04s273ms417.892us; 99%=04s273ms417.892us
Metric: TransferToServerTime
  TotalSamples: 270
  Accumulator: 30s644ms355.238us
  ValueRate: 322ms818.428us / second
  Rate: 2.93111 / second
  Percentiles: 1%=001ms222.412us; 5%=001ms442.896us; 10%=002ms542.858us; 20%=002ms665.676us; 50%=002ms989.946us; 80%=014ms968.429us; 90%=104ms846.026us; 95%=213ms229.207us; 99%=02s608ms968.433us
Metric: TransferToServerTransformTime
  TotalSamples: 270
  Accumulator: 778ms546.591us
  ValueRate: 008ms440.128us / second
  Rate: 2.9308 / second
  Percentiles: 1%=037.569us; 5%=046.240us; 10%=052.238us; 20%=065.222us; 50%=129.463us; 80%=551.956us; 90%=003ms573.463us; 95%=010ms633.565us; 99%=061ms107.966us
Metric: UnwrapXlaData
  TotalSamples: 5001
  Accumulator: 003ms520.459us
  ValueRate: 020.561us / second
  Rate: 66.9731 / second
  Percentiles: 1%=000.046us; 5%=000.047us; 10%=000.050us; 20%=000.053us; 50%=000.080us; 80%=000.250us; 90%=000.476us; 95%=000.703us; 99%=002.198us
Metric: WrapXlaData
  TotalSamples: 871
  Accumulator: 004ms757.688us
  ValueRate: 040.792us / second
  Rate: 9.45533 / second
  Percentiles: 1%=000.289us; 5%=000.295us; 10%=000.302us; 20%=000.312us; 50%=000.518us; 80%=009.572us; 90%=013.214us; 95%=015.614us; 99%=027.176us
Counter: CreateCompileHandles
  Value: 19
Counter: CreateDataHandles
  Value: 871
Counter: CreateXlaTensor
  Value: 2569
Counter: DestroyDataHandles
  Value: 299
Counter: DestroyXlaTensor
  Value: 2022
Counter: DeviceDataCacheMiss
  Value: 24
Counter: MarkStep
  Value: 1
Counter: ReleaseDataHandles
  Value: 299
Counter: UncachedCompile
  Value: 19
Counter: XRTAllocateFromTensor_Empty
  Value: 65
Counter: XrtCompile_Empty
  Value: 387
Counter: XrtExecuteChained_Empty
  Value: 384
Counter: XrtExecute_Empty
  Value: 387
Counter: XrtMemoryInfo_Empty
  Value: 384
Counter: XrtRead_Empty
  Value: 393
Counter: XrtReleaseAllocationHandle_Empty
  Value: 387
Counter: XrtReleaseCompileHandle_Empty
  Value: 384
Counter: XrtSessionCount
  Value: 4
Counter: XrtSubTuple_Empty
  Value: 384
Counter: aten::_local_scalar_dense
  Value: 1
Counter: aten::count_nonzero.dim_IntList
  Value: 1
Counter: aten::nonzero
  Value: 2
Counter: aten::prelu_backward
  Value: 14
Counter: xla::_copy_from
  Value: 951
Counter: xla::_to_cpu
  Value: 18
Counter: xla::adaptive_max_pool2d
  Value: 4
Counter: xla::adaptive_max_pool2d_backward
  Value: 4
Counter: xla::add
  Value: 261
Counter: xla::addcdiv_
  Value: 103
Counter: xla::addcmul
  Value: 103
Counter: xla::argmax
  Value: 1
Counter: xla::bernoulli_
  Value: 14
Counter: xla::binary_cross_entropy
  Value: 1
Counter: xla::binary_cross_entropy_backward
  Value: 1
Counter: xla::cat
  Value: 10
Counter: xla::convolution_backward_overrideable
  Value: 40
Counter: xla::convolution_overrideable
  Value: 70
Counter: xla::div
  Value: 131
Counter: xla::empty
  Value: 516
Counter: xla::eq
  Value: 1
Counter: xla::expand
  Value: 11
Counter: xla::fill_
  Value: 3
Counter: xla::ge
  Value: 1
Counter: xla::index_put_
  Value: 2
Counter: xla::lt
  Value: 1
Counter: xla::max
  Value: 5
Counter: xla::max_pool2d
  Value: 8
Counter: xla::mean
  Value: 9
Counter: xla::mul
  Value: 369
Counter: xla::native_batch_norm
  Value: 19
Counter: xla::native_batch_norm_backward
  Value: 19
Counter: xla::neg
  Value: 2
Counter: xla::nonzero
  Value: 2
Counter: xla::prelu
  Value: 14
Counter: xla::relu
  Value: 34
Counter: xla::rsub
  Value: 1
Counter: xla::scatter
  Value: 5
Counter: xla::select
  Value: 2
Counter: xla::sigmoid
  Value: 16
Counter: xla::sigmoid_backward
  Value: 11
Counter: xla::slice
  Value: 22
Counter: xla::sqrt
  Value: 103
Counter: xla::squeeze
  Value: 3
Counter: xla::sum
  Value: 12
Counter: xla::threshold_backward
  Value: 8
Counter: xla::unbind
  Value: 2
Counter: xla::unsqueeze
  Value: 2
Counter: xla::view
  Value: 3
Counter: xla::zero_
  Value: 212
Metric: XrtAllocateFromTensor
  TotalSamples: 5131
  Accumulator: 18s494ms546.177us
  Mean: 001ms449.975us
  StdDev: 007ms678.546us
  Rate: 1.23659 / second
  Percentiles: 25%=233.842us; 50%=305.679us; 80%=002ms038.856us; 90%=003ms306.599us; 95%=004ms711.623us; 99%=012ms759.300us
Metric: XrtCompile
  TotalSamples: 151
  Accumulator: 06m19s885ms728.558us
  Mean: 03s509ms170.388us
  StdDev: 05s203ms439.975us
  Rate: 0.0269124 / second
  Percentiles: 25%=008ms589.342us; 50%=162ms228.809us; 80%=04s569ms581.619us; 90%=10s525ms976.274us; 95%=14s169ms839.647us; 99%=25s328ms783.977us
Metric: XrtExecute
  TotalSamples: 1194
  Accumulator: 33m45s293ms512.954us
  Mean: 01s483ms222.346us
  StdDev: 02s462ms682.989us
  Rate: 0.211562 / second
  Percentiles: 25%=044ms045.291us; 50%=228ms794.137us; 80%=03s764ms507.283us; 90%=05s712ms269.956us; 95%=08s216ms841.280us; 99%=09s716ms908.911us
Metric: XrtReadLiteral
  TotalSamples: 3024
  Accumulator: 02m04s700ms007.964us
  Mean: 040ms054.487us
  StdDev: 082ms408.927us
  Rate: 0.350906 / second
  Percentiles: 25%=252.296us; 50%=007ms790.780us; 80%=047ms164.740us; 90%=140ms525.371us; 95%=213ms949.222us; 99%=414ms844.775us
Metric: XrtReleaseAllocation
  TotalSamples: 3590
  Accumulator: 35s101ms726.085us
  Mean: 916.331us
  StdDev: 008ms892.063us
  Rate: 0.448014 / second
  Percentiles: 25%=036.827us; 50%=295.057us; 80%=700.869us; 90%=001ms308.927us; 95%=002ms882.938us; 99%=007ms698.175us

Any hint to increase speed would be appreciated.

farhadinima75 avatar Aug 04 '22 14:08 farhadinima75

Ah, it is actually caused by some of the ops fall back to CPU

Counter: aten::_local_scalar_dense
  Value: 1
Counter: aten::count_nonzero.dim_IntList
  Value: 1
Counter: aten::nonzero
  Value: 2
Counter: aten::prelu_backward
  Value: 14

nonzero and count_nonzero.dim_IntList will likely requires dynamic shape support, which we are actively working with the pytorch team. @miladm can give you more update on that front. prelu_backward should be something we can support, but until dynamic shape feature being added, your model will take a significant performance hit.

I am not sure which ops's backward trigger nonzero,(or if nonzero is from the forward?), but unless you can remove it from the model, you will need to wait for the dynamic shape feature being pushed out.

JackCaoG avatar Aug 05 '22 19:08 JackCaoG

Tnx @JackCaoG Can you predict when this feature will be released approximately?

farhadinima75 avatar Aug 10 '22 06:08 farhadinima75

We are trying to do the experimental release for dynamic shape at the end of this year. The context is we already have a way to make nonzero op to be dynamic but then we need to make every other op to take and correctly handle the dynamic input generated by nonzero.

JackCaoG avatar Aug 10 '22 17:08 JackCaoG