xla icon indicating copy to clipboard operation
xla copied to clipboard

Slow compilation during diffusion sampling

Open nom opened this issue 1 year ago • 8 comments

🐛 Bug

I've adapted https://github.com/crowsonkb/k-diffusion/ for torch_xla use. In the evaluation step (diffusion sampling), there is a very high (>7minute) compilation time after a number of sampling steps. It happens in this function specifically: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L118 (also in other sampling functions in that file)

Strangely, it tends to sample fine at first and then sometimes slows down by about 7 minutes for an iteration in the for loop. So for example: first 22 iterations are fast, then 23rd iteration takes 7 minutes, then 24th iteration is fast then 25th iteration is slow again.

Diff between metric reports of slow and fast steps: https://www.diffchecker.com/ehgDrMqa/ (left is slow, right is fast) Here are the full reports.

Metric: DeviceLockWait
  TotalSamples: 6
  Accumulator: 065.150us
  ValueRate: 000.160us / second
  Rate: 0.0147674 / second
  Percentiles: 1%=004.250us; 5%=004.250us; 10%=004.250us; 20%=005.150us; 50%=015.360us; 80%=016.040us; 90%=019.170us; 95%=019.170us; 99%=019.170us
Metric: LazyTracing
  TotalSamples: 3937
  Accumulator: 07m46s391ms626.443us
  ValueRate: 1000ms936.804us / second
  Rate: 2.52092 / second
  Percentiles: 1%=000.190us; 5%=001.480us; 10%=003.300us; 20%=004.230us; 50%=007.020us; 80%=130.449us; 90%=155.690us; 95%=176.530us; 99%=247.650us
Metric: TensorsGraphSize
  TotalSamples: 3
  Accumulator: 104289.00
  ValueRate: 256.68 / second
  Rate: 0.00738372 / second
  Percentiles: 1%=5.00; 5%=5.00; 10%=5.00; 20%=5.00; 50%=5.00; 80%=104279.00; 90%=104279.00; 95%=104279.00; 99%=104279.00
Metric: UnwrapXlaData
  TotalSamples: 5
  Accumulator: 080.790us
  ValueRate: 000.199us / second
  Rate: 0.0123062 / second
  Percentiles: 1%=001.890us; 5%=001.890us; 10%=001.890us; 20%=002.030us; 50%=006.060us; 80%=064.510us; 90%=064.510us; 95%=064.510us; 99%=064.510us
Metric: WrapXlaData
  TotalSamples: 3
  Accumulator: 000.380us
  ValueRate: 000.001us / second
  Rate: 0.00738242 / second
  Percentiles: 1%=000.090us; 5%=000.090us; 10%=000.090us; 20%=000.090us; 50%=000.110us; 80%=000.180us; 90%=000.180us; 95%=000.180us; 99%=000.180us
Counter: CreateXlaTensor
  Value: 3779
Counter: DestroyLtcTensor
  Value: 3779
Counter: DestroyXlaTensor
  Value: 3779
Counter: TrimIrGraph
  Value: 1
Counter: UncachedCompile
  Value: 3
Counter: xla::_copy_from
  Value: 96
Counter: xla::_propagate_xla_data
  Value: 106
Counter: xla::_softmax
  Value: 24
Counter: xla::_to_cpu
  Value: 2
Counter: xla::_unsafe_view
  Value: 150
Counter: xla::add
  Value: 266
Counter: xla::bmm
  Value: 48
Counter: xla::cat
  Value: 32
Counter: xla::convolution_overrideable
  Value: 1
Counter: xla::copy
  Value: 96
Counter: xla::cos
  Value: 50
Counter: xla::div
  Value: 5
Counter: xla::empty_strided_symint
  Value: 11
Counter: xla::empty_symint
  Value: 12
Counter: xla::expand_copy_symint
  Value: 98
Counter: xla::ge
  Value: 1
Counter: xla::gelu
  Value: 26
Counter: xla::le
  Value: 1
Counter: xla::lerp
  Value: 2
Counter: xla::linspace
  Value: 2
Counter: xla::log
  Value: 1
Counter: xla::logical_not
  Value: 8
Counter: xla::masked_fill
  Value: 8
Counter: xla::mean
  Value: 55
Counter: xla::mm
  Value: 158
Counter: xla::mul
  Value: 531
Counter: xla::normal_
  Value: 1
Counter: xla::permute_copy
  Value: 116
