lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Additional ThunderFX benchmark backend options

Open tfogal opened this issue 1 year ago • 3 comments

🚀 Feature

Add additional backends to ThunderFX benchmarking options: thunder.jit, torch.compile, torch.compile(backend="eager"), thunder w/ its CUDA graph transform, and the ability to add CUDA graphs around any of the above options.

Motivation

  • thunder.jit but without the torch.compile (i.e. inductor) executor. This can help identify if there are "switching costs" between executors.
  • torch.compile (no thunder at all). For small graphs, this is a close-enough proxy to inductor to be interesting from a kernels generated perspective
  • torch.compile(backend="eager"). This hints at whether dynamo or thunder is most relevant
  • thunder with the thunder.transforms.cudagraph.CUDAGraphTransform in the transforms arg.
  • The ability to apply a post-compilation torch.cuda.make_graphed_callables on the generated function, regardless of which backend it came from above. This lets us e.g. compare CUDA graphs for thunder.jit'd code as well as torch.compile'd code, which can help identify if we have a bug or a perf problem.

Pitch

I've been using code like this:

  backend = os.getenv("BACKEND")
  if backend == None or backend == "thunder":
    fqn = thunder.jit(DynamoModule(), transforms=[NvtxProfileTransform()])
  elif backend == "thunder-no-t.c":
    fqn = thunder.jit(DynamoModule(), executors=_execs)
  elif backend == "t.c":
    fqn = torch.compile(DynamoModule())
  elif backend == "dynamo-eager":
    fqn = torch.compile(DynamoModule(), backend="eager")
  elif backend == "thunder-cugraph":
    xform = thunder.transforms.cudagraph.CUDAGraphTransform()
    fqn = thunder.jit(DynamoModule(), transform=[xform])
  post_graph = os.getenv("POST_GRAPH", "0")
  if int(post_graph) > 0:
    fqn = torch.cuda.make_graphed_callables(
      fqn, inputs,
      num_warmup_iters=1, allow_unused_input=True
    )

for this part. I don't mean to imply that we should use BACKEND and/or POST_GRAPH env vars (or even that we should use environment variables at all), but something with similar functionality would be useful.

cc @crcrpar

tfogal avatar Oct 03 '24 20:10 tfogal

Hey @kiya00, Ivan and I signed you up for this in a comment on #1066 ;-). But please let us know if you're not keen on this!

tfogal avatar Oct 03 '24 20:10 tfogal

Hi @tfogal , thank you for filing this, let me think about how to add these options and get back to you

kiya00 avatar Oct 04 '24 06:10 kiya00

For info #1249 adds ThunderFX for the targets.py benchmarks

riccardofelluga avatar Oct 04 '24 07:10 riccardofelluga

Hi @tfogal Currently for each GraphModule after splitting, we can run with 3 backends: thunder.jit with the specified options, torch.compile with the specified options and eager, e.g.:

import torch
import thunder
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking

def test_func(benchmark):
    import thunder.tests.litgpt_model as litgpt_model
    from thunder.transforms.cudagraph import CUDAGraphTransform

    config = litgpt_model.Config.from_name("Llama-2-7b-hf")
    m = (litgpt_model.LLaMAMLP(config).to(device="cuda", dtype=torch.bfloat16).requires_grad_())

    from thunder.transforms.cudagraph import CUDAGraphTransform
    from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

    cgtransform = CUDAGraphTransform()
    # we follow the constructor of ThunderCompiler, the torch_inductor_options are passed to torch.compile, other options are for thunder.jit
    backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder","eager","inductor"], transforms=[cgtransform], torch_inductor_options={"backend":"eager"}) 

    # normally this should be torch.compile
    compiled = torch._dynamo.optimize(backend=backend)(m)

    shape = (4, config.block_size, config.n_embd)
    x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
    thunder_result = compiled(x)

run pytest thunder/tests/script.py the result is:

--------------------------------------------------------------------------------------------------------- benchmark: 3 tests ---------------------------------------------------------------------------------------------------------
Name (time in ms)                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder]      40.4653 (1.0)      43.4994 (1.0)      41.5081 (1.0)      0.7748 (1.25)     41.3242 (1.0)      0.6007 (1.0)           6;5  24.0917 (1.0)          33           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[eager]        41.5841 (1.03)     43.9790 (1.01)     42.4784 (1.02)     0.7016 (1.13)     42.3968 (1.03)     1.0052 (1.67)         16;0  23.5414 (0.98)         39           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[inductor]     42.0556 (1.04)     44.4405 (1.02)     43.1582 (1.04)     0.6217 (1.0)      43.0654 (1.04)     1.0168 (1.69)         13;0  23.1706 (0.96)         39           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

This setup allows us to support both the thunder.jit("thunder") and torch.compile("inductor") with any options passed in, do you think this approach is sufficient for our needs?

The ability to apply a post-compilation torch.cuda.make_graphed_callables on the generated function

