Improve no_autocast overhead from 3.1 µs to 0.5 µs (6x improvement)
This PR decreases the overhead we have due to using the torch.autocast decorator for disabling PyTorch's autocast context.
Here's the script I used to measure the impact:
from thunder.executors.torchex import no_autocast
import timeit
def f(): pass
no_autocast_f = no_autocast(f)
iterations, timing = timeit.Timer("no_autocast_f()", globals=globals()).autorange()
print(f"Average time no_autocast(f): {(timing / iterations * 1e6):.3f} µs")
Before this change:
Average time no_autocast(f): 3.119 µs
With this change:
Average time no_autocast(f): 0.513 µs
For an empty jitted function this change gives 1.3x speedup decreasing the overhead by 7 µs:
import thunder
import timeit
def f():
pass
tf = thunder.jit(f)
tf()
iterations, timing = timeit.Timer("tf()", globals=globals()).autorange()
print(f"Average time thunder.jit(f): {(timing / iterations * 1e6):.3f} µs")
Can we get a big picture here, please? How much are we looking to save and what role does this play for the discussion in #169 ?
How does this interact with https://github.com/Lightning-AI/lightning-thunder/pull/169#discussion_r1577824718 ?
How does this interact with #169 (comment) ?
This change impacts all usages of no_autocast decorator including prologue, computation, but not the functions passed to torch.compile. We shouldn't optimize for torch.compile's limitations because it's a lot easier not to apply decorators to functions that should be accelerated by torch.compile, but apply it on compiled functions instead. In fact, it's already done in the same linked PR (https://github.com/Lightning-AI/lightning-thunder/pull/169#discussion_r1577973901) https://github.com/Lightning-AI/lightning-thunder/blob/fceb64efc93a80a27d38b8e84f0e2b5f132f3d2f/thunder/executors/torch_compile.py#L102-L108
Can we get a big picture here, please? How much are we looking to save
Sure, I'll create an issue.
My concern here from #169 is that I was under the impression that using the original decorators was special to torch.compile. If that's not the case I'm happy to merge this. @carmocca do you remember how this exactly worked?
Hi friends.
You should be able to check this by jitting a model with the torch.compile executor enabled, making sure that fullgraph=True is set (disallowing recompiles). It should be an easy test to run.
Back when I did #169, inductor had hard-coded support for torch.autocast and not an user-defined context managers. Whether that has changed or been extended in recent PyTorch releases I couldn't say.
But it seems like this PR doesn't add back a custom contextmanager (like we had in https://github.com/Lightning-AI/lightning-thunder/pull/169/files#diff-3c6a6ca64f7cd3508bcd348612f5aadea83a0506e521c1c8a232553f047d2321L148), it simply calls getters and setters which is more likely to be dynamo friendly.
From my 5 minute review, and trusting that Ivan has checked this code, this LGTM
Hope that helps!
Thank you @carmocca !
FAILED thunder/tests/test_autocast.py::test_torch_compile_autocast - torch._dynamo.exc.Unsupported:
Graph break due to unsupported builtin torch.set_autocast_enabled.
This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
seems dynamo still doesn't like that.
FAILED thunder/tests/test_autocast.py::test_torch_compile_autocast - torch._dynamo.exc.Unsupported: Graph break due to unsupported builtin torch.set_autocast_enabled. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.seems dynamo still doesn't like that.
Let's take a look at the source code of this test:
@no_autocast
def fn(x, y):
return x + y
cfn = torch.compile(fn, fullgraph=True)
This test exists because in Thunder's torch.compile executor the flow is to create a Python callable using Thunder's TraceCtx first which by default includes a no_autocast decorator (like in the test above) and pass the generated callable to torch.compile. But in #169 Carlos has disabled the generation of these decorators for functions we send to torch.compile: https://github.com/Lightning-AI/lightning-thunder/blob/908be577730fc4f3855a4e644cf69b8defd8e607/thunder/executors/torch_compile.py#L95
test_torch_compile_autocast should be removed. Thunder doesn't use the pattern that is tested there.
You should be able to check this by jitting a model with the torch.compile executor enabled, making sure that fullgraph=True is set (disallowing recompiles). It should be an easy test to run.
Thank you, Carlos. We use fullgraph=True by default: https://github.com/Lightning-AI/lightning-thunder/blob/908be577730fc4f3855a4e644cf69b8defd8e607/thunder/executors/torch_compile.py#L99-L101