AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

[Issue]: Lower than expected torch.compile speedup with torch_migraphx for SD1.5 and SD2.0

Open FumoTime opened this issue 1 year ago • 5 comments

Problem Description

Trying out torch.compile via torch_migraphx and using the example code in torch_migraphx/examples/dynamo/stable_diffusion (but compiling only the unet) does not seem to give a performance increase. Passing in exhaustive_tune gives a small performance increase, but takes roughly 20 minutes to compile.

Below are the performance results: Resolution: 512x512, steps: 75, excluding the first pass

Model Compile option Speed
SD1.5 None ~8.5it/s
SD1.5 migraphx ~8.2it/s
SD1.5 migraphx exhaustive_tune ~8.7it/s
SD2,0 None ~9.2it/s
SD2.0 migraphx ~9.17it/s
SD2.0 migraphx exhaustive_tune ~8.87it/s

Paging @sunway513

Operating System

Ubuntu 22.04.3 LTS

CPU

Ryzen 7 7700X

GPU

AMD Radeon RX 6800 XT

Other

No response

ROCm Version

ROCm 6.0.0

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

I had to patch in this commit to solve static_compute_shape: reshape_lazy on axis that is not packed when running the example.

FumoTime avatar May 07 '24 18:05 FumoTime

I had to patch in this commit to solve static_compute_shape: reshape_lazy on axis that is not packed when running the example.

There is perf issues with that commit which is why it was never merged. That could very likely be what is causing the perf issue.

pfultz2 avatar May 07 '24 23:05 pfultz2

Also, what version/commit of migraphx are you using for this? As we did fix a reshape_lazy issue in #2721, so I dont know if that fixes your issue or if there is another issue to be addressed.

pfultz2 avatar May 08 '24 14:05 pfultz2

Also, what version/commit of migraphx are you using for this? As we did fix a reshape_lazy issue in #2721, so I dont know if that fixes your issue or if there is another issue to be addressed.

Initially tried ROCm 6.0 with MIGraphX 2.8, then after encountering the reshape lazy error, I forked MIGraphX at commit ddc991d and applied the patch as described above, unaware of the fix. I'll try out ROCm 6.1 with MIGraphX 2.9 and report back

FumoTime avatar May 08 '24 15:05 FumoTime

I'll try out ROCm 6.1 with MIGraphX 2.9 and report back

Below is the performance of ROCm 6.1 with MIGraphX 2.9: Resolution: 512x512, steps: 75, excluding the first pass

Model Compile Option Speed
SD1.5 None ~8.67it/s
SD1.5 migraphx ~9.98it/s
SD1.5 migraphx exhaustive_tune ~10.85it/s
SD2.0 None ~9.52it/s
SD2.0 migraphx ~9.98it/s
SD2.0 migraphx exhaustive_tune ~10.7it/s

The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?

FumoTime avatar May 08 '24 18:05 FumoTime

As for SDXL, 1024x1024 compilation fails with the following error: RuntimeError: /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/AMDMIGraphX/src/targets/gpu/hip.cpp:109: allocate_gpu: Memory not available to allocate buffer: 41943040.

Reducing resolution to 768x768 reduced the memory allocation but the error still occurs. Should i open another issue for this or is high vram consumption expected when compiling SDXL?

Full error log:

