SkyReels-A2 icon indicating copy to clipboard operation
SkyReels-A2 copied to clipboard

RTX5090 - OOM issue even with offload

Open emerdem opened this issue 10 months ago • 1 comments

Hi,

I have a RTX5090 and normally able to run Wan 2.1 14B using diffsynth studio however I wasn't able to run your model. I am using the below revised script as per instructions, however ran into OOM issue as below.

I would appreciate your support. Thank you

Error: $ python infer_MGPU.py Fetching 32 files: 100%|██████████████████████| 32/32 [00:00<00:00, 8723.93it/s] Loading checkpoint shards: 100%|██████████████████| 7/7 [00:02<00:00, 3.21it/s] Traceback (most recent call last): File "/home/emre/Documents/FluxVideoGen/SkyReels-A2/infer_MGPU.py", line 51, in transformer.to(device, dtype=dtype) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 1353, in to return super().to(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1355, in to return self._apply(convert) ^^^^^^^^^^^^^^^^^^^^ File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 915, in _apply module._apply(fn) File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 915, in _apply module._apply(fn) File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 915, in _apply module._apply(fn) [Previous line repeated 1 more time] File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 942, in _apply param_applied = fn(param) ^^^^^^^^^ File "/home/emre/anaconda3/envs/diffstudio2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1341, in convert return t.to( ^^^^^ torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 50.00 MiB. GPU 0 has a total capacity of 31.36 GiB of which 91.00 MiB is free. Including non-PyTorch memory, this process has 30.67 GiB memory in use. Of the allocated memory 30.18 GiB is allocated by PyTorch, and 12.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

import torch 
torch.cuda.empty_cache()
import os
from PIL import Image 
import numpy as np 
from diffusers import AutoencoderKLWan
from transformers import CLIPVisionModel 
from diffusers.video_processor import VideoProcessor
from diffusers import UniPCMultistepScheduler 
from diffusers.utils import export_to_video, load_image 
from diffusers.image_processor import VaeImageProcessor
from diffusers.training_utils import free_memory

from models.transformer_a2 import A2Model 
from models.pipeline_a2_parallel import WanA2Pipeline 
from models.utils import _crop_and_resize_pad, _crop_and_resize, write_mp4
from huggingface_hub import snapshot_download

# Removed: import torch.distributed as dist
# Removed: from para_attn.context_parallel import init_context_parallel_mesh
# Removed: from para_attn.context_parallel.diffusers_adapters import parallelize_pipe

prompt = "A man is holding a teddy bear in the forest." 
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

refer_images = ['assets/human.png', 'assets/thing.png', 'assets/env.png'] 
width = 832
height = 480 
seed = 42 
# if RTX4090, set True
# offload_switch = True
offload_switch = True

# model parameters 
device = "cuda"
video_path = "output.mp4"
pipeline_path = "Skywork/SkyReels-A2"
dtype = torch.float16

# download models
snapshot_download(repo_id="Skywork/SkyReels-A2", local_dir="Skywork/SkyReels-A2")

# load models 
image_encoder = CLIPVisionModel.from_pretrained(pipeline_path, subfolder="image_encoder", torch_dtype=torch.float32, device_map="cpu")
vae = AutoencoderKLWan.from_pretrained(pipeline_path, subfolder="vae", torch_dtype=torch.float32, device_map="cpu")

# print("load transformer...")
model_path = os.path.join(pipeline_path, 'transformer')
transformer = A2Model.from_pretrained(model_path, torch_dtype=torch.float16, use_safetensors=True)
# # transformer.save_pretrained("transformer", max_shard_size="5GB") 
transformer.to(device, dtype=dtype) 

print(torch.cuda.memory_summary())

pipe = WanA2Pipeline.from_pretrained(pipeline_path, transformer=transformer, vae=vae, image_encoder=image_encoder, torch_dtype=dtype)

scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=8)
pipe.scheduler = scheduler 
# Removed mesh and parallelize_pipe setup
# mesh = init_context_parallel_mesh(
#         pipe.device.type,
#         max_ring_dim_size=1,
#         max_batch_dim_size=2,
#     )
# parallelize_pipe(pipe, mesh=mesh)
transformer.to(device, dtype=dtype) 
pipe.to(device)

# for RTX4090
if offload_switch:
    Offload.offload(
        pipeline=pipe,
        config=OffloadConfig(
            high_cpu_memory=True,
            parameters_level=True,
            compiler_transformer=False,
        ),
    )


VAE_SCALE_FACTOR_SPATIAL = 8
video_processor = VideoProcessor(vae_scale_factor=VAE_SCALE_FACTOR_SPATIAL)

# prepare reference images
clip_image_list = []
vae_image_list = []
for image_id, image_path in enumerate(refer_images): 
    image = load_image(image=image_path).convert("RGB")
    # for clip 
    image_clip = _crop_and_resize_pad(image, height=512, width=512) 
    clip_image_list.append(image_clip)
    
    # for vae 
    if image_id == 0 or image_id == 1: 
        image_vae = _crop_and_resize_pad(image, height=height, width=width) # ref image
    else:
        image_vae = _crop_and_resize(image, height=height, width=width) # background image
    
    image_vae = video_processor.preprocess(image_vae, height=height, width=width).to(memory_format=torch.contiguous_format) # (1, 3, 480, 320)
    image_vae = image_vae.unsqueeze(2).to(device, dtype=torch.float32)
    vae_image_list.append(image_vae) #.to(device, dtype=dtype))

# forward
generator = torch.Generator(device).manual_seed(seed) 
video_pt = pipe(
    image_clip=clip_image_list, 
    image_vae=vae_image_list,
    prompt=prompt, 
    negative_prompt=negative_prompt, 
    height=480, 
    width=width, 
    num_frames=81, 
    guidance_scale=5.0,
    generator=generator,
    output_type="pt",
    num_inference_steps=50,
    vae_combine="before",
).frames

dist.barrier()
free_memory()


# combine results
batch_size = video_pt.shape[0]
batch_video_frames = []
for batch_idx in range(batch_size):
    pt_image = video_pt[batch_idx]
    pt_image = torch.stack([pt_image[i] for i in range(pt_image.shape[0])])
    pt_image = pt_image[12:]
    image_np = VaeImageProcessor.pt_to_numpy(pt_image)
    image_pil = VaeImageProcessor.numpy_to_pil(image_np)
    batch_video_frames.append(image_pil)

video_generate = batch_video_frames[0] 
final_images = []
for q in range(len(video_generate)): 
    frame1 = _crop_and_resize_pad(load_image(image=refer_images[0]), height, width) 
    frame2 = _crop_and_resize_pad(load_image(image=refer_images[1]), height, width) 
    frame3 = _crop_and_resize_pad(load_image(image=refer_images[2]), height, width) 
    frame4 = Image.fromarray(np.array(video_generate[q])).convert("RGB")
    result = Image.new('RGB', (width * 4, height),color="white")
    result.paste(frame1, (0, 0)) 
    result.paste(frame2, (width, 0)) 
    result.paste(frame3, (width*2, 0)) 
    result.paste(frame4, (width*3, 0)) 
    final_images.append(np.array(result))
# Removed dist.barrier() calls
# dist.barrier()

# Unconditionally write output video
write_mp4(video_path, final_images, fps=15) 

emerdem avatar Apr 27 '25 22:04 emerdem

Same here, tried to run 8 RTX 3090s, but "torch.OutOfMemoryError: CUDA out of memory" error pops up everytime...

rwang5203 avatar Apr 28 '25 07:04 rwang5203