se3-transformer-pytorch
se3-transformer-pytorch copied to clipboard
CUDA out of memory
Thanks for your great job!
The se3-transformer is powerful, but seems to be memory exhaustive.
I built a model with the following parameters, and got "CUDA out of memory error" when I run it on the GPU(Nvidia V100 / 32G).
model = SE3Transformer( dim = 20, heads = 4, depth = 2, dim_head = 5, num_degrees = 2, valid_radius = 5 )
num_points = 512
feats = torch.randn(1, num_points, 20)
coors = torch.randn(1, num_points, 3)
mask = torch.ones(1, num_points).bool()
Does this error relate to the version of pytorch? and how can I fix it?
Same problem, our GPU is A100/80G, but can not run the above code.