point-e icon indicating copy to clipboard operation
point-e copied to clipboard

Rework forward pass to remove old gradients

Open Arkay92 opened this issue 2 years ago • 16 comments

Using the torch.cuda.device_of() function to determine if the input tensors are on the GPU or CPU, and then choosing the appropriate layer implementations for better performance. Uses the torch.no_grad() context manager to prevent the model from tracking gradients in the forward pass.

Arkay92 avatar Dec 27 '22 18:12 Arkay92

This may be linked to issue #27

Arkay92 avatar Dec 27 '22 18:12 Arkay92

This is awesome - without this change I cannot run any of the examples on my Geforce GTX 1650 with 4Gb of dedicated GPU memory. With this change I can run the 40M-textvec model. This takes sampling time from nearly one hour (cpu) to a couple of minutes (Gpu) on my laptop. Thank you so much ! I hope it is accepted in to the repo.

dancergraham avatar Dec 29 '22 09:12 dancergraham

This also relates to issue #36

dancergraham avatar Dec 29 '22 10:12 dancergraham

Hello, Using the pointcloud2mesh.ipynb notebook I get an error:

AttributeError: module 'torch.nn' has no attribute 'CUDALayerNorm'

I am using pytorch version '1.13.1+cu117'

dancergraham avatar Dec 29 '22 11:12 dancergraham

Good spot yet again @dancergraham have switched over to nvidia apex for fusedlayernorm on GPU can you try this now ?

Arkay92 avatar Dec 29 '22 11:12 Arkay92

NB this requires the external lib for apex to work should speed up rendering once it fires up. Any issues let me know and I'll rework @dancergraham (have added to setup.py install_requires)

Arkay92 avatar Dec 29 '22 14:12 Arkay92

I was not able to install apex with pip on my Windows machine - I got a lot of errors about "filename too long"

I tried python -m pip install "apex @ git+https://github.com/NVIDIA/apex.git"

dancergraham avatar Dec 29 '22 21:12 dancergraham

Looking into other alternatives to layernorm and it seems instance or group normalisation may help speed things up here ! Will ping another refactored PR soon

Arkay92 avatar Dec 29 '22 22:12 Arkay92

hmm this looks rather complex - I will try it out on my machine but if I was running the point-e repo I don't think I would want to adopt a complicated dependency, especially one marked as "experimental" on Windows...

It might be good to add it as an optional dependency in the same way that the code currently works with or without cuda; That adds complexity to the library so it is the maintainers' call whether or not to accept that approach.

dancergraham avatar Dec 30 '22 14:12 dancergraham

Shall remove the apex lib but keep the forward pass change this should still preserve performance without the lib dependency

Arkay92 avatar Dec 30 '22 14:12 Arkay92

@dancergraham try this now, textvec rendering should still be significantly faster whilst I find a native way of speeding up layer norm on gpu / cuda

Arkay92 avatar Dec 30 '22 14:12 Arkay92

I now get an error TypeError: ResidualCrossAttentionBlock.forward() missing 1 required positional argument: 'device' when I try to run pointcloud2mesh

perceiver.py:154, in SimplePerceiver.forward(self, x, data)
    152 with torch.no_grad():
    153     for block in self.resblocks:
--> 154         x = block(x, data)
    155 return x

dancergraham avatar Jan 01 '23 20:01 dancergraham

My bad @dancergraham forgot I added as a param, changed back so .to() uses torch.deice directly rather than by reference from param list

Arkay92 avatar Jan 02 '23 10:01 Arkay92

still not working for me - I get errors with pointcloud2mesh:

File ...\point_e\models\perceiver.py:154, in SimplePerceiver.forward(self, x, data)
    152 with torch.no_grad():
    153     for block in self.resblocks:
--> 154         x = block(x, data)
    155 return x

File ...\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ...\point-e\point_e\models\perceiver.py:106, in ResidualCrossAttentionBlock.forward(self, x, data)
    103 def forward(self, x: torch.Tensor, data: torch.Tensor):
    104     with torch.no_grad():
    105         # Use the to() method to move the input tensors to the specified device
--> 106         x = x.to(torch.device)
    107         data = data.to(torch.device)
    109         # Normalize input tensors and pass them through the attention and MLP layers

TypeError: to() received an invalid combination of arguments - got (type), but expected one of:
 * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (torch.dtype dtype, bool non_blocking, bool copy, *, torch.memory_format memory_format)
 * (Tensor tensor, bool non_blocking, bool copy, *, torch.memory_format memory_format)

dancergraham avatar Jan 02 '23 20:01 dancergraham

We have liftoff 🚀 I can now run at grid_size=128 in 45 seconds per model on my GPU - many thanks again !

dancergraham avatar Jan 03 '23 06:01 dancergraham

Thankyou so much for the testing support @dancergraham ! LFG !

Arkay92 avatar Jan 03 '23 11:01 Arkay92