lightning-thunder
lightning-thunder copied to clipboard
Expose `torch.compile` arguments as compile options
🚀 Feature
Motivation
The torch.compile call is internal and has fullgraph=True hardcoded. It would be useful to allow customizing it, especially after #140 lands.
https://github.com/Lightning-AI/lightning-thunder/blob/3a2b27e9b38e639133913c42c03acdf0e136ec4f/thunder/executors/torch_compile.py#L84
Pitch
from thunder.executors.torch_compile import torch_compile_ex
... = thunder.jit(..., executors=[torch_compile_ex], torch_compile_fullgraph=True, torch_compile_backend="reduce-overhead")
Use get_compile_option to allow customizing fullgraph= and backend=
https://github.com/Lightning-AI/lightning-thunder/blob/3a2b27e9b38e639133913c42c03acdf0e136ec4f/thunder/core/compile_data.py#L57
cc @apaz-cli
Hey @carmocca! I am working on this issue and want to understand how things exactly work. Once I updated the code, is there a way to verify that this works as expected? I am not sure how I can test this out as the following code works with and without modification:
import thunder
import torch
from thunder.executors.torch_compile import torch_compile_executor
def foo(x):
return torch.abs(x)
input = torch.randn(3,3)
input = input.to(torch.int32)
jfn = thunder.jit(foo, executors=[torch_compile_executor], torch_compile_fullgraph=False)
thunder_output = jfn(input)
(cc. @mruberry )
You could maybe write a test that mocks the torch.compile call here: https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/executors/torch_compile.py#L84 and asserts that the correct values were passed.
You could also force a graph-break in your function and run it with the environment variable TORCH_LOGS="recompiles" set to check it.
Hi @carmocca thanks for such suggestions. However, I am still unable to test this out would appreciate if you can provide me with a detailed explanation. Let me ask some specific questions. Please bare with me as I may ask trivial questions. I am trying to understand how compile works in more detail.
- When will
get_compile_optionbe invoked when wejit? I am going through each step and cannot figure out whenget_compile_optioncomes into play - When do we use
make_compiledinexecutors/torch_compile.py? I am trying to mock what we have in line 84 as you have mentioned, however, it is unclear how this is linked tojit.
Would appreciate your help :)
get_compile_option doesn't get invoked automatically. You need to add it in your PR. If I'm not mistaken, right before torch.compile() gets called
make_compiled will get called on the first call of the model/function. After it has been thunder.jitted.
Don't hesitate to open a draft PR! It's probably easier to comment and help directly there