xla
xla copied to clipboard
Slow compilation during diffusion sampling
🐛 Bug
I've adapted https://github.com/crowsonkb/k-diffusion/ for torch_xla use. In the evaluation step (diffusion sampling), there is a very high (>7minute) compilation time after a number of sampling steps. It happens in this function specifically: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L118 (also in other sampling functions in that file)
Strangely, it tends to sample fine at first and then sometimes slows down by about 7 minutes for an iteration in the for loop. So for example: first 22 iterations are fast, then 23rd iteration takes 7 minutes, then 24th iteration is fast then 25th iteration is slow again.
Diff between metric reports of slow and fast steps: https://www.diffchecker.com/ehgDrMqa/ (left is slow, right is fast) Here are the full reports.
Metric: DeviceLockWait
TotalSamples: 6
Accumulator: 065.150us
ValueRate: 000.160us / second
Rate: 0.0147674 / second
Percentiles: 1%=004.250us; 5%=004.250us; 10%=004.250us; 20%=005.150us; 50%=015.360us; 80%=016.040us; 90%=019.170us; 95%=019.170us; 99%=019.170us
Metric: LazyTracing
TotalSamples: 3937
Accumulator: 07m46s391ms626.443us
ValueRate: 1000ms936.804us / second
Rate: 2.52092 / second
Percentiles: 1%=000.190us; 5%=001.480us; 10%=003.300us; 20%=004.230us; 50%=007.020us; 80%=130.449us; 90%=155.690us; 95%=176.530us; 99%=247.650us
Metric: TensorsGraphSize
TotalSamples: 3
Accumulator: 104289.00
ValueRate: 256.68 / second
Rate: 0.00738372 / second
Percentiles: 1%=5.00; 5%=5.00; 10%=5.00; 20%=5.00; 50%=5.00; 80%=104279.00; 90%=104279.00; 95%=104279.00; 99%=104279.00
Metric: UnwrapXlaData
TotalSamples: 5
Accumulator: 080.790us
ValueRate: 000.199us / second
Rate: 0.0123062 / second
Percentiles: 1%=001.890us; 5%=001.890us; 10%=001.890us; 20%=002.030us; 50%=006.060us; 80%=064.510us; 90%=064.510us; 95%=064.510us; 99%=064.510us
Metric: WrapXlaData
TotalSamples: 3
Accumulator: 000.380us
ValueRate: 000.001us / second
Rate: 0.00738242 / second
Percentiles: 1%=000.090us; 5%=000.090us; 10%=000.090us; 20%=000.090us; 50%=000.110us; 80%=000.180us; 90%=000.180us; 95%=000.180us; 99%=000.180us
Counter: CreateXlaTensor
Value: 3779
Counter: DestroyLtcTensor
Value: 3779
Counter: DestroyXlaTensor
Value: 3779
Counter: TrimIrGraph
Value: 1
Counter: UncachedCompile
Value: 3
Counter: xla::_copy_from
Value: 96
Counter: xla::_propagate_xla_data
Value: 106
Counter: xla::_softmax
Value: 24
Counter: xla::_to_cpu
Value: 2
Counter: xla::_unsafe_view
Value: 150
Counter: xla::add
Value: 266
Counter: xla::bmm
Value: 48
Counter: xla::cat
Value: 32
Counter: xla::convolution_overrideable
Value: 1
Counter: xla::copy
Value: 96
Counter: xla::cos
Value: 50
Counter: xla::div
Value: 5
Counter: xla::empty_strided_symint
Value: 11
Counter: xla::empty_symint
Value: 12
Counter: xla::expand_copy_symint
Value: 98
Counter: xla::ge
Value: 1
Counter: xla::gelu
Value: 26
Counter: xla::le
Value: 1
Counter: xla::lerp
Value: 2
Counter: xla::linspace
Value: 2
Counter: xla::log
Value: 1
Counter: xla::logical_not
Value: 8
Counter: xla::masked_fill
Value: 8
Counter: xla::mean
Value: 55
Counter: xla::mm
Value: 158
Counter: xla::mul
Value: 531
Counter: xla::normal_
Value: 1
Counter: xla::permute_copy
Value: 116
Counter: xla::pow
Value: 106
Counter: xla::reciprocal
Value: 2
Counter: xla::roll
Value: 32
Counter: xla::rsqrt
Value: 101
Counter: xla::select_copy
Value: 3
Counter: xla::sin
Value: 50
Counter: xla::slice_copy
Value: 412
Counter: xla::slice_scatter
Value: 96
Counter: xla::split_copy
Value: 27
Counter: xla::sqrt
Value: 24
Counter: xla::stack
Value: 1
Counter: xla::sub
Value: 51
Counter: xla::sum
Value: 48
Counter: xla::t_copy
Value: 130
Counter: xla::transpose_copy
Value: 66
Counter: xla::unbind_copy
Value: 24
Counter: xla::unsqueeze_copy
Value: 214
Counter: xla::view_copy_symint
Value: 632
Counter: xla::zero_
Value: 11
Metric: CompileTime
TotalSamples: 3
Accumulator: 07m44s971ms632.663us
ValueRate: 995ms954.253us / second
Rate: 0.00738881 / second
Percentiles: 1%=025ms227.868us; 5%=025ms227.868us; 10%=025ms227.868us; 20%=025ms227.868us; 50%=027ms990.198us; 80%=07m44s918ms414.597us; 90%=07m44s918ms414.597us; 95%=07m44s918ms414.597us; 99%=07m44s918ms414.597us
Metric: ExecuteTime
TotalSamples: 2
Accumulator: 003ms782.130us
ValueRate: 087ms927.958us / second
Rate: 62.4902 / second
Percentiles: 1%=001ms302.180us; 5%=001ms302.180us; 10%=001ms302.180us; 20%=001ms302.180us; 50%=001ms479.950us; 80%=001ms479.950us; 90%=001ms479.950us; 95%=001ms479.950us; 99%=001ms479.950us
Metric: InboundData
TotalSamples: 2
Accumulator: 2.00B
ValueRate: 62.64B / second
Rate: 62.6403 / second
Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=1.00B; 50%=1.00B; 80%=1.00B; 90%=1.00B; 95%=1.00B; 99%=1.00B
Metric: TransferFromServerTime
TotalSamples: 2
Accumulator: 002ms992.370us
ValueRate: 062ms409.340us / second
Rate: 62.6483 / second
Percentiles: 1%=918.840us; 5%=918.840us; 10%=918.840us; 20%=918.840us; 50%=001ms073.530us; 80%=001ms073.530us; 90%=001ms073.530us; 95%=001ms073.530us; 99%=001ms073.530us
Counter: CreateCompileHandles
Value: 3
Counter: CreateDataHandles
Value: 3
Counter: aten::_local_scalar_dense
Value: 2
For a fast step, it is
Metric: DeviceLockWait
TotalSamples: 4
Accumulator: 024.091us
ValueRate: 009.316us / second
Rate: 1.54684 / second
Percentiles: 1%=001.711us; 5%=001.711us; 10%=001.711us; 20%=001.711us; 50%=006.800us; 80%=011.460us; 90%=011.460us; 95%=011.460us; 99%=011.460us
Metric: LazyTracing
TotalSamples: 3937
Accumulator: 03s801ms148.970us
ValueRate: 778ms711.898us / second
Rate: 18086.1 / second
Percentiles: 1%=000.180us; 5%=001.350us; 10%=003.210us; 20%=003.590us; 50%=005.880us; 80%=123.620us; 90%=146.770us; 95%=164.610us; 99%=208.420us
Metric: TensorsGraphSize
TotalSamples: 2
Accumulator: 10.00
ValueRate: 3.87 / second
Rate: 0.773416 / second
Percentiles: 1%=5.00; 5%=5.00; 10%=5.00; 20%=5.00; 50%=5.00; 80%=5.00; 90%=5.00; 95%=5.00; 99%=5.00
Metric: UnwrapXlaData
TotalSamples: 4
Accumulator: 004.879us
ValueRate: 001.886us / second
Rate: 1.54633 / second
Percentiles: 1%=000.560us; 5%=000.560us; 10%=000.560us; 20%=000.560us; 50%=001.800us; 80%=001.829us; 90%=001.829us; 95%=001.829us; 99%=001.829us
Metric: WrapXlaData
TotalSamples: 2
Accumulator: 000.149us
ValueRate: 000.058us / second
Rate: 0.773277 / second
Percentiles: 1%=000.069us; 5%=000.069us; 10%=000.069us; 20%=000.069us; 50%=000.080us; 80%=000.080us; 90%=000.080us; 95%=000.080us; 99%=000.080us
Counter: CreateXlaTensor
Value: 3779
Counter: DestroyLtcTensor
Value: 3779
Counter: DestroyXlaTensor
Value: 3779
Counter: UncachedCompile
Value: 2
Counter: xla::_copy_from
Value: 96
Counter: xla::_propagate_xla_data
Value: 106
Counter: xla::_softmax
Value: 24
Counter: xla::_to_cpu
Value: 2
Counter: xla::_unsafe_view
Value: 150
Counter: xla::add
Value: 266
Counter: xla::bmm
Value: 48
Counter: xla::cat
Value: 32
Counter: xla::convolution_overrideable
Value: 1
Counter: xla::copy
Value: 96
Counter: xla::cos
Value: 50
Counter: xla::div
Value: 5
Counter: xla::empty_strided_symint
Value: 11
Counter: xla::empty_symint
Value: 12
Counter: xla::expand_copy_symint
Value: 98
Counter: xla::ge
Value: 1
Counter: xla::gelu
Value: 26
Counter: xla::le
Value: 1
Counter: xla::lerp
Value: 2
Counter: xla::linspace
Value: 2
Counter: xla::log
Value: 1
Counter: xla::logical_not
Value: 8
Counter: xla::masked_fill
Value: 8
Counter: xla::mean
Value: 55
Counter: xla::mm
Value: 158
Counter: xla::mul
Value: 531
Counter: xla::normal_
Value: 1
Counter: xla::permute_copy
Value: 116
Counter: xla::pow
Value: 106
Counter: xla::reciprocal
Value: 2
Counter: xla::roll
Value: 32
Counter: xla::rsqrt
Value: 101
Counter: xla::select_copy
Value: 3
Counter: xla::sin
Value: 50
Counter: xla::slice_copy
Value: 412
Counter: xla::slice_scatter
Value: 96
Counter: xla::split_copy
Value: 27
Counter: xla::sqrt
Value: 24
Counter: xla::stack
Value: 1
Counter: xla::sub
Value: 51
Counter: xla::sum
Value: 48
Counter: xla::t_copy
Value: 130
Counter: xla::transpose_copy
Value: 66
Counter: xla::unbind_copy
Value: 24
Counter: xla::unsqueeze_copy
Value: 214
Counter: xla::view_copy_symint
Value: 632
Counter: xla::zero_
Value: 11
Metric: CompileTime
TotalSamples: 2
Accumulator: 063ms809.026us
ValueRate: 024ms288.720us / second
Rate: 0.773415 / second
Percentiles: 1%=026ms037.518us; 5%=026ms037.518us; 10%=026ms037.518us; 20%=026ms037.518us; 50%=037ms771.508us; 80%=037ms771.508us; 90%=037ms771.508us; 95%=037ms771.508us; 99%=037ms771.508us
Metric: ExecuteTime
TotalSamples: 3
Accumulator: 05s261ms744.984us
ValueRate: 03m01s526ms912.953us / second
Rate: 102.947 / second
Percentiles: 1%=001ms215.130us; 5%=001ms215.130us; 10%=001ms215.130us; 20%=001ms215.130us; 50%=03s558ms071.496us; 80%=03s701ms458.358us; 90%=03s701ms458.358us; 95%=03s701ms458.358us; 99%=03s701ms458.358us
Metric: InboundData
TotalSamples: 2
Accumulator: 2.00B
ValueRate: 69.71B / second
Rate: 69.7133 / second
Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=1.00B; 50%=1.00B; 80%=1.00B; 90%=1.00B; 95%=1.00B; 99%=1.00B
Metric: TransferFromServerTime
TotalSamples: 2
Accumulator: 03s559ms142.285us
ValueRate: 01m29s210ms316.985us / second
Rate: 69.7189 / second
Percentiles: 1%=716.229us; 5%=716.229us; 10%=716.229us; 20%=716.229us; 50%=03s558ms426.056us; 80%=03s558ms426.056us; 90%=03s558ms426.056us; 95%=03s558ms426.056us; 99%=03s558ms426.056us
Counter: CreateCompileHandles
Value: 2
Counter: CreateDataHandles
Value: 2
Counter: aten::_local_scalar_dense
Value: 2
Is there something about this function that induces extra recompilation?
To Reproduce
I don't have a quick way for others to reproduce this. Hoping that the code sample above suffices.
Expected behavior
No extra compilation during sampling.
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
- torch_xla version: torch-xla==2.2.0
If I comment out the line
x = x + d * dt
it works fast for every step, so it seems like extra compilation is being done on that line for certain specific iterations in the loop.
Can you run with PT_XLA_DEBUG=1
? It will tell you the trigger of the compilation. There are 2 possibilities
- Somehow you are accessing the value of the tensor accidentally, computing that value of the tensor requires a compilation
- There are some conditional/value depend op in your code which will introduce slightly different graph at certain steps.
If it is the 1, PT_XLA_DEBUG
should tell you where the trigger is from(if you are using 2.2 release). If it is 2, you can dump the IR/HLO graph and compare the graphs and the step that recompiles. Check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations
Thanks, compilation is triggered on this line: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/models/image_transformer_v2.py#L384
This does not access a tensor value if I understand einops.rearrange
correctly.
I logged the shapes of the variables used in that line (x
and skip
) and they are the same across iterations, so I'm not sure if (2) applies?
I dumped the HLO/IR graphs, but how can I see which belongs to which step? The dumps are ~1M lines long.
Can you dump the PT_XLA_DEBUG=1
output? Want to know how einops.rearrange
triggers the recompilation.
btw you can also enabled the persisent cache following https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. This way it will recompiles once and result is saved for future runs.
Can you dump the
PT_XLA_DEBUG=1
output? Want to know howeinops.rearrange
triggers the recompilation.
Interestingly it triggers on different lines each time I run it. This is the latest:
Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis: most likely user code trying to access tensor value before mark_step
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: 507d309b2d23e1c0ff316e65fbb8d47a
Compilation Analysis: Number of Graph Inputs: 229
Compilation Analysis: Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: _apply_rotary_emb_inplace (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:196)
Compilation Analysis: _fn (REDACTED/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489)
Compilation Analysis: __call__ (REDACTED/k-diffusion/k_diffusion/models/flags.py:56)
Compilation Analysis: forward (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:205)
Compilation Analysis: apply (REDACTED/.local/lib/python3.10/site-packages/torch/autograd/function.py:553)
Compilation Analysis: apply_rotary_emb_ (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:222)
Compilation Analysis: forward (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:380)
Compilation Analysis: _call_impl (REDACTED/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520)
Compilation Analysis: ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================
Which is this line respectively: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/models/image_transformer_v2.py#L196
From what I can tell REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:196
this line trying to access the value of a torch_xla tensor before the execution. If you look at the output graph
Compilation Analysis: Number of Graph Inputs: 229
Compilation Analysis: Number of Graph Outputs: 1
it means it is trying to compute the value of one tensor... but this really confuse me since y1 = x1_ * cos - x2_ * sin
this doesn't seems like it is doing that.. I started to wondering if it is the way I grab the python frame is problematic somehow. If you enabled XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1
and dump the HLO graph by XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"
. Can you find the graph with hash 507d309b2d23e1c0ff316e65fbb8d47a
(or whichever new hash might be) and share it here?