lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Expose `torch.compile` arguments as compile options

Open carmocca opened this issue 1 year ago • 4 comments

🚀 Feature

Motivation

The torch.compile call is internal and has fullgraph=True hardcoded. It would be useful to allow customizing it, especially after #140 lands.

https://github.com/Lightning-AI/lightning-thunder/blob/3a2b27e9b38e639133913c42c03acdf0e136ec4f/thunder/executors/torch_compile.py#L84

Pitch

from thunder.executors.torch_compile import torch_compile_ex

... = thunder.jit(..., executors=[torch_compile_ex], torch_compile_fullgraph=True, torch_compile_backend="reduce-overhead")

Use get_compile_option to allow customizing fullgraph= and backend=

https://github.com/Lightning-AI/lightning-thunder/blob/3a2b27e9b38e639133913c42c03acdf0e136ec4f/thunder/core/compile_data.py#L57

cc @apaz-cli

carmocca avatar Apr 25 '24 21:04 carmocca

Hey @carmocca! I am working on this issue and want to understand how things exactly work. Once I updated the code, is there a way to verify that this works as expected? I am not sure how I can test this out as the following code works with and without modification:

import thunder
import torch
from thunder.executors.torch_compile import torch_compile_executor

def foo(x):
    return torch.abs(x)

input = torch.randn(3,3)
input = input.to(torch.int32)
jfn = thunder.jit(foo, executors=[torch_compile_executor], torch_compile_fullgraph=False)
thunder_output = jfn(input)

(cc. @mruberry )

k223kim avatar Apr 30 '24 02:04 k223kim

You could maybe write a test that mocks the torch.compile call here: https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/executors/torch_compile.py#L84 and asserts that the correct values were passed.

You could also force a graph-break in your function and run it with the environment variable TORCH_LOGS="recompiles" set to check it.

carmocca avatar Apr 30 '24 12:04 carmocca

Hi @carmocca thanks for such suggestions. However, I am still unable to test this out would appreciate if you can provide me with a detailed explanation. Let me ask some specific questions. Please bare with me as I may ask trivial questions. I am trying to understand how compile works in more detail.

  • When will get_compile_option be invoked when we jit? I am going through each step and cannot figure out when get_compile_option comes into play
  • When do we use make_compiled in executors/torch_compile.py? I am trying to mock what we have in line 84 as you have mentioned, however, it is unclear how this is linked to jit.

Would appreciate your help :)

k223kim avatar May 02 '24 09:05 k223kim

get_compile_option doesn't get invoked automatically. You need to add it in your PR. If I'm not mistaken, right before torch.compile() gets called

make_compiled will get called on the first call of the model/function. After it has been thunder.jitted.

Don't hesitate to open a draft PR! It's probably easier to comment and help directly there

carmocca avatar May 06 '24 11:05 carmocca