Backward tooks so much time (nearly 44s, but in GPU is ~0.8s) in huge models
🐛 Bug
I was trying to train a segmentation model on colab (free version) with a single core torch xla.
The model is huge and takes so much time in backward step.
Stuck in here for ~44s just for backward calculations:

But for smaller models, it takes 3s to complete a cycle.
Input and Output shape: (16, 3, 256, 256) (16, 2, 256, 256)
here is the report after the first iteration:
Metric: CompileTime
TotalSamples: 19
Accumulator: 722ms446.172us
ValueRate: 008ms298.252us / second
Rate: 0.21824 / second
Percentiles: 1%=002ms861.164us; 5%=002ms861.164us; 10%=002ms450.221us; 20%=008ms803.322us; 50%=016ms614.016us; 80%=071ms913.140us; 90%=158ms419.131us; 95%=180ms318.931us; 99%=180ms318.931us
Metric: DeviceLockWait
TotalSamples: 20
Accumulator: 001ms129.113us
ValueRate: 012.969us / second
Rate: 0.229728 / second
Percentiles: 1%=004.208us; 5%=004.441us; 10%=004.588us; 20%=004.711us; 50%=005.662us; 80%=007.207us; 90%=023.808us; 95%=998.744us; 99%=998.744us
Metric: ExecuteTime
TotalSamples: 19
Accumulator: 58s180ms918.458us
ValueRate: 710ms706.824us / second
Rate: 0.231771 / second
Percentiles: 1%=002ms028.711us; 5%=002ms028.711us; 10%=003ms653.155us; 20%=028ms359.028us; 50%=02s147ms330.984us; 80%=05s991ms335.959us; 90%=06s014ms470.205us; 95%=14s408ms616.776us; 99%=14s408ms616.776us
Metric: InboundData
TotalSamples: 19
Accumulator: 1.78GB
ValueRate: 22.73MB / second
Rate: 0.236406 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=1.00MB; 50%=64.00MB; 80%=128.00MB; 90%=256.00MB; 95%=512.00MB; 99%=512.00MB
Metric: InputOutputAliasCount
TotalSamples: 18
Accumulator: 600.00
ValueRate: 6.89 / second
Rate: 0.206849 / second
Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=3.00; 50%=3.00; 80%=3.00; 90%=3.00; 95%=555.00; 99%=555.00
Metric: IrValueTensorToXlaData
TotalSamples: 245
Accumulator: 15s192ms657.388us
ValueRate: 165ms391.118us / second
Rate: 2.66731 / second
Percentiles: 1%=001ms366.195us; 5%=001ms478.816us; 10%=002ms575.362us; 20%=002ms687.775us; 50%=002ms018.077us; 80%=020ms163.611us; 90%=104ms163.374us; 95%=213ms295.221us; 99%=02s582ms906.310us
Metric: OutboundData
TotalSamples: 270
Accumulator: 1.13GB
ValueRate: 12.61MB / second
Rate: 2.9308 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=8.00B; 50%=1.00KB; 80%=512.00KB; 90%=9.00MB; 95%=16.00MB; 99%=128.00MB
Metric: ReleaseDataHandlesTime
TotalSamples: 31
Accumulator: 01s012ms027.950us
ValueRate: 013ms613.706us / second
Rate: 0.386378 / second
Percentiles: 1%=001ms430.383us; 5%=001ms434.815us; 10%=002ms621.342us; 20%=003ms946.111us; 50%=005ms690.084us; 80%=013ms002.678us; 90%=071ms239.845us; 95%=212ms512.513us; 99%=312ms274.607us
Metric: TensorToData
TotalSamples: 270
Accumulator: 30s664ms358.184us
ValueRate: 322ms072.148us / second
Rate: 2.93145 / second
Percentiles: 1%=001ms236.247us; 5%=001ms461.173us; 10%=002ms563.105us; 20%=002ms693.867us; 50%=002ms037.846us; 80%=014ms010.957us; 90%=104ms878.930us; 95%=213ms287.579us; 99%=02s608ms018.549us
Metric: TensorsGraphSize
TotalSamples: 19
Accumulator: 9367.00
ValueRate: 107.59 / second
Rate: 0.218241 / second
Percentiles: 1%=4.00; 5%=4.00; 10%=4.00; 20%=71.00; 50%=325.00; 80%=509.00; 90%=628.00; 95%=4094.00; 99%=4094.00
Metric: TransferFromServerTime
TotalSamples: 19
Accumulator: 14s511ms862.286us
ValueRate: 168ms107.536us / second
Rate: 0.236406 / second
Percentiles: 1%=002ms546.927us; 5%=002ms546.927us; 10%=017ms201.227us; 20%=020ms422.553us; 50%=251ms316.779us; 80%=01s175ms365.453us; 90%=02s993ms617.187us; 95%=04s273ms417.892us; 99%=04s273ms417.892us
Metric: TransferToServerTime
TotalSamples: 270
Accumulator: 30s644ms355.238us
ValueRate: 322ms818.428us / second
Rate: 2.93111 / second
Percentiles: 1%=001ms222.412us; 5%=001ms442.896us; 10%=002ms542.858us; 20%=002ms665.676us; 50%=002ms989.946us; 80%=014ms968.429us; 90%=104ms846.026us; 95%=213ms229.207us; 99%=02s608ms968.433us
Metric: TransferToServerTransformTime
TotalSamples: 270
Accumulator: 778ms546.591us
ValueRate: 008ms440.128us / second
Rate: 2.9308 / second
Percentiles: 1%=037.569us; 5%=046.240us; 10%=052.238us; 20%=065.222us; 50%=129.463us; 80%=551.956us; 90%=003ms573.463us; 95%=010ms633.565us; 99%=061ms107.966us
Metric: UnwrapXlaData
TotalSamples: 5001
Accumulator: 003ms520.459us
ValueRate: 020.561us / second
Rate: 66.9731 / second
Percentiles: 1%=000.046us; 5%=000.047us; 10%=000.050us; 20%=000.053us; 50%=000.080us; 80%=000.250us; 90%=000.476us; 95%=000.703us; 99%=002.198us
Metric: WrapXlaData
TotalSamples: 871
Accumulator: 004ms757.688us
ValueRate: 040.792us / second
Rate: 9.45533 / second
Percentiles: 1%=000.289us; 5%=000.295us; 10%=000.302us; 20%=000.312us; 50%=000.518us; 80%=009.572us; 90%=013.214us; 95%=015.614us; 99%=027.176us
Counter: CreateCompileHandles
Value: 19
Counter: CreateDataHandles
Value: 871
Counter: CreateXlaTensor
Value: 2569
Counter: DestroyDataHandles
Value: 299
Counter: DestroyXlaTensor
Value: 2022
Counter: DeviceDataCacheMiss
Value: 24
Counter: MarkStep
Value: 1
Counter: ReleaseDataHandles
Value: 299
Counter: UncachedCompile
Value: 19
Counter: XRTAllocateFromTensor_Empty
Value: 65
Counter: XrtCompile_Empty
Value: 387
Counter: XrtExecuteChained_Empty
Value: 384
Counter: XrtExecute_Empty
Value: 387
Counter: XrtMemoryInfo_Empty
Value: 384
Counter: XrtRead_Empty
Value: 393
Counter: XrtReleaseAllocationHandle_Empty
Value: 387
Counter: XrtReleaseCompileHandle_Empty
Value: 384
Counter: XrtSessionCount
Value: 4
Counter: XrtSubTuple_Empty
Value: 384
Counter: aten::_local_scalar_dense
Value: 1
Counter: aten::count_nonzero.dim_IntList
Value: 1
Counter: aten::nonzero
Value: 2
Counter: aten::prelu_backward
Value: 14
Counter: xla::_copy_from
Value: 951
Counter: xla::_to_cpu
Value: 18
Counter: xla::adaptive_max_pool2d
Value: 4
Counter: xla::adaptive_max_pool2d_backward
Value: 4
Counter: xla::add
Value: 261
Counter: xla::addcdiv_
Value: 103
Counter: xla::addcmul
Value: 103
Counter: xla::argmax
Value: 1
Counter: xla::bernoulli_
Value: 14
Counter: xla::binary_cross_entropy
Value: 1
Counter: xla::binary_cross_entropy_backward
Value: 1
Counter: xla::cat
Value: 10
Counter: xla::convolution_backward_overrideable
Value: 40
Counter: xla::convolution_overrideable
Value: 70
Counter: xla::div
Value: 131
Counter: xla::empty
Value: 516
Counter: xla::eq
Value: 1
Counter: xla::expand
Value: 11
Counter: xla::fill_
Value: 3
Counter: xla::ge
Value: 1
Counter: xla::index_put_
Value: 2
Counter: xla::lt
Value: 1
Counter: xla::max
Value: 5
Counter: xla::max_pool2d
Value: 8
Counter: xla::mean
Value: 9
Counter: xla::mul
Value: 369
Counter: xla::native_batch_norm
Value: 19
Counter: xla::native_batch_norm_backward
Value: 19
Counter: xla::neg
Value: 2
Counter: xla::nonzero
Value: 2
Counter: xla::prelu
Value: 14
Counter: xla::relu
Value: 34
Counter: xla::rsub
Value: 1
Counter: xla::scatter
Value: 5
Counter: xla::select
Value: 2
Counter: xla::sigmoid
Value: 16
Counter: xla::sigmoid_backward
Value: 11
Counter: xla::slice
Value: 22
Counter: xla::sqrt
Value: 103
Counter: xla::squeeze
Value: 3
Counter: xla::sum
Value: 12
Counter: xla::threshold_backward
Value: 8
Counter: xla::unbind
Value: 2
Counter: xla::unsqueeze
Value: 2
Counter: xla::view
Value: 3
Counter: xla::zero_
Value: 212
Metric: XrtAllocateFromTensor
TotalSamples: 5131
Accumulator: 18s494ms546.177us
Mean: 001ms449.975us
StdDev: 007ms678.546us
Rate: 1.23659 / second
Percentiles: 25%=233.842us; 50%=305.679us; 80%=002ms038.856us; 90%=003ms306.599us; 95%=004ms711.623us; 99%=012ms759.300us
Metric: XrtCompile
TotalSamples: 151
Accumulator: 06m19s885ms728.558us
Mean: 03s509ms170.388us
StdDev: 05s203ms439.975us
Rate: 0.0269124 / second
Percentiles: 25%=008ms589.342us; 50%=162ms228.809us; 80%=04s569ms581.619us; 90%=10s525ms976.274us; 95%=14s169ms839.647us; 99%=25s328ms783.977us
Metric: XrtExecute
TotalSamples: 1194
Accumulator: 33m45s293ms512.954us
Mean: 01s483ms222.346us
StdDev: 02s462ms682.989us
Rate: 0.211562 / second
Percentiles: 25%=044ms045.291us; 50%=228ms794.137us; 80%=03s764ms507.283us; 90%=05s712ms269.956us; 95%=08s216ms841.280us; 99%=09s716ms908.911us
Metric: XrtReadLiteral
TotalSamples: 3024
Accumulator: 02m04s700ms007.964us
Mean: 040ms054.487us
StdDev: 082ms408.927us
Rate: 0.350906 / second
Percentiles: 25%=252.296us; 50%=007ms790.780us; 80%=047ms164.740us; 90%=140ms525.371us; 95%=213ms949.222us; 99%=414ms844.775us
Metric: XrtReleaseAllocation
TotalSamples: 3590
Accumulator: 35s101ms726.085us
Mean: 916.331us
StdDev: 008ms892.063us
Rate: 0.448014 / second
Percentiles: 25%=036.827us; 50%=295.057us; 80%=700.869us; 90%=001ms308.927us; 95%=002ms882.938us; 99%=007ms698.175us
Any hint to increase speed would be appreciated.
Ah, it is actually caused by some of the ops fall back to CPU
Counter: aten::_local_scalar_dense
Value: 1
Counter: aten::count_nonzero.dim_IntList
Value: 1
Counter: aten::nonzero
Value: 2
Counter: aten::prelu_backward
Value: 14
nonzero and count_nonzero.dim_IntList will likely requires dynamic shape support, which we are actively working with the pytorch team. @miladm can give you more update on that front. prelu_backward should be something we can support, but until dynamic shape feature being added, your model will take a significant performance hit.
I am not sure which ops's backward trigger nonzero,(or if nonzero is from the forward?), but unless you can remove it from the model, you will need to wait for the dynamic shape feature being pushed out.
Tnx @JackCaoG Can you predict when this feature will be released approximately?
We are trying to do the experimental release for dynamic shape at the end of this year. The context is we already have a way to make nonzero op to be dynamic but then we need to make every other op to take and correctly handle the dynamic input generated by nonzero.