xla icon indicating copy to clipboard operation
xla copied to clipboard

Lowering `unfold`

Open ibeltagy opened this issue 4 years ago • 25 comments

🚀 Feature

Add a lowering for unfold.

Motivation

I want to run Longformer (model code on HF repo) on pytroch-xla, and this requires an overlapping sliding window operation which needs a lowering for unfold.

Pitch

Add a lowering for unfold

Alternatives

Use as_strided but the current implementation is limited as discussed in this issue.

Additional context

Below is the metric report for the forward pass of Longformer with unfold. It has entries for aten::unfold.

Metric: CompileTime                                                                                                                                                                             [194/1996]
  TotalSamples: 40                                                                                                                                                                                        
  Accumulator: 06m12s060ms761.186us                                                                                                                                                                       
  ValueRate: 985ms703.787us / second                                                                                                                                                                      
  Rate: 0.105865 / second                                                                                                                                                                                 
  Percentiles: 1%=002ms604.019us; 5%=002ms103.276us; 10%=002ms209.085us; 20%=031ms487.158us; 50%=11s482ms222.482us; 80%=14s789ms136.836us; 90%=14s427ms259.848us; 95%=15s075ms200.017us; 99%=15s212ms201.$
81us                                                                                                                                                                                                      
Metric: DeviceLockWait                                                                                                                                                                                    
  TotalSamples: 73                                                                                                                                                                                        
  Accumulator: 277.621us                                                                                                                                                                                  
  ValueRate: 000.765us / second                                                                                                                                                                           
  Rate: 0.201229 / second                                                                                                                                                                                 
  Percentiles: 1%=002.159us; 5%=002.515us; 10%=002.707us; 20%=002.944us; 50%=003.671us; 80%=004.275us; 90%=004.708us; 95%=004.854us; 99%=015.004us                                                        
Metric: ExecuteTime                                                                                                                                                                                       
  TotalSamples: 73                                                                                                                                                                                        
  Accumulator: 03s919ms069.706us                                                                                                                                                                          
  ValueRate: 008ms722.713us / second                                                                                                                                                                      
  Rate: 0.193129 / second                                                                                                                                                                                 
  Percentiles: 1%=001ms485.104us; 5%=002ms714.332us; 10%=002ms000.342us; 20%=002ms237.048us; 50%=003ms337.952us; 80%=098ms610.960us; 90%=126ms721.599us; 95%=139ms781.481us; 99%=154ms800.680us           
Metric: InboundData                                                                                                                                                                                       
  TotalSamples: 72                                                                                                                                                                                        
  Accumulator: 234.19MB                                                                                                                                                                                   
  ValueRate: 634.49KB / second                                                                                                                                                                            
  Rate: 0.190499 / second                                                                                                                                                                                 
  Percentiles: 1%=1.00B; 5%=1.00B; 10%=1.00B; 20%=8.00KB; 50%=6.00MB; 80%=6.00MB; 90%=7.50MB; 95%=7.50MB; 99%=7.50MB                                                                                      
Metric: InputOutputAliasCount                                                                                                                                                                             
  TotalSamples: 1                                                                                                                                                                                         
  Accumulator: 271.00                                                                                                                                                                                     
  Percentiles: 1%=271.00; 5%=271.00; 10%=271.00; 20%=271.00; 50%=271.00; 80%=271.00; 90%=271.00; 95%=271.00; 99%=271.00                                                                                   
Metric: IrValueTensorToXlaData                                                                                                                                                                            
  TotalSamples: 331                                                                                                                                                                                       
  Accumulator: 03s006ms264.150us                                                                                                                                                                          
  ValueRate: 008ms922.872us / second                                                                                                                                                                      
  Rate: 0.872335 / second                                                                                                                                                                                 
  Percentiles: 1%=863.555us; 5%=967.491us; 10%=001ms069.569us; 20%=001ms215.703us; 50%=002ms606.635us; 80%=007ms211.581us; 90%=022ms513.355us; 95%=029ms074.835us; 99%=067ms409.847us                     
Metric: OutboundData                                                                                                                                                                                      
  TotalSamples: 335                                                                                                                                                                                       
  Accumulator: 1.01GB                                                                                                                                                                                     
  ValueRate: 2.73MB / second
  Rate: 0.881721 / second
  Percentiles: 1%=3.00KB; 5%=3.00KB; 10%=3.00KB; 20%=3.00KB; 50%=14.00KB; 80%=2.25MB; 90%=10.50MB; 95%=10.50MB; 99%=18.00MB
