icefall icon indicating copy to clipboard operation
icefall copied to clipboard

Training with pytorch 2.0

Open desh2608 opened this issue 2 years ago • 5 comments

PyTorch 2.0 is now available which now includes two major things that may be useful for faster training:

  1. TorchDynamo, which can be used by calling torch.compile(model)
  2. FLASH attention, by using torch.backends.cuda.enable_flash_sdp(True).

I have built k2 with the new PyTorch (it requires minimum CUDA 11.7), and am trying to see if we can leverage the speed-ups. I am trying to train a small streaming Zipformer-transducer in this setup.

Flash attention seems to be working out-of-the-box without any changes. TorchDynamo, on the other hand, has some issues. It is clear that any custom k2 objects (such as RaggedTensor) cannot be "compiled", so I do the following: instead of torch.compile(model), I only do torch.compile(model.encoder) since the encoder does not contain any k2 objects (which are mainly used in the decoder and loss computation. I get the following error:

Traceback (most recent call last):
  File "pruned_transducer_stateless7_streaming/train.py", line 1268, in <module>
    main()
  File "pruned_transducer_stateless7_streaming/train.py", line 1261, in main
    run(rank=0, world_size=1, args=args)
  File "pruned_transducer_stateless7_streaming/train.py", line 1139, in run
    train_one_epoch(
  File "pruned_transducer_stateless7_streaming/train.py", line 834, in train_one_epoch
    scaler.scale(loss).backward()
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2348, in backward
    out = call_compiled_backward()
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2319, in call_compiled_backward
    CompiledFunction.compiled_bw = aot_config.bw_compiler(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 38, in _wrapped_bw_compiler
    return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 441, in bw_compiler
    return inner_compile(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
    graph.run(*example_inputs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 194, in run
    return super().run(*args)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
    self.env[node] = self.run_node(node)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 466, in run_node
    if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3932, in has_exceeded_max_reads
    self.num_reads() > config.realize_acc_reads_threshold
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/utils.py", line 212, in wrapper
    setattr(self, key, fn(self))
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3970, in num_reads
    read_writes = ComputedBuffer(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/utils.py", line 212, in wrapper
    setattr(self, key, fn(self))
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2095, in get_read_writes
    return extract_read_writes(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/dependencies.py", line 300, in extract_read_writes
    fn(*args)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 434, in store_output
    return ops.store(output_name, indexer(vars), self.inner_fn(vars))
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1593, in inner_fn
    src_val = ops.masked(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/virtualized.py", line 104, in inner
    line = getattr(self.parent_handler, name)(*args, **kwargs)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/virtualized.py", line 75, in masked
    return f"masked({mask}, {body()}, {other})"
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1595, in <lambda>
    lambda: src_loader(src_idx),
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 762, in fn
    return functools.reduce(
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 765, in <genexpr>
    value_fn(index, rindex)
  File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3398, in loader
    assert all(index[i] == 0 for i in reduced_idx)
AssertionError: While executing %slice_scatter : [#users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full, %sum_3, 0, 864, 9223372036854775807), kwargs = {})
Original traceback:
  File "/exp/draj/mini_scale_2022/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py", line 1548, in forward
    src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])

I will try to debug this a bit more, but just wanted to open this issue here in case someone has already come across this.

desh2608 avatar Mar 31 '23 22:03 desh2608

I could avoid the error by changing that line to:

src_extra = src[-1, ...].unsqueeze(dim=0).expand(pad, -1, -1)

I suppose TorchDynamo doesn't allow indexing with shapes. However, I keep getting OOM during compilation even after decreasing batch size significantly.

desh2608 avatar Apr 01 '23 12:04 desh2608

However, I think even without torch.compile(), just using the FLASH attention (by setting torch.backends.cuda.enable_flash_sdp(True)) improves memory efficiency. For LibriSpeech training setup (described in my original post), I am able to increase batch size from 500 to 800 on RTX GPUs.

desh2608 avatar Apr 01 '23 17:04 desh2608

After digging deeper, it seems flash attention was not being used after all. Here is a summary:

  • The scaled_dot_product_attention() in PyTorch 2.0 contains 3 implementations internally: FLASH attention, efficient attention, and a C++ attention of the regular SDP attention.
  • The first two implementations provide memory savings, but they do not yet work when an arbitrary attention mask is provided.
  • In this case, the function falls back on the C++ implementation I suppose.

It seems support for attention masks may be available soon. But still, it would not be usable directly in the Zipformer encoder. This is because Flash attention does not explicitly compute attention weights, which is require for the Zipformer (since it re-uses these weights in the second self-attention in the block). I think it should be fine to just re-compute the SDP attention for a second time with the fused kernel, since it is apparently quite fast (mainly due to much less read/write between the HBM and SRAM). Anyway, I will try this once flash attention supports attention masks.

desh2608 avatar Apr 03 '23 01:04 desh2608

@desh2608 Which implementation are you refering here? I dont see any scaled_dot_product_attention() or MultiheadAttention() APIs being used in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py

uni-sagar-raikar avatar Nov 30 '23 10:11 uni-sagar-raikar

Yeah I was talking about their implementation in PyTorch, not in icefall.

desh2608 avatar Nov 30 '23 14:11 desh2608