I can add a flag, such as ThunderCompilerGraphBenchmarking(bench: BenchmarkFixture, executors: Sequence[str], post_graph: bool, **thunder_options,), when it's set True, the torch.cuda.make_graphed_callables will be applied to the compiled module

Sorry for the delayed response, we can also discuss this in more detail via a video call if you'd like.

kiya00 avatar Oct 16 '24 09:10 kiya00

The ability to apply a post-compilation torch.cuda.make_graphed_callables on the generated function

I can add a flag,

A flag sounds fine for now.

Currently for each GraphModule after splitting, we can run with 3 backends: thunder.jit with the specified options, torch.compile with the specified options and eager, e.g.:

Great! That mostly hits it. I think the only thing that's missing is the ability to disable thunder's torch.compile executor (which we use for cat etc.; I think you've just called it "inductor" above), which then forces those few ops to go to nvFuser or eager. This helps us understand whether the "cost to switch executors" exceeds the "time saved due to better kernels on the other executor".

tfogal avatar Oct 16 '24 16:10 tfogal

Hi @tfogal , since we want to run multiple backends, I'm thinking that your use case is probably better served if we modify the interface like this, WDYT? We pass in a dict of all compile functions we want to benchmark instead of the 3 kinds(thunder,torch.compile,eager). cc: @IvanYashchuk

from functools import partial
import torch
import thunder
from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking

def test_func(benchmark):
    import thunder.tests.litgpt_model as litgpt_model
    from thunder.transforms.cudagraph import CUDAGraphTransform

    config = litgpt_model.Config.from_name("Llama-2-7b-hf")
    m = (litgpt_model.LLaMAMLP(config).to(device="cuda", dtype=torch.bfloat16).requires_grad_())

    from thunder.transforms.cudagraph import CUDAGraphTransform
    from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

    cgtransform = CUDAGraphTransform()
    # key is the executor name you want to show in the test name, value is a callable to use as the compile function
    # note: key should avoid '-', because it'll mess up the group-by function
    bench_executors_dict = {} 
    bench_executors_dict["thunder"]=partial(thunder.jit, transforms=[NvtxProfileTransform()])
    bench_executors_dict["thunder_no_t.c"]=partial(thunder.jit, executors=list(ex for ex in thunder.get_default_executors() if ex !=thunder.executors.torch_compile.torch_compile_cat_ex))
    bench_executors_dict["t.c"]=torch.compile
    bench_executors_dict["dynamo_eager"]=partial(torch.compile, backend="eager")
    bench_executors_dict["thunder_cugraph"]=partial(thunder.jit, transform=cgtransform)
    bench_executors_dict["eager"]=None

    backend = ThunderCompilerGraphBenchmarking(benchmark, executors=bench_executors_dict) 

    # normally this should be torch.compile
    compiled = torch._dynamo.optimize(backend=backend)(m)

    shape = (4, config.block_size, config.n_embd)
    x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
    thunder_result = compiled(x)

output will be:

----------------------------------------------------------------------------- benchmark 'GraphID=GraphID[1] SplitModuleName=SplitModuleName[thunder_1]': 6 tests ----------------------------------------------------------------------------
Name (time in ms)                                                                 Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder]             39.4736 (1.0)      41.1901 (1.0)      40.2788 (1.0)      0.4671 (1.0)      40.3809 (1.01)     0.7684 (1.22)         10;0  24.8270 (1.0)          41           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder_no_t.c]      39.4927 (1.00)     42.0066 (1.02)     40.5177 (1.01)     0.6498 (1.39)     40.1213 (1.0)      0.6293 (1.0)           6;4  24.6806 (0.99)         29           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[t.c]                 40.9270 (1.04)     42.5549 (1.03)     41.6281 (1.03)     0.5423 (1.16)     41.5933 (1.04)     1.1297 (1.80)         20;0  24.0222 (0.97)         39           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[dynamo_eager]        40.9625 (1.04)     43.4503 (1.05)     42.1136 (1.05)     0.6384 (1.37)     42.0806 (1.05)     0.9188 (1.46)         12;0  23.7453 (0.96)         39           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder_cugraph]     41.0254 (1.04)     43.3476 (1.05)     41.8409 (1.04)     0.6388 (1.37)     41.6595 (1.04)     1.0024 (1.59)         11;0  23.9001 (0.96)         29           1
test_func-GraphID[1]-SplitModuleName[thunder_1]-executor[eager]               41.9157 (1.06)     44.1695 (1.07)     42.9516 (1.07)     0.6395 (1.37)     42.8717 (1.07)     1.2707 (2.02)         21;0  23.2820 (0.94)         39           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

and if the graph is segmented by dynamo and splitter, the result can be grouped by graphID and module name

kiya00 avatar Oct 16 '24 17:10 kiya00

I'm thinking that your use case is probably better served if we [accept executors as a dict of functors]

Yes, that looks great!

tfogal avatar Oct 16 '24 17:10 tfogal