point-transformer-pytorch icon indicating copy to clipboard operation
point-transformer-pytorch copied to clipboard

Cost too much memory

Open JLU-Neal opened this issue 4 years ago • 9 comments

I'm not sure whether I used the point-transformer correctly: I just implemented one block for training, and the data shape of (x, pos) in each gpu are both [16, 2048, 3], later I was informed that my gpu is running out of the memory(11.77 GB total capacity)

JLU-Neal avatar Feb 10 '21 09:02 JLU-Neal

Oh yup, this type of vector attention is quite expensive. In the paper, they did KNN on each point, and only attended to the local neighbors

lucidrains avatar Feb 10 '21 17:02 lucidrains

You can use this great library to do the clustering https://github.com/rusty1s/pytorch_cluster

lucidrains avatar Feb 10 '21 17:02 lucidrains

@JLU-Neal Hi, I decided to add in the feature for attending only to the k nearest neighbors, which can be set with the num_neighbors keyword argument. Let me know if it works for you! https://github.com/lucidrains/point-transformer-pytorch/commit/2ec1322cbecced477826652c567ed8bc2d31c952

lucidrains avatar Feb 11 '21 01:02 lucidrains

@JLU-Neal Hi, I decided to add in the feature for attending only to the k nearest neighbors, which can be set with the num_neighbors keyword argument. Let me know if it works for you! 2ec1322

Wow, surprised by your efficient work! Wish you a happy Chinese new year!

JLU-Neal avatar Feb 11 '21 03:02 JLU-Neal

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

zimonitrome avatar Feb 22 '21 11:02 zimonitrome

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

In the section 3.2 Point Transformer Layer, author mentioned that "the subset X(i) ∈ X is a set of points in a local neighborhood (specifically k nearest neighbors) of xi."

JLU-Neal avatar Feb 23 '21 01:02 JLU-Neal

kNN is only specified for the "transition down" layer. They don't seem to mention it for the general point transformer layer. So is this just an added bonus or am I missing something from the original paper?

In the section 3.2 Point Transformer Layer, author mentioned that "the subset X(i) ∈ X is a set of points in a local neighborhood (specifically k nearest neighbors) of xi."

On that note, do you think they talk about k-nearest neighbours in the point space? They always refer to the coordinates as p and not X throughout the paper. I've always read that as the KNN in the feature space which although might be less stable, may increase the receptive field quite a bit.

ouenal avatar Feb 24 '21 10:02 ouenal

@lucidrains I have the same issue, and cannot get a single layer to run with 12GB GPU Memory. Maybe my understanding of the layer is incorrect but I have made the following observation:

I have a pointcloud in (160000, 3) Format which I put into the layer as both feature and position with nearest neighbors k=8. However I noticed (using smaller input data) that in the forward function the relative position takes up almost pos**2 in memory and does not change with k. Is the implementation here correct?

rel_pos = pos[:, :, None, :] - pos[:, None, :, :]

L-Reichardt avatar Jul 17 '22 18:07 L-Reichardt

@lucidrains I have the same issue, and cannot get a single layer to run with 12GB GPU Memory. Maybe my understanding of the layer is incorrect but I have made the following observation:

I have a pointcloud in (160000, 3) Format which I put into the layer as both feature and position with nearest neighbors k=8. However I noticed (using smaller input data) that in the forward function the relative position takes up almost pos**2 in memory and does not change with k. Is the implementation here correct?

rel_pos = pos[:, :, None, :] - pos[:, None, :, :]

I believe this code is designed to support both global and local attention mechanisms by switching knn on and off. The following code line checks whether knn is on and starts to select specific qk_rel, v, and position embedding based on indices from knn.

https://github.com/lucidrains/point-transformer-pytorch/blob/f9d4e56a26ceee70deb60da230fef40c656396e6/point_transformer_pytorch/point_transformer_pytorch.py#L78

kidpaul94 avatar Dec 31 '22 01:12 kidpaul94