vit-pytorch icon indicating copy to clipboard operation
vit-pytorch copied to clipboard

Cuda memory for 3D VIT

Open JesseZZZZZ opened this issue 11 months ago • 2 comments

image this 356GIB is a little stunning... I don't think I changed the original code enormously, so does anyone know that it is my mistake or the original itself needs such huge cuda memory? Thanks a lot !

JesseZZZZZ avatar Mar 20 '24 15:03 JesseZZZZZ

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

lucidrains avatar May 02 '24 15:05 lucidrains

@JesseZZZZZ

try

import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    use_flash_attn = True
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

should help with memory, but you'll still face the compute cost

Thank you so much! It does fix my problem to some extent!

JesseZZZZZ avatar May 06 '24 06:05 JesseZZZZZ