[Issue]: Lower than expected torch.compile speedup with torch_migraphx for SD1.5 and SD2.0
Problem Description
Trying out torch.compile via torch_migraphx and using the example code in torch_migraphx/examples/dynamo/stable_diffusion (but compiling only the unet) does not seem to give a performance increase. Passing in exhaustive_tune gives a small performance increase, but takes roughly 20 minutes to compile.
Below are the performance results: Resolution: 512x512, steps: 75, excluding the first pass
| Model | Compile option | Speed |
|---|---|---|
| SD1.5 | None | ~8.5it/s |
| SD1.5 | migraphx | ~8.2it/s |
| SD1.5 | migraphx exhaustive_tune | ~8.7it/s |
| SD2,0 | None | ~9.2it/s |
| SD2.0 | migraphx | ~9.17it/s |
| SD2.0 | migraphx exhaustive_tune | ~8.87it/s |
Paging @sunway513
Operating System
Ubuntu 22.04.3 LTS
CPU
Ryzen 7 7700X
GPU
AMD Radeon RX 6800 XT
Other
No response
ROCm Version
ROCm 6.0.0
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
I had to patch in this commit to solve static_compute_shape: reshape_lazy on axis that is not packed when running the example.
I had to patch in this commit to solve static_compute_shape: reshape_lazy on axis that is not packed when running the example.
There is perf issues with that commit which is why it was never merged. That could very likely be what is causing the perf issue.
Also, what version/commit of migraphx are you using for this? As we did fix a reshape_lazy issue in #2721, so I dont know if that fixes your issue or if there is another issue to be addressed.
Also, what version/commit of migraphx are you using for this? As we did fix a reshape_lazy issue in #2721, so I dont know if that fixes your issue or if there is another issue to be addressed.
Initially tried ROCm 6.0 with MIGraphX 2.8, then after encountering the reshape lazy error, I forked MIGraphX at commit ddc991d and applied the patch as described above, unaware of the fix. I'll try out ROCm 6.1 with MIGraphX 2.9 and report back
I'll try out ROCm 6.1 with MIGraphX 2.9 and report back
Below is the performance of ROCm 6.1 with MIGraphX 2.9: Resolution: 512x512, steps: 75, excluding the first pass
| Model | Compile Option | Speed |
|---|---|---|
| SD1.5 | None | ~8.67it/s |
| SD1.5 | migraphx | ~9.98it/s |
| SD1.5 | migraphx exhaustive_tune | ~10.85it/s |
| SD2.0 | None | ~9.52it/s |
| SD2.0 | migraphx | ~9.98it/s |
| SD2.0 | migraphx exhaustive_tune | ~10.7it/s |
The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?
As for SDXL, 1024x1024 compilation fails with the following error: RuntimeError: /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/AMDMIGraphX/src/targets/gpu/hip.cpp:109: allocate_gpu: Memory not available to allocate buffer: 41943040.
Reducing resolution to 768x768 reduced the memory allocation but the error still occurs. Should i open another issue for this or is high vram consumption expected when compiling SDXL?
Full error log:
Traceback (most recent call last):
File "/home/k/Desktop/Hallucinogen/TMGX/sdxl.py", line 79, in <module>
run(args)
File "/home/k/Desktop/Hallucinogen/TMGX/sdxl.py", line 66, in run
image = pipe(prompt=args.prompts,
File "/home/k/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1174, in __call__
noise_pred = self.unet(
File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 939, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
result = inner_convert(
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
r = func(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 541, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1078, in transform_code_object
transformations(instructions, code_options)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 503, in transform
tracer.run()
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2202, in run
super().run()
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
while self.step():
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2338, in RETURN_VALUE
self._return(inst)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2323, in _return
self.output.compile_subgraph(
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 972, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1159, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
r = func(*args, **kwargs)
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1232, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1213, in call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/k/.local/lib/python3.10/site-packages/torch/__init__.py", line 1778, in __call__
return self.compiler_fn(model_, inputs_, **self.kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/backends.py", line 44, in migraphx_backend
return migraphx_aot_backend(gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/backends.py", line 67, in migraphx_aot_backend
compiled_gm = lower_aten_to_mgx(aten_gm, example_inputs, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/lower_dynamo.py", line 76, in lower_aten_to_mgx
mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/lower_dynamo.py", line 117, in lower_subgraph
mgx_module = MGXModule(program=interpreter.program,
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/fx/mgx_module.py", line 64, in __init__
self._initialize()
File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/fx/mgx_module.py", line 73, in _initialize
self.program.compile(migraphx.get_target('gpu'),
torch._dynamo.exc.BackendCompilerFailed: backend='migraphx' raised:
RuntimeError: /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/AMDMIGraphX/src/targets/gpu/hip.cpp:109: allocate_gpu: Memory not available to allocate buffer: 41943040
You can run the pytorch sdxl on its own on your system right? In general, we try and avoid duplicating weights when compiling but sometimes the compilation steps in migraphx can use extra memory and so you might see this error even if the total model size theoretically within the device memory limit.
I was able to get SDXL to run at full size on my gfx1030 (32gb) using this script:
import torch
from diffusers import DiffusionPipeline
import torch_migraphx
def benchmark(func, iters, *args, **kwargs):
# Warm up
for _ in range(1):
func(*args, **kwargs)
# Start benchmark.
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(iters):
out = func(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
# in ms
return (start_event.elapsed_time(end_event)) / iters
if __name__ == '__main__':
# torch.random.manual_seed(10)
model_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
prompts = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
]
num_steps = 30
fname = 'benchmark_output.png'
pipe = DiffusionPipeline.from_pretrained(model_repo,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16").to("cuda")
pipe.unet = torch.compile(
pipe.unet,
backend='migraphx',
options={"exhaustive_tune": True},
)
inputs = {
"prompt": prompts,
"height": 1024,
"width": 1024,
"num_inference_steps": num_steps,
"num_images_per_prompt": 1,
}
image = pipe(**inputs).images[0]
image.save(fname)
print("Benchmarking...")
t = benchmark(pipe, 10, **inputs)
print(f"sd e2e: {t} ms")
The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?
I think this is about the expected increase on a navi2 system for stable diffusion models. For stable diffusion more tuning has been done on Instinct cards and so you'd expect to see a much better boost on those cards.
You can run the pytorch sdxl on its own on your system right?
Yes, it does.
In general, we try and avoid duplicating weights when compiling but sometimes the compilation steps in migraphx can use extra memory and so you might see this error even if the total model size theoretically within the device memory limit.
Ah, gotcha. Do you want me to open an issue to track this?
I was able to get SDXL to run at full size on my gfx1030 (32gb) using this script:
Yep, thats roughly how I tested it (loading fp16, compile unet only). Also probably worth mentioning is that the compile fails with or without exhaustive_tune.
The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?
I think this is about the expected increase on a navi2 system for stable diffusion models. For stable diffusion more tuning has been done on Instinct cards and so you'd expect to see a much better boost on those cards.
Thanks for confirming, will close the issue then as its been resolved. Cheers.