Metric: ReleaseDataHandlesTime
  TotalSamples: 81
  Accumulator: 333ms705.496us
  ValueRate: 880.219us / second
  Rate: 0.214297 / second
  Percentiles: 1%=382.511us; 5%=474.639us; 10%=522.986us; 20%=611.054us; 50%=001ms050.138us; 80%=001ms216.637us; 90%=003ms012.896us; 95%=031ms474.989us; 99%=038ms143.816us
Metric: TensorsGraphSize
  TotalSamples: 73
  Accumulator: 83903.00
  ValueRate: 222.06 / second
  Rate: 0.193203 / second
  Percentiles: 1%=4.00; 5%=4.00; 10%=4.00; 20%=23.00; 50%=67.00; 80%=2874.00; 90%=3673.00; 95%=4075.00; 99%=4474.00
Metric: TransferFromServerTime                                                                                                                                                                  [141/1996]
  TotalSamples: 72
  Accumulator: 850ms040.762us
  ValueRate: 002ms249.054us / second
  Rate: 0.190499 / second
  Percentiles: 1%=850.857us; 5%=001ms079.063us; 10%=001ms135.278us; 20%=001ms285.135us; 50%=015ms444.166us; 80%=021ms375.969us; 90%=027ms938.459us; 95%=030ms432.630us; 99%=046ms339.680us
Metric: TransferToServerTime
  TotalSamples: 335
  Accumulator: 03s025ms272.057us
  ValueRate: 008ms972.967us / second
  Rate: 0.882877 / second
  Percentiles: 1%=857.302us; 5%=959.191us; 10%=001ms060.268us; 20%=001ms210.822us; 50%=002ms606.569us; 80%=007ms260.753us; 90%=021ms492.181us; 95%=029ms982.476us; 99%=067ms384.995us
Metric: TransferToServerTransformTime
  TotalSamples: 335
  Accumulator: 460ms996.455us
  ValueRate: 001ms210.712us / second
  Rate: 0.881721 / second
  Percentiles: 1%=087.734us; 5%=094.554us; 10%=099.654us; 20%=107.230us; 50%=268.367us; 80%=612.733us; 90%=003ms313.737us; 95%=006ms138.063us; 99%=009ms517.447us
Counter: CachedCompile
  Value: 33
Counter: CreateCompileHandles
  Value: 40
Counter: CreateDataHandles
  Value: 692
Counter: CreateXlaTensor
  Value: 3897
Counter: DestroyDataHandles
  Value: 343
Counter: DestroyXlaTensor
  Value: 3608
Counter: MarkStep
  Value: 1
Counter: ReleaseDataHandles
  Value: 343
Counter: UncachedCompile
  Value: 40
Counter: XRTAllocateFromTensor_Empty
  Value: 20
Counter: XrtCompile_Empty
  Value: 144
Counter: XrtExecuteChained_Empty
  Value: 144
Counter: XrtExecute_Empty
  Value: 144
Counter: XrtRead_Empty
  Value: 144
Counter: XrtReleaseAllocationHandle_Empty
  Value: 144
Counter: XrtReleaseCompileHandle_Empty
  Value: 144
Counter: XrtSessionCount
  Value: 10
Counter: XrtSubTuple_Empty
  Value: 144
Counter: aten::_local_scalar_dense
  Value: 12
Counter: aten::unfold
  Value: 60
Counter: xla::_softmax
  Value: 12
Counter: xla::_unsafe_view
  Value: 72
Counter: xla::add
  Value: 27
Counter: xla::add_
  Value: 84
Counter: xla::addcmul
  Value: 25
Counter: xla::addmm
  Value: 1
Counter: xla::as_strided
  Value: 271
Counter: xla::bmm
  Value: 36
Counter: xla::clone
  Value: 24
Counter: xla::constant_pad_nd
  Value: 48
Counter: xla::copy_
  Value: 394
Counter: xla::cumsum
  Value: 1
Counter: xla::div_
  Value: 12
Counter: xla::embedding
  Value: 3
Counter: xla::empty
  Value: 359
Counter: xla::empty_strided
  Value: 271
Counter: xla::eq
  Value: 48
Counter: xla::expand
  Value: 120
Counter: xla::fill_
  Value: 36
Counter: xla::flip
  Value: 48
Counter: xla::gelu
  Value: 12
Counter: xla::gt
  Value: 12
Counter: xla::index_select
  Value: 3
Counter: xla::le
  Value: 12
Counter: xla::lt
  Value: 12
