pytorch_geometric
pytorch_geometric copied to clipboard
RandLA-Net in pytorch geometric's examples
The paper: RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds
Context
There lacks a good pytorch implementation of RandLa-Net that leverages pytorch geometric standards and modules. In torch-points3d, the current modules are outdated leading to some confusion among users.
The implementation with the most stars on github is aRI0U/RandLA-Net-pytorch, which has nasty dependencies (torch_points or torch_points_kernels), makes slow back-and-forth between cpu and gpu when calling knns, and only accepts fixed size point clouds.
Proposal
I would like to implement RandLA-Net as part of pyg's examples. For now I would tackle the ModelNet classification task, and would follow the structure of other examples (pointnet2_classification in particular).
The RandLa-Net paper focuses on segmentation, but for classification I would simply add a MLP+Global Max Pooling after the first DilatedResidualBlocks.
RandLa-Net architecture is conceptually close to PointNet++, augmented with different tricks to speed things up (random sampling instead of fps), use more context (with a sort of dilated KNN), and encode local information better (by explicitly calculating positions, distances, and euclidian distance between points in a neighborhood, and by using self-attention on these features).
If I have some success, I will take on the segmentation task as well (which is what interests me anyway for my own project)
Where I am at
I have a working implementation at examples/randlanet_classification.py
. I still have to review it to make sure that I am following the paper as closely as possible, but I think I am on the right track.
I would love some guidance on how to move forward. In particular:
- Am I using MessagePassing modules correctly?
- What should I aim for in term of accuracy on ModelNet?
- Should I stick strictly to the paper? Or adapt the architecture to ModelNet.
Indeed the hyperparameters were not chosen by the author for small objects but rather for large scale Lidar data, which could make convergence way longer that needed.
With 4 DilatedResidualBlocks (like in the paper), we reach ~57% accuracy at epoch 200.
With 3 DilatedResidualBlocks, we reach up to 75% accuracy at the 20th epoch
With only 2 DilatedResidualBlocks, we reach 90% accuracy at the 81st epoch, getting closer to the leaderboard for the ModelNet10 challenge.
Codecov Report
Merging #5117 (545b2cb) into master (07ba384) will decrease coverage by
1.86%
. The diff coverage is100.00%
.
:exclamation: Current head 545b2cb differs from pull request most recent head cb66f4b. Consider uploading reports for the commit cb66f4b to get more accurate results
@@ Coverage Diff @@
## master #5117 +/- ##
==========================================
- Coverage 86.20% 84.34% -1.87%
==========================================
Files 362 363 +1
Lines 20477 20487 +10
==========================================
- Hits 17653 17279 -374
- Misses 2824 3208 +384
Impacted Files | Coverage Δ | |
---|---|---|
torch_geometric/nn/pool/decimation.py | 100.00% <100.00%> (ø) |
|
torch_geometric/nn/models/dimenet_utils.py | 0.00% <0.00%> (-75.52%) |
:arrow_down: |
torch_geometric/nn/models/dimenet.py | 14.90% <0.00%> (-52.76%) |
:arrow_down: |
torch_geometric/profile/profile.py | 36.27% <0.00%> (-26.48%) |
:arrow_down: |
torch_geometric/nn/conv/utils/typing.py | 81.25% <0.00%> (-17.50%) |
:arrow_down: |
torch_geometric/nn/pool/asap.py | 92.10% <0.00%> (-7.90%) |
:arrow_down: |
torch_geometric/nn/inits.py | 67.85% <0.00%> (-7.15%) |
:arrow_down: |
torch_geometric/nn/dense/linear.py | 87.40% <0.00%> (-5.93%) |
:arrow_down: |
torch_geometric/transforms/add_self_loops.py | 94.44% <0.00%> (-5.56%) |
:arrow_down: |
torch_geometric/nn/models/attentive_fp.py | 95.83% <0.00%> (-4.17%) |
:arrow_down: |
... and 13 more |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
I implemented RandLa-Net for segmentation as well, and made some small refactor. The model seems to learn quite well, and reaches 70% accuracy after 3 epochs. It takes ~1s to run on CPU.
Unfortunately, I am not able to fully test it out on ShapeNet's Airplane task due to pytorch conflicts that prevent me to use CUDA :/
I need to install the master
branch of pyg to run segmentation training. When I follow instructions to build pytorch geometric from source, I see that I am working on a machine with CUDA 11.4, and that I therefore have to build pytorch against CUDA 11.4 before installing dependencies (torch_scatter
, etc), and then pytorch_geometric
from master branch directly. However, it seems that pytorch + CUDA 11.4 is not really supported - I could not find how to build it from source, and using cudatoolkit=11.4
in pytorch's conda install does not work.
Maybe I am missing something here... Any help would be appreciated :)
I suggest to simply install the wheels with CUDA 11.3, this should work even for CUDA 11.4.
Thank you, this worked like a charm. I think this is ready for review. :)
Right now both scripts follow the paper's architecture (in terms of hyperparameters, depth and number of channels in MLPs). Those were chosen by authors for large scale aerial lidar, not ModelNet and ShapeNet. For ModelNet, removing a few layer enables to reach good accuracy. I was not able to replicate this for ShapeNet, which quickly plateaus around 70% train accuracy (vs. 90% train accuracy / 79% test IoU for PointNet++).
I think it is cleaner to keep everything as it is to follow the paper. We could also change the benchmark for this model (with e.g. S3DIS), but I am not really sure that this is worth the extra work for an example.
I identified some differences between my implementation and the original paper (in particular in terms of batch norms, activations, and number of channels). Will come back with fixes and modifications!
Thanks! Sorry for the delay in review. Please keep me posted.
Here comes the update! I made a few changes:
- Set the BatchNorm and LeakyReLU with parameters from the paper (in particular: 0.99 momentum for BN, 0.2 alpha for lrelu, like in tensorflow).
- Made a default_MLP class with
plain_last=False
, because by default classtorch_geometric.nn.MLP
disables BN, activation, and dropout for the last layer. - Added a missing Linear layer at the beginning of the network.
- Changed the meaning of "d_out" in DilatedResidualBlock so that it is the actual number of channel that this block produces. This makes the dimensionnality of successive layers easier to read
For the classification task, I commented out two of the DilatedResidualBlock, because, as mentionned earlier, performances are way better for ModelNet with a smaller network.
With these modifications :
- ModelNet : 85%-88% accuracy after the ~50th epoch, with only two DilatedResidualBlocks instead of 4. (target in leaderbord: >=90%)
- ShapeNet : Loss: 0.6704 Train Acc: 0.7500 Test IoU: 0.4746 at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)
There may be an issue with my implementation for the segmentation task, but maybe it is simply a matter of model size / configuration vs. the task (as it is the case for classification where a smaller model works way better).
@rusty1s Ready for review. :)
I realized that the flow direction of the MessagePassing class responsible for the summarization of local neighborhood was inverted. I got immediate improvement for the segmentation task once fixed. ShapeNet : Loss: 0.3035 Train Acc: 0.8881Test IoU: 0.6739 at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)
EDIT: not so sure it was inverted, I reverted back.
https://github.com/pyg-team/pytorch_geometric/pull/5117/commits/1e544ca19300ac07d4b62713e132067e9cc82b25 : I also fixed the knn operation: first LFA aggregated only on neighborhood points, which means that the second LFA was working on a mixture of input features (for other points) and lfa outputs (for neighborhood centers). Correcting this leads to way faster convergence. Crazy how everything can still run in the presence of logic
--> Reaching 92.1% Test IoU for ModelNet at epoch 129 (with two ResidualBlocks) :1st_place_medal:
I think one last thing that I have to change is linked to the last upsampling :)
In the paper, this diagram seems to mean that the output of the first FC layer is used for the last skip connection.
But on the author's reference tensorflow implementation as well as in aRI0U's implementation (they are structurally very similar), it is different: the unsampled output of the first dilated residual block is used instead. This makes sense, and I just realized that this is mentionned in the paper as well:
Next, the upsampled feature maps are concatenated with the intermediate feature maps produced by encoding layers through skip connections,
And voilà :100: ! We reach PointNet++-comparable performances on ShapeNet's Airplane class after a few epochs :) ShapeNet : Loss: 0.2036 Train Acc: 92%, Test IoU: 82.7% at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)
@rusty1s :)
I also made an implementation and tested for 60 epochs. Reach similar results. The highest is at epoch 58 with IoU of 0.83.
Epoch: 58, Test IoU: 0.8301
[10/40] Loss: 0.2139 Train Acc: 0.9232
[20/40] Loss: 0.2094 Train Acc: 0.9241
[30/40] Loss: 0.2041 Train Acc: 0.9247
[40/40] Loss: 0.2030 Train Acc: 0.9235
I just put it here for simplicity. The code is relatively easier thanks to the point_transformer_segmentation example.
import os.path as osp
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Sequential, ReLU, Dropout, LeakyReLU
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Linear, knn_interpolate, MessagePassing, knn_graph
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import intersection_and_union as i_and_u
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
class SharedMLP(torch.nn.Module):
def __init__(self, in_channels, out_channels, bn=False, act=None):
super(SharedMLP, self).__init__()
self.lin = Linear(in_channels, out_channels)
if bn and bn is not None:
self.bn = torch.nn.BatchNorm1d(out_channels, eps=1e-6, momentum=0.99)
else:
self.bn = None
self.act = act
self.lin.reset_parameters()
if self.bn:
self.bn.reset_parameters()
def forward(self, x):
x = self.lin(x)
if self.bn is not None:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x
class RandlaConv(MessagePassing):
def __init__(self, d, d_slash, add_self_loops: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(RandlaConv, self).__init__(**kwargs)
self.mlp_rppe = SharedMLP(10, d, bn=True, act=ReLU()) # Relative Point Position Encoding
self.mlp_att = Linear(2 * d, 2 * d, bias=False)
self.mlp_post_pool = SharedMLP(2 * d, d_slash, bn=True, act=ReLU())
self.add_self_loops = add_self_loops
self.reset_parameters()
def reset_parameters(self):
reset(self.mlp_att)
def forward(self,
x: Union[Tensor, PairTensor],
pos: Union[Tensor, PairTensor],
edge_index: Adj
) -> Tensor:
if isinstance(x, Tensor):
x: PairTensor = (x, x)
if isinstance(pos, Tensor):
pos: PairTensor = (pos, pos)
if self.add_self_loops:
if isinstance(edge_index, Tensor):
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(
edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))
# propagate_type: (x: PairTensor, pos: PairTensor)
out = self.propagate(edge_index, x=x, pos=pos, size=None)
f_tilde = self.mlp_post_pool(out)
return f_tilde
def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,
index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor:
delta = pos_i - pos_j
dist = delta.norm(dim=-1, keepdim=True)
rp = torch.cat([pos_i, pos_j, delta, dist], dim=-1)
r = self.mlp_rppe(rp)
# makesure the dim of r and xj are both d and the size are the same
assert x_j.shape == r.shape
f_hat = torch.cat([x_j, r], dim=-1)
# g(f,W)
f_hat = self.mlp_att(f_hat)
# do softmax along the KNN
s = softmax(f_hat, index, ptr, size_i)
f_hat = s * f_hat
# propagate will do sum aggregation after attentive pooling between KNN
return f_hat
class DilatedResidualBlock(torch.nn.Module):
def __init__(self, d_in, d_out, k=16):
super(DilatedResidualBlock, self).__init__()
self.mlp_start = SharedMLP(d_in, d_out // 2, act=LeakyReLU(0.2))
self.mlp_end = SharedMLP(d_out, 2 * d_out)
self.mlp_skip = SharedMLP(d_in, 2 * d_out, bn=True)
self.locse_ap1 = RandlaConv(d_out // 2, d_out // 2)
self.locse_ap2 = RandlaConv(d_out // 2, d_out)
self.lrelu = torch.nn.LeakyReLU()
self.k = k
def reset_parameters(self):
reset(self.lrelu)
def forward(self, x, pos, batch):
edge_index = knn_graph(pos, k=self.k, batch=batch)
x_skip = self.mlp_skip(x) # 2*dout
x = self.mlp_start(x) # dout / 2
x = self.locse_ap1(x, pos, edge_index) # dout / 2
x = self.locse_ap2(x, pos, edge_index) # dout
x = self.mlp_end(x) # 2*dout
x = x + x_skip
x = self.lrelu(x) # 2 * dout
return x
class FPModule(torch.nn.Module):
"""Upsampling with a skip connection."""
def __init__(self, k, nn):
super().__init__()
self.k = k
self.nn = nn
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
if x_skip is not None:
x = torch.cat([x, x_skip], dim=1)
x = self.nn(x)
return x, pos_skip, batch_skip
def inverse_permutation(perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv
class Net(torch.nn.Module):
def __init__(self, num_features, num_classes, k=16, decimation=4):
super(Net, self).__init__()
self.k = k # knn
self.decimation = decimation # decimation ratio
self.fc_start = SharedMLP(num_features, 8, bn=True, act=LeakyReLU(0.2))
# encoder
self.module_down = torch.nn.ModuleList([
DilatedResidualBlock(8, 16, self.k),
DilatedResidualBlock(32, 64, self.k),
DilatedResidualBlock(128, 128, self.k),
DilatedResidualBlock(256, 256, self.k)
])
self.mlp_summit = SharedMLP(512, 512, act=ReLU())
self.module_up = torch.nn.ModuleList([
FPModule(1, SharedMLP(512 + 512, 256, bn=True, act=ReLU())),
FPModule(1, SharedMLP(256 + 256, 128, bn=True, act=ReLU())),
FPModule(1, SharedMLP(128 + 128, 32, bn=True, act=ReLU())),
FPModule(1, SharedMLP(32 + 32, 8, bn=True, act=ReLU()))
])
self.mlp_cls = Sequential(
SharedMLP(8, 64, bn=True, act=ReLU()),
SharedMLP(64, 32, bn=True, act=ReLU()),
Dropout(),
SharedMLP(32, num_classes)
)
def forward(self, x, pos, batch, ptr):
# create random permutation for each batch
B = ptr.size()[0] - 1
batch_sizes = torch.Tensor([ptr[i + 1] - ptr[i] for i in range(B)]).long()
indices = [torch.randperm(batch_sizes[i], dtype=torch.int64, device=ptr.device) + ptr[i] for i in range(B)]
indices = torch.cat(indices)
inverse_indices = inverse_permutation(indices)
x = x[indices, :]
pos = pos[indices, :]
batch = batch[indices,]
out_x = []
out_pos = []
out_batch = []
x = self.fc_start(x)
for encoder in self.module_down:
# KNN aggregation
x = encoder(x, pos, batch)
out_x.append(x)
out_pos.append(pos)
out_batch.append(batch)
# select first part of ::decimated indices in each batch
batch_sizes = batch_sizes // self.decimation
indices_decimated = torch.cat([
torch.arange(batch_sizes[i], device=ptr.device) + ptr[i] for i in range(B)
])
# update ptr
for i in range(B):
ptr[i + 1] = ptr[i] + batch_sizes[i]
x, pos, batch = x[indices_decimated, :], pos[indices_decimated, :], batch[indices_decimated,]
x = self.mlp_summit(x)
for i, decoder in enumerate(self.module_up):
skip_x = out_x.pop()
skip_pos = out_pos.pop()
skip_batch = out_batch.pop()
x, pos, batch = decoder(x, pos, batch, skip_x, skip_pos, skip_batch)
x = self.mlp_cls(x)
x = x[inverse_indices, :]
return x.log_softmax(dim=-1)
if __name__ == '__main__':
category = 'Airplane' # Pass in `None` to train on all categories.
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
transform = T.Compose([
T.RandomTranslate(0.01),
T.RandomRotate(15, axis=0),
T.RandomRotate(15, axis=1),
T.RandomRotate(15, axis=2)
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform,
pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test',
pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True,
num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=60, shuffle=False,
num_workers=6)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(3, train_dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
def train():
model.train()
total_loss = correct_nodes = total_nodes = 0
for i, data in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.pos, data.batch, data.ptr)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
total_nodes += data.num_nodes
if (i + 1) % 10 == 0:
print(f'[{i + 1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
f'Train Acc: {correct_nodes / total_nodes:.4f}')
total_loss = correct_nodes = total_nodes = 0
@torch.no_grad()
def test(loader):
model.eval()
y_mask = loader.dataset.y_mask
ious = [[] for _ in range(len(loader.dataset.categories))]
for data in loader:
data = data.to(device)
pred = model(data.x, data.pos, data.batch, data.ptr).argmax(dim=1)
i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch)
iou = i.cpu().to(torch.float) / u.cpu().to(torch.float)
iou[torch.isnan(iou)] = 1
# Find and filter the relevant classes for each category.
for iou, category in zip(iou.unbind(), data.category.unbind()):
ious[category.item()].append(iou[y_mask[category]])
# Compute mean IoU.
ious = [torch.stack(iou).mean(0).mean(0) for iou in ious]
return torch.tensor(ious).mean().item()
for epoch in range(1, 61):
train()
iou = test(test_loader)
print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')
scheduler.step()
Thank you very much for sharing your implementation @saedrna. That was really helpful in order to give final touches of simplification, better readibility. I cherry-picked the following elements:
- SharedMLP as a class instead of a function
- knn_graph instead of knn
- Use of pyg's softmax instead of torch's softmax
- Use of ptr for decimation
- Decimation that happens in main forward loop, which removes a special case where we needed non-decimated block output.
I prefer to avoid loops for encoders/decoders in order to keep everything easy to read, which is important in an example - I thus keep the structure of the PointNet++ example.
I think we are a go :)
@saedrna As a side comment, if you still wand to use your implementation with a loop, I have the following comments:
- Authors of RandLA-Net used
LeakyReLU(negative_slope=0.2)
instead of standardReLU()
activations. - I think you would always want to have self_loops in your knn_graph, so the features of the neighborhood centroid are also considered as part of the local feature encoding. You can do this by specifying
loop=True
inknn_graph
:)
Looks good to me~
Just one more thing: one of the beautiful things for RandLA is that, although it's called the selection of random decimation points during the transition-down, in implementation we don't really need actually to permute every time. Just permute the input and choosing every decimated slice in the first parts should be more elegant (like x=x[perm], x=x[:N/d], x=x[inv_perm]
). Just don't forget to permute back.
I learned this trick from the Pytorch implementation. But he uses sort to permute back, which is O(nlogn), the O(n) solution exists with a temporary array.
https://github.com/aRI0U/RandLA-Net-pytorch/blob/057035adf587a4e377e431f54654a090e53740f2/model.py#L233-L300
Looks good to me
Nice. :)
~ Just one more thing: one of the beautiful things for RandLA is that, although it's called the selection of random decimation points during the transition-down, in implementation we don't really need actually to permute every time. Just permute the input and choosing every decimated slice in the first parts should be more elegant (like
x=x[perm], x=x[:N/d], x=x[inv_perm]
). Just don't forget to permute back. I learned this trick from the Pytorch implementation. But he uses sort to permute back, which is O(nlogn), the O(n) solution exists with a temporary array.https://github.com/aRI0U/RandLA-Net-pytorch/blob/057035adf587a4e377e431f54654a090e53740f2/model.py#L233-L300
I see te appeal as well. I actually encountered empty clouds on my data with this solution: I made sure that num_nodes in each cloud was slightly above decimation**4, but the opration x=x[:N/d]
would not decimate permutated clouds homogeneously. This would result in empty clouds and errors.
I thus prefer to resort to either your concatenation i.e.
indices_decimated = torch.cat([
torch.arange(batch_sizes[i], device=ptr.device) + ptr[i] for i in range(B)
])
or mine, which leverages a randperm:
idx_decim = torch.cat(
[
(ptr[i] + torch.randperm(decimated_num_nodes[i], device=ptr.device))
for i in range(batch_size)
],
dim=0,
)
And if we have to cat things, might as well avoid a confusing permutation + inverse permutation.
I am actually thinking of adding a +1
to decimated_num_nodes[i] to always have at least one node even if the input cloud is too small. That would be a nice trick to avoid any unforeseen error and more complex preprocessing. Any objection to that strategy? :)
EDIT: that might be a good idea for my use cases but outside of the scope of this example.
I think everyone has his own flavor for coding and I'm fine with this.
I found an issue, the momentum
parameter in PyTorch and Tensorflow is different. So it should be 0.01 rather than 0.99.
I found an issue, the
momentum
parameter in PyTorch and Tensorflow is different. So it should be 0.01 rather than 0.99.
Indeed. Good catch, thank you, I'll update it soon. It is a shame because I just got my best model so far on my own data with this implementation. Hope it last with more momentum ^^
@rusty1s You asked to be kept posted, so here is a gentle bump for review :).
Super, let me do a final review tomorrow:) Thanks for the amazing progress!
Thanks for adding this. This looks great already! Left some comments. Let me know if you have any questions.
My pleasure. Thanks for the great review, I'll be able to look into it next week. :)
Super, please ping me :)
Hi @rusty1s, I took most of your comment into account. I made a few replies as well for clarification.
Last item on the list would be moving the decimation pooling function into nn.pool
as you suggested. I will have the time to do so mid-October and hopefully finalize this PR :100:
@rusty1s Hi!
Back to the office! I moved the function to get decimation indices to its own decimation
module under torch.nn.pool
, along with a few simple tests.
On a side note, I always wonder what ptr
stands for (maybe pointer?). Feel free to edit its docstring if needed (currenly: ptr (LongTensor): indices of samples in the batch.
).
@rusty1s @saedrna The gentlest bump on this :)
Yes, I will merge this over the weekend. Sorry for the delay.
@rusty1s Resolved the conflict in Changelog :)
@rusty1s Small bump :)