Counter: xla::pow
  Value: 106
Counter: xla::reciprocal
  Value: 2
Counter: xla::roll
  Value: 32
Counter: xla::rsqrt
  Value: 101
Counter: xla::select_copy
  Value: 3
Counter: xla::sin
  Value: 50
Counter: xla::slice_copy
  Value: 412
Counter: xla::slice_scatter
  Value: 96
Counter: xla::split_copy
  Value: 27
Counter: xla::sqrt
  Value: 24
Counter: xla::stack
  Value: 1
Counter: xla::sub
  Value: 51
Counter: xla::sum
  Value: 48
Counter: xla::t_copy
  Value: 130
Counter: xla::transpose_copy
  Value: 66
Counter: xla::unbind_copy
  Value: 24
Counter: xla::unsqueeze_copy
  Value: 214
Counter: xla::view_copy_symint
  Value: 632
Counter: xla::zero_
  Value: 11
Metric: CompileTime
  TotalSamples: 3
  Accumulator: 07m44s971ms632.663us
  ValueRate: 995ms954.253us / second
  Rate: 0.00738881 / second
  Percentiles: 1%=025ms227.868us; 5%=025ms227.868us; 10%=025ms227.868us; 20%=025ms227.868us; 50%=027ms990.198us; 80%=07m44s918ms414.597us; 90%=07m44s918ms414.597us; 95%=07m44s918ms414.597us; 99%=07m44s918ms414.597us
Metric: ExecuteTime
  TotalSamples: 2
  Accumulator: 003ms782.130us
  ValueRate: 087ms927.958us / second
  Rate: 62.4902 / second
  Percentiles: 1%=001ms302.180us; 5%=001ms302.180us; 10%=001ms302.180us; 20%=001ms302.180us; 50%=001ms479.950us; 80%=001ms479.950us; 90%=001ms479.950us; 95%=001ms479.950us; 99%=001ms479.950us
Metric: InboundData
  TotalSamples: 2
  Accumulator: 2.00B
  ValueRate: 62.64B / second
  Rate: 62.6403 / second
  Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=1.00B; 50%=1.00B; 80%=1.00B; 90%=1.00B; 95%=1.00B; 99%=1.00B
Metric: TransferFromServerTime
  TotalSamples: 2
  Accumulator: 002ms992.370us
  ValueRate: 062ms409.340us / second
  Rate: 62.6483 / second
  Percentiles: 1%=918.840us; 5%=918.840us; 10%=918.840us; 20%=918.840us; 50%=001ms073.530us; 80%=001ms073.530us; 90%=001ms073.530us; 95%=001ms073.530us; 99%=001ms073.530us
Counter: CreateCompileHandles
  Value: 3
Counter: CreateDataHandles
  Value: 3
Counter: aten::_local_scalar_dense
  Value: 2

For a fast step, it is

Metric: DeviceLockWait
  TotalSamples: 4
  Accumulator: 024.091us
  ValueRate: 009.316us / second
  Rate: 1.54684 / second
  Percentiles: 1%=001.711us; 5%=001.711us; 10%=001.711us; 20%=001.711us; 50%=006.800us; 80%=011.460us; 90%=011.460us; 95%=011.460us; 99%=011.460us
Metric: LazyTracing
  TotalSamples: 3937
  Accumulator: 03s801ms148.970us
  ValueRate: 778ms711.898us / second
  Rate: 18086.1 / second
  Percentiles: 1%=000.180us; 5%=001.350us; 10%=003.210us; 20%=003.590us; 50%=005.880us; 80%=123.620us; 90%=146.770us; 95%=164.610us; 99%=208.420us
Metric: TensorsGraphSize
  TotalSamples: 2
  Accumulator: 10.00
  ValueRate: 3.87 / second
  Rate: 0.773416 / second
  Percentiles: 1%=5.00; 5%=5.00; 10%=5.00; 20%=5.00; 50%=5.00; 80%=5.00; 90%=5.00; 95%=5.00; 99%=5.00
Metric: UnwrapXlaData
  TotalSamples: 4
  Accumulator: 004.879us
  ValueRate: 001.886us / second
  Rate: 1.54633 / second
  Percentiles: 1%=000.560us; 5%=000.560us; 10%=000.560us; 20%=000.560us; 50%=001.800us; 80%=001.829us; 90%=001.829us; 95%=001.829us; 99%=001.829us