Counter: xla::masked_fill_
  Value: 72
Counter: xla::max
  Value: 12
Counter: xla::mm
 Value: 72
Counter: xla::mul
  Value: 2
Counter: xla::native_batch_norm
  Value: 25
Counter: xla::native_layer_norm
  Value: 25
Counter: xla::ne
  Value: 13
Counter: xla::permute
  Value: 180
Counter: xla::rsub
  Value: 1
Counter: xla::select
  Value: 97
Counter: xla::slice
  Value: 999
Counter: xla::squeeze
  Value: 24
Counter: xla::sum
  Value: 12
Counter: xla::t
  Value: 73
Counter: xla::tanh
  Value: 1
Counter: xla::transpose
  Value: 240
Counter: xla::tril
  Value: 24
Counter: xla::unsqueeze
  Value: 170
Counter: xla::view
  Value: 644
Counter: xla::zero_
  Value: 1
Metric: XrtAllocateFromTensor
  TotalSamples: 48135
  Accumulator: 01m10s487ms137.203us
  Mean: 002ms504.791us
  StdDev: 006ms961.073us
  Rate: 1.03083 / second
  Percentiles: 25%=295.798us; 50%=458.079us; 80%=002ms686.172us; 90%=003ms916.758us; 95%=004ms695.148us; 99%=008ms407.314us
Metric: XrtCompile
  TotalSamples: 2122
  Accumulator: 10m56s974ms699.040us
  Mean: 505ms763.352us
  StdDev: 02s338ms482.396us
  Rate: 0.114957 / second
  Percentiles: 25%=008ms570.206us; 50%=008ms862.980us; 80%=008ms259.798us; 90%=009ms638.784us; 95%=611ms324.713us; 99%=13s233ms291.015us
Metric: XrtExecute
  TotalSamples: 20796
  Accumulator: 02m59s103ms661.768us
  Mean: 004ms131.993us
  StdDev: 017ms650.704us
  Rate: 0.114971 / second
  Percentiles: 25%=851.542us; 50%=956.518us; 80%=001ms210.393us; 90%=002ms377.763us; 95%=006ms024.523us; 99%=110ms012.002us
Metric: XrtExecutorEvict
  TotalSamples: 0
  Accumulator: nanB
  Mean: nanB
  StdDev: nanB
  Percentiles:
Metric: XrtReadLiteral
  TotalSamples: 10335
  Accumulator: 05s641ms262.404us
  Mean: 774.616us
  StdDev: 002ms725.146us
  Rate: 0.114966 / second
  Percentiles: 25%=269.442us; 50%=343.087us; 80%=470.896us; 90%=583.015us; 95%=005ms062.565us; 99%=010ms496.053us
Metric: XrtReleaseAllocation
  TotalSamples: 34172
  Accumulator: 02s911ms410.970us
  Mean: 185.145us
  StdDev: 322.759us
  Rate: 0.115061 / second
  Percentiles: 25%=020.634us; 50%=033.172us; 80%=338.061us; 90%=648.733us; 95%=861.389us; 99%=002ms549.868us
Metric: XrtReleaseCompilation
  TotalSamples: 518
  Accumulator: 002ms770.287us
  Mean: 003.418us
  StdDev: 002.299us
  Rate: 81.2152 / second
  Percentiles: 25%=002.889us; 50%=003.118us; 80%=003.383us; 90%=003.659us; 95%=003.945us; 99%=019.823us

ibeltagy avatar Jun 18 '20 05:06 ibeltagy

Thanks for reporting @ibeltagy , we will take a look.

dlibenzi avatar Jun 18 '20 14:06 dlibenzi

Thanks, @dlibenzi.

ibeltagy avatar Jun 18 '20 15:06 ibeltagy

While loop based patch extraction is likely slower than convolution tricks:

https://github.com/tensorflow/tensorflow/blob/6116b7f9114f28dcffd685222285a8c5f7db3daa/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc#L43

dlibenzi avatar Jun 20 '20 00:06 dlibenzi

@dlibenzi, sorry, I am not sure I am following how this link is related to unfold

ibeltagy avatar Jun 20 '20 04:06 ibeltagy

I see. This is the c++ tensorflow version of torch.unfold. But this is not something that can be called from pytorch-xla?

ibeltagy avatar Jun 20 '20 04:06 ibeltagy

Cannot be called, but we can use the same idea (convolutions using kernels picking up one element at a time), for the forward.

dlibenzi avatar Jun 20 '20 14:06 dlibenzi

