Add graph-by-graph benchmarking of dynamo.ThunderCompiler
As a part of https://github.com/Lightning-AI/lightning-thunder/issues/915.
After https://github.com/Lightning-AI/lightning-thunder/pull/947 torch.compile can have thunder.dynamo.ThunderCompiler as a backend. This PR adds a graph-by-graph benchmarking of different executors(eager, inductor, thunder) working on each torch.fx.GraphModule after splitting and gives perf and peak allocated memory information. Moreover, a customized grouping option is supported --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName' to group the test cases based on GraphID and SplitModuleName, allowing for performance comparison between different executors.(see examples in the comments below).
The added benchmarking structure is thunder.benchmarks.ThunderCompilerGraphBenchmarking, an example of how to use it is in thunder.benchmarks.targets.test_dynamo_LlamaMLPBenchmark
Example outputs in https://github.com/Lightning-AI/lightning-thunder/pull/1066#issuecomment-2317778178.
Examples: Dynamo segments the graph into some subgraphs, each identified by the 'GraphId[id]' field in the test name. Each subgraph can contain multiple split modules, processed by the Thunder-defined splitter, which correspond to the 'SplitModuleName[split_module_name]' field. The currently active executor is indicated by the suffix, e.g. '_thunder' suffix.
Run pytest thunder/benchmarks/targets.py -k test_dynamo -vs
outputs:
------------------------------------------------------------------------------------------------------------------- benchmark: 3 tests -------------------------------------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_dynamo_LlamaMLPBenchmark-GraphID[1]-SplitModuleName[thunder_1]-executor[eager] 19.9992 (1.0) 21.5106 (1.01) 20.6026 (1.0) 0.4169 (1.29) 20.5916 (1.0) 0.6059 (1.0) 26;0 48.5377 (1.0) 77 1
test_dynamo_LlamaMLPBenchmark-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder] 20.1475 (1.01) 21.4838 (1.00) 20.8078 (1.01) 0.3730 (1.15) 20.7868 (1.01) 0.6459 (1.07) 15;0 48.0589 (0.99) 78 1
test_dynamo_LlamaMLPBenchmark-GraphID[1]-SplitModuleName[thunder_1]-executor[inductor] 20.2960 (1.01) 21.3948 (1.0) 20.8300 (1.01) 0.3232 (1.0) 20.6199 (1.00) 0.6218 (1.03) 33;0 48.0076 (0.99) 77 1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
import torch
from thunder.benchmarks import ThunderCompilerGraphBenchmarking
def func1(x, y):
x = torch.sin(x)
if x.sum() > 0:
x = x.exp()
y = torch.sinc(x) + torch.cos(y)
return y + 1
else:
y = y.exp()
x = torch.sinc(y) + torch.cos(x)
return x - 1
def func2(x):
x = torch.sin(x)
if x.sum() > 0:
return x + 1
else:
return x - 1
def test_func1(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"])
compiled = torch.compile(backend=backend)(func1)
x = torch.ones(2, requires_grad=True).cuda()
y = torch.ones(2, requires_grad=True).cuda()
compiled(x, y)
def test_func2(benchmark):
backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder"])
compiled = torch.compile(backend=backend)(func2)
x = torch.randn(2, requires_grad=True).cuda()
compiled(x)
By running the above script pytest script.py -k func1 --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName' --benchmark-json=gyg.json (NOTE thunder/benchmarks/conftest.py is required to present in the same directory as script.py to provide the grouping customization.), outputs:
---------------------------------------------------------------------------- benchmark 'GraphID=GraphID[1] SplitModuleName=SplitModuleName[thunder_1]': 3 tests ----------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[eager] 12.4976 (1.0) 14.6263 (1.0) 12.7523 (1.0) 0.3502 (1.0) 12.6651 (1.0) 0.0972 (1.0) 43;71 78.4170 (1.0) 801 100
test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[inductor] 16.2459 (1.30) 18.5344 (1.27) 16.6045 (1.30) 0.3870 (1.11) 16.5051 (1.30) 0.1385 (1.43) 39;48 60.2246 (0.77) 620 100
test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder] 66.4015 (5.31) 98.0684 (6.70) 69.9598 (5.49) 4.4720 (12.77) 68.7676 (5.43) 1.4640 (15.06) 113;125 14.2939 (0.18) 1517 10
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
---------------------------------------------------------------------------- benchmark 'GraphID=GraphID[2] SplitModuleName=SplitModuleName[inductor_2]': 3 tests ----------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[eager] 4.8365 (1.0) 6.8269 (1.0) 4.9870 (1.0) 0.3311 (1.0) 4.9086 (1.0) 0.0515 (1.10) 102;136 200.5213 (1.0) 2070 100
test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[inductor] 8.2432 (1.70) 10.2169 (1.50) 8.3910 (1.68) 0.3599 (1.09) 8.3026 (1.69) 0.0470 (1.0) 69;86 119.1749 (0.59) 1230 100
test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[thunder] 44.4527 (9.19) 65.2626 (9.56) 47.5590 (9.54) 4.0214 (12.14) 46.4217 (9.46) 1.1875 (25.26) 168;193 21.0265 (0.10) 2267 10
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
.....
The max allocated memory is noted in the json file(run with --benchmark-json):
...
"name": "test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[eager]",
"fullname": "thunder/benchmarks/gbyg_test.py::test_func1",
"params": null,
"param": null,
"extra_info": {
"test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[eager]_max_allocated_memory_MB": 0.0029296875,
"test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[inductor]_max_allocated_memory_MB": 0.0029296875,
"test_func1-GraphID[1]-SplitModuleName[thunder_1]-executor[thunder]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[thunder_1]-executor[eager]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[thunder_1]-executor[inductor]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[thunder_1]-executor[thunder]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[eager]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[inductor]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[inductor_2]-executor[thunder]_max_allocated_memory_MB": 0.00244140625,
"test_func1-GraphID[2]-SplitModuleName[thunder_3]-executor[eager]_max_allocated_memory_MB": 0.00341796875,
"test_func1-GraphID[2]-SplitModuleName[thunder_3]-executor[inductor]_max_allocated_memory_MB": 0.00341796875,
"test_func1-GraphID[2]-SplitModuleName[thunder_3]-executor[thunder]_max_allocated_memory_MB": 0.0029296875
},
...
cc: @IvanYashchuk
Hi @tfogal @IvanYashchuk @crcrpar , can we merge this initial PR first, we can discuss the additional requirements in the new issue(https://github.com/Lightning-AI/lightning-thunder/issues/1258)
Hi @IvanYashchuk @t-vi , I think it's ready to merge, do you want to have another look?
@t-vi, could you please approve this PR to merge it?