Metric: WrapXlaData
  TotalSamples: 2
  Accumulator: 000.149us
  ValueRate: 000.058us / second
  Rate: 0.773277 / second
  Percentiles: 1%=000.069us; 5%=000.069us; 10%=000.069us; 20%=000.069us; 50%=000.080us; 80%=000.080us; 90%=000.080us; 95%=000.080us; 99%=000.080us
Counter: CreateXlaTensor
  Value: 3779
Counter: DestroyLtcTensor
  Value: 3779
Counter: DestroyXlaTensor
  Value: 3779
Counter: UncachedCompile
  Value: 2
Counter: xla::_copy_from
  Value: 96
Counter: xla::_propagate_xla_data
  Value: 106
Counter: xla::_softmax
  Value: 24
Counter: xla::_to_cpu
  Value: 2
Counter: xla::_unsafe_view
  Value: 150
Counter: xla::add
  Value: 266
Counter: xla::bmm
  Value: 48
Counter: xla::cat
  Value: 32
Counter: xla::convolution_overrideable
  Value: 1
Counter: xla::copy
  Value: 96
Counter: xla::cos
  Value: 50
Counter: xla::div
  Value: 5
Counter: xla::empty_strided_symint
  Value: 11
Counter: xla::empty_symint
  Value: 12
Counter: xla::expand_copy_symint
  Value: 98
Counter: xla::ge
  Value: 1
Counter: xla::gelu
  Value: 26
Counter: xla::le
  Value: 1
Counter: xla::lerp
  Value: 2
Counter: xla::linspace
  Value: 2
Counter: xla::log
  Value: 1
Counter: xla::logical_not
  Value: 8
Counter: xla::masked_fill
  Value: 8
Counter: xla::mean
  Value: 55
Counter: xla::mm
  Value: 158
Counter: xla::mul
  Value: 531
Counter: xla::normal_
  Value: 1
Counter: xla::permute_copy
  Value: 116
Counter: xla::pow
  Value: 106
Counter: xla::reciprocal
  Value: 2
Counter: xla::roll
  Value: 32
Counter: xla::rsqrt
  Value: 101
Counter: xla::select_copy
  Value: 3
Counter: xla::sin
  Value: 50
Counter: xla::slice_copy
  Value: 412
Counter: xla::slice_scatter
  Value: 96
Counter: xla::split_copy
  Value: 27
Counter: xla::sqrt
  Value: 24
Counter: xla::stack
  Value: 1
Counter: xla::sub
  Value: 51
Counter: xla::sum
  Value: 48
Counter: xla::t_copy
  Value: 130
Counter: xla::transpose_copy
  Value: 66
Counter: xla::unbind_copy
  Value: 24
Counter: xla::unsqueeze_copy
  Value: 214
Counter: xla::view_copy_symint
  Value: 632
Counter: xla::zero_
  Value: 11
Metric: CompileTime
  TotalSamples: 2
  Accumulator: 063ms809.026us
  ValueRate: 024ms288.720us / second
  Rate: 0.773415 / second
  Percentiles: 1%=026ms037.518us; 5%=026ms037.518us; 10%=026ms037.518us; 20%=026ms037.518us; 50%=037ms771.508us; 80%=037ms771.508us; 90%=037ms771.508us; 95%=037ms771.508us; 99%=037ms771.508us
Metric: ExecuteTime
  TotalSamples: 3
  Accumulator: 05s261ms744.984us
  ValueRate: 03m01s526ms912.953us / second
  Rate: 102.947 / second
  Percentiles: 1%=001ms215.130us; 5%=001ms215.130us; 10%=001ms215.130us; 20%=001ms215.130us; 50%=03s558ms071.496us; 80%=03s701ms458.358us; 90%=03s701ms458.358us; 95%=03s701ms458.358us; 99%=03s701ms458.358us
Metric: InboundData
  TotalSamples: 2
  Accumulator: 2.00B
  ValueRate: 69.71B / second
  Rate: 69.7133 / second
  Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=1.00B; 50%=1.00B; 80%=1.00B; 90%=1.00B; 95%=1.00B; 99%=1.00B
Metric: TransferFromServerTime
  TotalSamples: 2
  Accumulator: 03s559ms142.285us
  ValueRate: 01m29s210ms316.985us / second
  Rate: 69.7189 / second
  Percentiles: 1%=716.229us; 5%=716.229us; 10%=716.229us; 20%=716.229us; 50%=03s558ms426.056us; 80%=03s558ms426.056us; 90%=03s558ms426.056us; 95%=03s558ms426.056us; 99%=03s558ms426.056us