hey @JackCaoG, I am just curious if there are updates here.

ibeltagy avatar Jun 29 '20 20:06 ibeltagy

Hi @ibeltagy , I am working on the lowering part but it is a bit tricky. You will see the pr linked in this issue when it is ready 😄.

JackCaoG avatar Jun 29 '20 20:06 JackCaoG

Thanks, @JackCaoG for the forward function in your PR here. I ran your code and I successfully get Counter: xla::unfold and still get Counter: aten::unfold_backward as expected. There are a few issues though,

  • The following unfold takes close to 1 hour to compile the first time but it gets faster afterward. I know the first step is slow but is it expected to be 1 hour long? Notice that this is the trivial case where the seqlen == window size, so the output contains only one slice, no data copying is needed.
t.unfold(1, 512, 256)  # t.shape = torch.Size([12, 512, 64]), t.unfold.shape = torch.Size([12, 1, 64, 512])
  • The following unfold operation OOM even though it doesn't copy that much data. Is this expected? In comparison, a 12GB GPU has enough memory to process this
t.unfold(1, 512, 256)  # t.shape = torch.Size([12, 2048, 64]), t.unfold.shape = torch.Size([12, 7, 64, 512])

Thanks

ibeltagy avatar Jul 15 '20 21:07 ibeltagy

HI @ibeltagy I am not sure if 1 hour is too long, it really depends on your model size. Did you remember how much time it takes prior to the unfold change?

For the second question I think I have an idea. During the lowering of unfold, for input with shape [12, 2048, 64], size=512, step=256, it will generate two iota vector of size [12 * 2048 * 64 - 512, 1, 12 * 2048 * 64 -512 ] and a filter of the same size. It will then use slice to shrink the filter size with step. You can easily see that this filter and two intermediate vector is huge. I might be able to think of way to perform the slice on the iota vector instead after. That should save us some space but the filter itself is huge.

I chose this lowering is that convolution trick is likely much faster than the loop base approach. For pytorch native GPU unfold is just playing with the pointer and the stride, but for XLA we actually need to calculate the output and store it(unfold is not a view op on XLA). This is the downside with not being able to access the storage. Does this OOM issue block you from using XLA on this model?

JackCaoG avatar Jul 15 '20 21:07 JackCaoG

Did you remember how much time it takes prior to the unfold change?

around 5 minutes

[12 * 2048 * 64 - 512, 1, 12 * 2048 * 64 -512 ]

Yeah, this is huge and won't work.

convolution trick is likely much faster than the loop base approach.

Can you elaborate on what the loop-based approach is? is it a loop with multiple slice operations? If I implement this in the pytorch side, is it going to be as fast/slow as implementing it in the c++ side?

Does this OOM issue block you from using XLA on this model?

Yes, and the actual input is even larger, something like [16, 4096, 64], size=1024, step=512. The 4096 dimension is the sequence length, which is very long for Longformer.

ibeltagy avatar Jul 15 '20 22:07 ibeltagy

Hi @ibeltagy

5 minutes to 1 hour seems a big jump. One possibility is that unfold was not lowered prior to this change and it is a pretty complex lowering (transpose + iota*2 + eq + couple reshape + convolution + transpose). If the metric suggest that there is one compile then most likely the time is from the unfold. If you can dump the HLO graph I can double check that.

For the loop based approach, yes I was thinking about multiple slice operation. Pytorch Slice is created as a view in here . If this is implemented in c++ and you don't need the view property of the unfold, I think using xla::Slice directly in here will be faster(didn't tested but maintaining a viewInfo is pretty complex).

JackCaoG avatar Jul 15 '20 22:07 JackCaoG

If it is possible to implement unfold as a view, that would be the ideal solution because it won't waste any memory, which is the bottleneck in the Longformer model.

ibeltagy avatar Jul 15 '20 23:07 ibeltagy

Do you mind trying out the idea of splitting the tensor before unfold and concat the result afterward? something like

>>> torch.arange(12).reshape([2,2,3]).unfold(1, 2, 1)
tensor([[[[ 0,  3],
          [ 1,  4],
          [ 2,  5]]],


        [[[ 6,  9],
          [ 7, 10],
          [ 8, 11]]]])
>>> torch.arange(12).reshape([2,2,3]).split(1)[0].unfold(1, 2, 1)
tensor([[[[0, 3],
          [1, 4],
          [2, 5]]]])
>>> torch.arange(12).reshape([2,2,3]).split(1)[1].unfold(1, 2, 1)
tensor([[[[ 6,  9],
          [ 7, 10],
          [ 8, 11]]]])

