icefall
icefall copied to clipboard
Training with pytorch 2.0
PyTorch 2.0 is now available which now includes two major things that may be useful for faster training:
- TorchDynamo, which can be used by calling
torch.compile(model) - FLASH attention, by using
torch.backends.cuda.enable_flash_sdp(True).
I have built k2 with the new PyTorch (it requires minimum CUDA 11.7), and am trying to see if we can leverage the speed-ups. I am trying to train a small streaming Zipformer-transducer in this setup.
Flash attention seems to be working out-of-the-box without any changes. TorchDynamo, on the other hand, has some issues. It is clear that any custom k2 objects (such as RaggedTensor) cannot be "compiled", so I do the following: instead of torch.compile(model), I only do torch.compile(model.encoder) since the encoder does not contain any k2 objects (which are mainly used in the decoder and loss computation. I get the following error:
Traceback (most recent call last):
File "pruned_transducer_stateless7_streaming/train.py", line 1268, in <module>
main()
File "pruned_transducer_stateless7_streaming/train.py", line 1261, in main
run(rank=0, world_size=1, args=args)
File "pruned_transducer_stateless7_streaming/train.py", line 1139, in run
train_one_epoch(
File "pruned_transducer_stateless7_streaming/train.py", line 834, in train_one_epoch
scaler.scale(loss).backward()
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/autograd/function.py", line 274, in apply
return user_fn(self, *args)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2348, in backward
out = call_compiled_backward()
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2319, in call_compiled_backward
CompiledFunction.compiled_bw = aot_config.bw_compiler(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 38, in _wrapped_bw_compiler
return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 441, in bw_compiler
return inner_compile(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
graph.run(*example_inputs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 194, in run
return super().run(*args)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/graph.py", line 466, in run_node
if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3932, in has_exceeded_max_reads
self.num_reads() > config.realize_acc_reads_threshold
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/utils.py", line 212, in wrapper
setattr(self, key, fn(self))
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 3970, in num_reads
read_writes = ComputedBuffer(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/utils.py", line 212, in wrapper
setattr(self, key, fn(self))
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 2095, in get_read_writes
return extract_read_writes(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/dependencies.py", line 300, in extract_read_writes
fn(*args)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 434, in store_output
return ops.store(output_name, indexer(vars), self.inner_fn(vars))
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1593, in inner_fn
src_val = ops.masked(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/virtualized.py", line 104, in inner
line = getattr(self.parent_handler, name)(*args, **kwargs)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/virtualized.py", line 75, in masked
return f"masked({mask}, {body()}, {other})"
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 1595, in <lambda>
lambda: src_loader(src_idx),
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 762, in fn
return functools.reduce(
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/ir.py", line 765, in <genexpr>
value_fn(index, rindex)
File "/home/hltcoe/draj/.conda/envs/torch2/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3398, in loader
assert all(index[i] == 0 for i in reduced_idx)
AssertionError: While executing %slice_scatter : [#users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full, %sum_3, 0, 864, 9223372036854775807), kwargs = {})
Original traceback:
File "/exp/draj/mini_scale_2022/icefall/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py", line 1548, in forward
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
I will try to debug this a bit more, but just wanted to open this issue here in case someone has already come across this.
I could avoid the error by changing that line to:
src_extra = src[-1, ...].unsqueeze(dim=0).expand(pad, -1, -1)
I suppose TorchDynamo doesn't allow indexing with shapes. However, I keep getting OOM during compilation even after decreasing batch size significantly.
However, I think even without torch.compile(), just using the FLASH attention (by setting torch.backends.cuda.enable_flash_sdp(True)) improves memory efficiency. For LibriSpeech training setup (described in my original post), I am able to increase batch size from 500 to 800 on RTX GPUs.
After digging deeper, it seems flash attention was not being used after all. Here is a summary:
- The
scaled_dot_product_attention()in PyTorch 2.0 contains 3 implementations internally: FLASH attention, efficient attention, and a C++ attention of the regular SDP attention. - The first two implementations provide memory savings, but they do not yet work when an arbitrary attention mask is provided.
- In this case, the function falls back on the C++ implementation I suppose.
It seems support for attention masks may be available soon. But still, it would not be usable directly in the Zipformer encoder. This is because Flash attention does not explicitly compute attention weights, which is require for the Zipformer (since it re-uses these weights in the second self-attention in the block). I think it should be fine to just re-compute the SDP attention for a second time with the fused kernel, since it is apparently quite fast (mainly due to much less read/write between the HBM and SRAM). Anyway, I will try this once flash attention supports attention masks.
@desh2608 Which implementation are you refering here? I dont see any scaled_dot_product_attention() or MultiheadAttention() APIs being used in https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
Yeah I was talking about their implementation in PyTorch, not in icefall.