Counter: CreateCompileHandles
  Value: 2
Counter: CreateDataHandles
  Value: 2
Counter: aten::_local_scalar_dense
  Value: 2

Is there something about this function that induces extra recompilation?

To Reproduce

I don't have a quick way for others to reproduce this. Hoping that the code sample above suffices.

Expected behavior

No extra compilation during sampling.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: torch-xla==2.2.0

nom avatar Feb 26 '24 17:02 nom

If I comment out the line x = x + d * dt it works fast for every step, so it seems like extra compilation is being done on that line for certain specific iterations in the loop.

nom avatar Feb 26 '24 18:02 nom

Can you run with PT_XLA_DEBUG=1? It will tell you the trigger of the compilation. There are 2 possibilities

  1. Somehow you are accessing the value of the tensor accidentally, computing that value of the tensor requires a compilation
  2. There are some conditional/value depend op in your code which will introduce slightly different graph at certain steps.

If it is the 1, PT_XLA_DEBUG should tell you where the trigger is from(if you are using 2.2 release). If it is 2, you can dump the IR/HLO graph and compare the graphs and the step that recompiles. Check https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations

JackCaoG avatar Feb 26 '24 19:02 JackCaoG

Thanks, compilation is triggered on this line: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/models/image_transformer_v2.py#L384

This does not access a tensor value if I understand einops.rearrange correctly.

I logged the shapes of the variables used in that line (x and skip) and they are the same across iterations, so I'm not sure if (2) applies?

nom avatar Feb 26 '24 20:02 nom

I dumped the HLO/IR graphs, but how can I see which belongs to which step? The dumps are ~1M lines long.

nom avatar Feb 26 '24 21:02 nom

Can you dump the PT_XLA_DEBUG=1 output? Want to know how einops.rearrange triggers the recompilation.

JackCaoG avatar Feb 26 '24 21:02 JackCaoG

btw you can also enabled the persisent cache following https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. This way it will recompiles once and result is saved for future runs.

JackCaoG avatar Feb 26 '24 21:02 JackCaoG

Can you dump the PT_XLA_DEBUG=1 output? Want to know how einops.rearrange triggers the recompilation.

Interestingly it triggers on different lines each time I run it. This is the latest:

Compilation Analysis: ================================================================================
Compilation Analysis: Compilation Cause
Compilation Analysis:   most likely user code trying to access tensor value before mark_step
Compilation Analysis: Graph Info:
Compilation Analysis:   Graph Hash: 507d309b2d23e1c0ff316e65fbb8d47a
Compilation Analysis:   Number of Graph Inputs: 229
Compilation Analysis:   Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis:   _apply_rotary_emb_inplace (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:196)
Compilation Analysis:   _fn (REDACTED/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:489)
Compilation Analysis:   __call__ (REDACTED/k-diffusion/k_diffusion/models/flags.py:56)
Compilation Analysis:   forward (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:205)
Compilation Analysis:   apply (REDACTED/.local/lib/python3.10/site-packages/torch/autograd/function.py:553)
Compilation Analysis:   apply_rotary_emb_ (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:222)
Compilation Analysis:   forward (REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:380)
Compilation Analysis:   _call_impl (REDACTED/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1520)
Compilation Analysis:   ..........
Compilation Analysis: --------------------------------------------------------------------------------
Compilation Analysis: ================================================================================

Which is this line respectively: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/models/image_transformer_v2.py#L196

nom avatar Feb 26 '24 23:02 nom

From what I can tell REDACTED/k-diffusion/k_diffusion/models/image_transformer_v2.py:196 this line trying to access the value of a torch_xla tensor before the execution. If you look at the output graph

Compilation Analysis:   Number of Graph Inputs: 229
Compilation Analysis:   Number of Graph Outputs: 1

it means it is trying to compute the value of one tensor... but this really confuse me since y1 = x1_ * cos - x2_ * sin this doesn't seems like it is doing that.. I started to wondering if it is the way I grab the python frame is problematic somehow. If you enabled XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 and dump the HLO graph by XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo". Can you find the graph with hash 507d309b2d23e1c0ff316e65fbb8d47a (or whichever new hash might be) and share it here?

JackCaoG avatar Feb 26 '24 23:02 JackCaoG