I will try to see if I can reduce the memory usage of the current implemantion and think a bit more about the slice approach.

JackCaoG avatar Jul 16 '20 00:07 JackCaoG

I pushed a new change to the unfold pr, the peak memory usage should be reduced to 1/3 when step > 3.

JackCaoG avatar Jul 16 '20 02:07 JackCaoG

Will try both and let you know. Thanks.

ibeltagy avatar Jul 16 '20 03:07 ibeltagy

If you guys can post a simple repro, and dump the HLO graph, we could see what is going on.

print(torch_xla._XLAC._get_xla_tensors_hlo([unfold_result]))

dlibenzi avatar Jul 16 '20 16:07 dlibenzi

I tried the iterative slicing that you suggested and found it to work well. The memory usage is low enough that I can run the model on long sequences, and the model is fast enough (1.7x slower than a GPU that uses as_strided) that it is usable. Therefore, I don't think I will need the current lowering of unfold especially that it is memory expensive.

Here's another thing that can use your help, and please let me know if I should move it to a separate issue. Right now the model is 1.7x slower than GPU. If you guys have any insights on how to make it faster, that would be great. And, I don't think the iterative unfold vs. as_strided is the main contributor to the slowdown. I tried the model with this part of the code removed and it was still slower than on a GPU. The model code is here. It is the same as RoBERTa with the only difference being the selfattention operation. In particular, the two matrix multiplications here and here are replaced with the two functions _sliding_chunks_matmul_qk and _sliding_chunks_matmul_pv. I am also attaching the debug output which has a dump of the HLO graph debug.tar.gz.

ibeltagy avatar Jul 23 '20 07:07 ibeltagy

Hi @ibeltagy , glad to hear that you get the unfold working. Let's keep this thread about unfold and open a new issue for the performance optimization😄 .

JackCaoG avatar Jul 23 '20 17:07 JackCaoG

Sure. I will move the model optimization to a separate issue. One thing that's still relevant here is finding out if unfold can be lowered as a view without additional memory. The iterative unfold that I am using is just a temporary hack.

ibeltagy avatar Jul 23 '20 18:07 ibeltagy

Fore sure, we still want unfold to be lowered in a way that is usable for you. We are a small team we have to pick tasks carefully, since this is not a blocker for you it is likely to be in a lower priority (compare to your optimization for example). I will keep this thread alive and keep you updated.

JackCaoG avatar Jul 23 '20 19:07 JackCaoG

I tried the iterative slicing that you suggested and found it to work well. The memory usage is low enough that I can run the model on long sequences, and the model is fast enough (1.7x slower than a GPU that uses as_strided) that it is usable. Therefore, I don't think I will need the current lowering of unfold especially that it is memory expensive.

Here's another thing that can use your help, and please let me know if I should move it to a separate issue. Right now the model is 1.7x slower than GPU. If you guys have any insights on how to make it faster, that would be great. And, I don't think the iterative unfold vs. as_strided is the main contributor to the slowdown. I tried the model with this part of the code removed and it was still slower than on a GPU. The model code is here. It is the same as RoBERTa with the only difference being the selfattention operation. In particular, the two matrix multiplications here and here are replaced with the two functions _sliding_chunks_matmul_qk and _sliding_chunks_matmul_pv. I am also attaching the debug output which has a dump of the HLO graph debug.tar.gz.

Hi, @ibeltagy , I have similar issues when using unfold. Do you mind elaborating on how iterative slicing works? Maybe via an example?

JunwenBai avatar Aug 03 '20 05:08 JunwenBai

@JunwenBai I believe it is this function

JackCaoG avatar Aug 03 '20 05:08 JackCaoG

Hi, is aten::unfold lowered? I am getting the error below, not sure if there is a work around?

 UserWarning: 0The operator aten::unfold appears to be a view operator, but it has no implementation for the backend "xla:0". View operators don't support falling back to run on the CPU, since the tensor's storage cannot be shared across devices. (Triggered internally at  ../aten/src/ATen/native/CPUFallback.cpp:175.)

coleridge72 avatar Aug 01 '22 12:08 coleridge72

@coleridge72 I think unfold now get dispatched to im2col but we also don't have a lowering for that yet. You can follow up in https://github.com/pytorch/xla/issues/2932. This message is somewhat OK. You will get the right result but we fallback to CPU to execute unfold which will create speed penalty.

JackCaoG avatar Aug 01 '22 17:08 JackCaoG