vit-pytorch
vit-pytorch copied to clipboard
Cuda memory for 3D VIT
this 356GIB is a little stunning... I don't think I changed the original code enormously, so does anyone know that it is my mistake or the original itself needs such huge cuda memory? Thanks a lot !
@JesseZZZZZ
try
import torch
from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT
v = SimpleViT(
image_size = 128, # image size
frames = 16, # number of frames
image_patch_size = 16, # image patch size
frame_patch_size = 2, # frame patch size
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
use_flash_attn = True
)
video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)
preds = v(video) # (4, 1000)
should help with memory, but you'll still face the compute cost
@JesseZZZZZ
try
import torch from vit_pytorch.simple_flash_attn_vit_3d import SimpleViT v = SimpleViT( image_size = 128, # image size frames = 16, # number of frames image_patch_size = 16, # image patch size frame_patch_size = 2, # frame patch size num_classes = 1000, dim = 1024, depth = 6, heads = 8, mlp_dim = 2048, use_flash_attn = True ) video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width) preds = v(video) # (4, 1000)
should help with memory, but you'll still face the compute cost
Thank you so much! It does fix my problem to some extent!