diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Enable ONNX export of GPU loaded SVD/SVD-XT UNet models

Open rajeevsrao opened this issue 1 year ago • 19 comments

What does this PR do?

Unpack num_frames scalar if created as a (CPU) tensor in forward path Avoids mixed use of CPU and CUDA tensors which is unsupported by torch.nn ops

File "/usr/local/lib/python3.10/dist-packages/diffusers/models/unet_spatio_temporal_condition.py", line 422, in forward
    emb = emb.repeat_interleave(num_frames, dim=0)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

rajeevsrao avatar Jan 13 '24 07:01 rajeevsrao

@patrickvonplaten @sayakpaul please review.

rajeevsrao avatar Jan 13 '24 07:01 rajeevsrao

It should reside in optimum. Cc: @echarlaix

sayakpaul avatar Jan 13 '24 08:01 sayakpaul

Hi @rajeevsrao, could you share the script you used for the export ?

echarlaix avatar Jan 15 '24 10:01 echarlaix

It should reside in optimum. Cc: @echarlaix

You mean patching the model in optimum ? Depending on the modifications needed, it could make sense to have it in diffusers instead

echarlaix avatar Jan 15 '24 10:01 echarlaix

It should reside in optimum. Cc: @echarlaix

You mean patching the model in optimum ? Depending on the modifications needed, it could make sense to have it in diffusers instead

ONNX is a popular interchange format. @sayakpaul I think that diffusers should also support exporting models into ONNX. Especially given that this is a easy/harmless fix.

rajeevsrao avatar Jan 22 '24 23:01 rajeevsrao

Hi @rajeevsrao, could you share the script you used for the export ?

Here is the ONNX export script for reference

from diffusers.models import UNetSpatioTemporalConditionModel
import torch

model_name = "svd"
if model_name == "svd-xt":
    pipeline = 'stabilityai/stable-video-diffusion-img2vid-xt'
    num_frames = 25
else:
    pipeline = 'stabilityai/stable-video-diffusion-img2vid'
    num_frames = 14

device = 'cuda'
dtype = torch.float16
model = UNetSpatioTemporalConditionModel.from_pretrained(pipeline,
    subfolder="unet",
    use_safetensors=True,
    variant='fp16',
    torch_dtype=dtype).to(device)

batch_size = 2
out_channels = 4
cross_attention_dim = 1024
latent_height = 576 // 8
latent_width = 1024 // 8

input_names = ['sample', 'timestep', 'encoder_hidden_states', 'added_time_ids']
inputs = (
    torch.randn(batch_size, num_frames, 2*out_channels, latent_height, latent_width, dtype=dtype, device=device),
    torch.tensor([1.], dtype=torch.float32, device=device),
    torch.randn(batch_size, 1, cross_attention_dim, dtype=dtype, device=device),
    torch.randn(batch_size, 3, dtype=dtype, device=device),
)
output_names = ['latent']
dynamic_axes = {
    'sample': {0: '2B', 1: 'num_frames', 3: 'H', 4: 'W'},
    'encoder_hidden_states': {0: '2B'},
    'added_time_ids': {0: '2B'}
}

with torch.inference_mode(), torch.autocast(device):
    torch.onnx.export(model,
        inputs,
        model_name+"_unet.onnx",
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
    )



rajeevsrao avatar Jan 22 '24 23:01 rajeevsrao

@sayakpaul @echarlaix please suggest next steps. Thanks.

rajeevsrao avatar Jan 25 '24 17:01 rajeevsrao

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 19 '24 15:02 github-actions[bot]

@rajeevsrao do you still plan to work on this?

sayakpaul avatar Feb 19 '24 15:02 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Reviving this PR.

The main issue observed is that the type of num_frames changes based on the operation performed.

Elaborating using 2 cases below. Please print the type(num_frames) here to investigate further

Case 1: Inference

During inference, the type of num_frames is <class 'int'>. Inference script used:

import torch
from diffusers.utils import load_image
from diffusers import StableVideoDiffusionPipeline
pipe = StableVideoDiffusionPipeline.from_pretrained('stabilityai/stable-video-diffusion-img2vid', torch_dtype=torch.float16, variant="fp16").to("cuda")
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
image = image.resize((1024, 576))
frames = pipe(image, decode_chunk_size=8).frames[0]

Case 2: ONNX export

As the error specifically lies with the UNET, I'm exporting just the unet model using the script below. While tracing for the ONNX export, num_frames is created as a <class 'torch.Tensor'> on the CPU. Running the export on GPU results in the error from the description.

import torch
from diffusers.models import UNetSpatioTemporalConditionModel

dtype=torch.float16
device='cuda'
model = UNetSpatioTemporalConditionModel.from_pretrained('stabilityai/stable-video-diffusion-img2vid',
    subfolder="unet",
    use_safetensors=True,
    variant='fp16',
    torch_dtype=dtype).to(device)

batch_size = 2
num_frames=14
out_channels = 4
cross_attention_dim = 1024
latent_height = 576 // 8
latent_width = 1024 // 8

inputs = (
    torch.randn(batch_size, num_frames, 2*out_channels, latent_height, latent_width, dtype=dtype, device=device),
    torch.tensor([1.], dtype=torch.float32, device=device),
    torch.randn(batch_size, 1, cross_attention_dim, dtype=dtype, device=device),
    torch.randn(batch_size, 3, dtype=dtype, device=device),
)

with torch.inference_mode(), torch.autocast(device):
    torch.onnx.export(model,
        inputs,
        "svd/svd_unet.onnx",
        export_params=True,
        opset_version=18,
        do_constant_folding=True
    )

The PR aims to correct the inconsistency in the type of num_frames during inference and tracing

asfiyab-nvidia avatar Mar 01 '24 00:03 asfiyab-nvidia

Cc: @yiyixuxu. I am okay with the changes here since ONNX is very popular. LMK. @DN6, you too.

sayakpaul avatar Mar 01 '24 01:03 sayakpaul

@asfiyab-nvidia would you suggest anything being done differently in this PR?

sayakpaul avatar Mar 01 '24 01:03 sayakpaul

@asfiyab-nvidia would you suggest anything being done differently in this PR?

The main goal is to align the types. An alternative to the change suggested in the PR is to unconditionally cast the variable to a torch tensor.

num_frames = torch.tensor(num_frames).to(sample.device)

However, since the variable num_frames has usage in the context of being a scalar, I'd vote for the recommendation in the PR to cast to a scalar if found to be a tensor.

asfiyab-nvidia avatar Mar 01 '24 01:03 asfiyab-nvidia

cc @echarlaix here again

I'm fine with the change if agreed it's the best way to support ONNX export

yiyixuxu avatar Mar 01 '24 01:03 yiyixuxu

Hi, following up on this PR.

asfiyab-nvidia avatar Mar 07 '24 18:03 asfiyab-nvidia

@echarlaix a gentle ping.

sayakpaul avatar Mar 07 '24 19:03 sayakpaul

Hi @echarlaix @sayakpaul requesting updates based on the latest comments. Thanks

asfiyab-nvidia avatar Mar 27 '24 00:03 asfiyab-nvidia

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 03 '24 15:05 github-actions[bot]