Traceback (most recent call last):
  File "/home/k/Desktop/Hallucinogen/TMGX/sdxl.py", line 79, in <module>
    run(args)
  File "/home/k/Desktop/Hallucinogen/TMGX/sdxl.py", line 66, in run
    image = pipe(prompt=args.prompts,
  File "/home/k/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1174, in __call__
    noise_pred = self.unet(
  File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 939, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 802, in _convert_frame
    result = inner_convert(
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 541, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1078, in transform_code_object
    transformations(instructions, code_options)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 503, in transform
    tracer.run()
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2202, in run
    super().run()
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 843, in run
    while self.step():
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 757, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2338, in RETURN_VALUE
    self._return(inst)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2323, in _return
    self.output.compile_subgraph(
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 972, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1159, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 265, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1232, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1213, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/k/.local/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/k/.local/lib/python3.10/site-packages/torch/__init__.py", line 1778, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/backends.py", line 44, in migraphx_backend
    return migraphx_aot_backend(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/backends.py", line 67, in migraphx_aot_backend
    compiled_gm = lower_aten_to_mgx(aten_gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/lower_dynamo.py", line 76, in lower_aten_to_mgx
    mgx_mod = lower_subgraph(mod, partition_inputs, name=name, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/dynamo/lower_dynamo.py", line 117, in lower_subgraph
    mgx_module = MGXModule(program=interpreter.program,
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/fx/mgx_module.py", line 64, in __init__
    self._initialize()
  File "/usr/local/lib/python3.10/dist-packages/torch_migraphx/fx/mgx_module.py", line 73, in _initialize
    self.program.compile(migraphx.get_target('gpu'),
torch._dynamo.exc.BackendCompilerFailed: backend='migraphx' raised:
RuntimeError: /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/AMDMIGraphX/src/targets/gpu/hip.cpp:109: allocate_gpu: Memory not available to allocate buffer: 41943040

FumoTime avatar May 08 '24 18:05 FumoTime

You can run the pytorch sdxl on its own on your system right? In general, we try and avoid duplicating weights when compiling but sometimes the compilation steps in migraphx can use extra memory and so you might see this error even if the total model size theoretically within the device memory limit.

I was able to get SDXL to run at full size on my gfx1030 (32gb) using this script:

import torch
from diffusers import DiffusionPipeline
import torch_migraphx


def benchmark(func, iters, *args, **kwargs):
    # Warm up
    for _ in range(1):
        func(*args, **kwargs)

    # Start benchmark.
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for _ in range(iters):
        out = func(*args, **kwargs)
    end_event.record()
    torch.cuda.synchronize()
    # in ms
    return (start_event.elapsed_time(end_event)) / iters


if __name__ == '__main__':
    # torch.random.manual_seed(10)
    model_repo = 'stabilityai/stable-diffusion-xl-base-1.0'
    prompts = [
        "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
    ]
    num_steps = 30
    fname = 'benchmark_output.png'

    pipe = DiffusionPipeline.from_pretrained(model_repo,
                                             torch_dtype=torch.float16,
                                             use_safetensors=True,
                                             variant="fp16").to("cuda")

    pipe.unet = torch.compile(
        pipe.unet,
        backend='migraphx',
        options={"exhaustive_tune": True},
    )

    inputs = {
        "prompt": prompts,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": num_steps,
        "num_images_per_prompt": 1,
    }
    image = pipe(**inputs).images[0]
    image.save(fname)

    print("Benchmarking...")
    t = benchmark(pipe, 10, **inputs)

    print(f"sd e2e: {t} ms")

The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?

I think this is about the expected increase on a navi2 system for stable diffusion models. For stable diffusion more tuning has been done on Instinct cards and so you'd expect to see a much better boost on those cards.

shivadbhavsar avatar May 13 '24 20:05 shivadbhavsar

You can run the pytorch sdxl on its own on your system right?

Yes, it does.

In general, we try and avoid duplicating weights when compiling but sometimes the compilation steps in migraphx can use extra memory and so you might see this error even if the total model size theoretically within the device memory limit.

Ah, gotcha. Do you want me to open an issue to track this?

I was able to get SDXL to run at full size on my gfx1030 (32gb) using this script:

Yep, thats roughly how I tested it (loading fp16, compile unet only). Also probably worth mentioning is that the compile fails with or without exhaustive_tune.

The performance increase does seem pretty nice (up to 25%). Are these performance improvements to be expected or should it be more?

I think this is about the expected increase on a navi2 system for stable diffusion models. For stable diffusion more tuning has been done on Instinct cards and so you'd expect to see a much better boost on those cards.

Thanks for confirming, will close the issue then as its been resolved. Cheers.

FumoTime avatar May 14 '24 02:05 FumoTime