pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

RandLA-Net in pytorch geometric's examples

Open CharlesGaydon opened this issue 2 years ago • 24 comments

The paper: RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds

Context

There lacks a good pytorch implementation of RandLa-Net that leverages pytorch geometric standards and modules. In torch-points3d, the current modules are outdated leading to some confusion among users.

The implementation with the most stars on github is aRI0U/RandLA-Net-pytorch, which has nasty dependencies (torch_points or torch_points_kernels), makes slow back-and-forth between cpu and gpu when calling knns, and only accepts fixed size point clouds.

Proposal

I would like to implement RandLA-Net as part of pyg's examples. For now I would tackle the ModelNet classification task, and would follow the structure of other examples (pointnet2_classification in particular).

The RandLa-Net paper focuses on segmentation, but for classification I would simply add a MLP+Global Max Pooling after the first DilatedResidualBlocks.

RandLa-Net architecture is conceptually close to PointNet++, augmented with different tricks to speed things up (random sampling instead of fps), use more context (with a sort of dilated KNN), and encode local information better (by explicitly calculating positions, distances, and euclidian distance between points in a neighborhood, and by using self-attention on these features).

If I have some success, I will take on the segmentation task as well (which is what interests me anyway for my own project)

Where I am at

I have a working implementation at examples/randlanet_classification.py. I still have to review it to make sure that I am following the paper as closely as possible, but I think I am on the right track.

I would love some guidance on how to move forward. In particular:

  • Am I using MessagePassing modules correctly?
  • What should I aim for in term of accuracy on ModelNet?
  • Should I stick strictly to the paper? Or adapt the architecture to ModelNet.

Indeed the hyperparameters were not chosen by the author for small objects but rather for large scale Lidar data, which could make convergence way longer that needed.

With 4 DilatedResidualBlocks (like in the paper), we reach ~57% accuracy at epoch 200.

With 3 DilatedResidualBlocks, we reach up to 75% accuracy at the 20th epoch

With only 2 DilatedResidualBlocks, we reach 90% accuracy at the 81st epoch, getting closer to the leaderboard for the ModelNet10 challenge.

CharlesGaydon avatar Aug 02 '22 16:08 CharlesGaydon

Codecov Report

Merging #5117 (545b2cb) into master (07ba384) will decrease coverage by 1.86%. The diff coverage is 100.00%.

:exclamation: Current head 545b2cb differs from pull request most recent head cb66f4b. Consider uploading reports for the commit cb66f4b to get more accurate results

@@            Coverage Diff             @@
##           master    #5117      +/-   ##
==========================================
- Coverage   86.20%   84.34%   -1.87%     
==========================================
  Files         362      363       +1     
  Lines       20477    20487      +10     
==========================================
- Hits        17653    17279     -374     
- Misses       2824     3208     +384     
Impacted Files Coverage Δ
torch_geometric/nn/pool/decimation.py 100.00% <100.00%> (ø)
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) :arrow_down:
torch_geometric/nn/models/dimenet.py 14.90% <0.00%> (-52.76%) :arrow_down:
torch_geometric/profile/profile.py 36.27% <0.00%> (-26.48%) :arrow_down:
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) :arrow_down:
torch_geometric/nn/pool/asap.py 92.10% <0.00%> (-7.90%) :arrow_down:
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) :arrow_down:
torch_geometric/nn/dense/linear.py 87.40% <0.00%> (-5.93%) :arrow_down:
torch_geometric/transforms/add_self_loops.py 94.44% <0.00%> (-5.56%) :arrow_down:
torch_geometric/nn/models/attentive_fp.py 95.83% <0.00%> (-4.17%) :arrow_down:
... and 13 more

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

codecov[bot] avatar Aug 02 '22 16:08 codecov[bot]

I implemented RandLa-Net for segmentation as well, and made some small refactor. The model seems to learn quite well, and reaches 70% accuracy after 3 epochs. It takes ~1s to run on CPU.

Unfortunately, I am not able to fully test it out on ShapeNet's Airplane task due to pytorch conflicts that prevent me to use CUDA :/

I need to install the master branch of pyg to run segmentation training. When I follow instructions to build pytorch geometric from source, I see that I am working on a machine with CUDA 11.4, and that I therefore have to build pytorch against CUDA 11.4 before installing dependencies (torch_scatter, etc), and then pytorch_geometric from master branch directly. However, it seems that pytorch + CUDA 11.4 is not really supported - I could not find how to build it from source, and using cudatoolkit=11.4 in pytorch's conda install does not work.

Maybe I am missing something here... Any help would be appreciated :)

CharlesGaydon avatar Aug 04 '22 17:08 CharlesGaydon

I suggest to simply install the wheels with CUDA 11.3, this should work even for CUDA 11.4.

rusty1s avatar Aug 05 '22 13:08 rusty1s

Thank you, this worked like a charm. I think this is ready for review. :)

Right now both scripts follow the paper's architecture (in terms of hyperparameters, depth and number of channels in MLPs). Those were chosen by authors for large scale aerial lidar, not ModelNet and ShapeNet. For ModelNet, removing a few layer enables to reach good accuracy. I was not able to replicate this for ShapeNet, which quickly plateaus around 70% train accuracy (vs. 90% train accuracy / 79% test IoU for PointNet++).

I think it is cleaner to keep everything as it is to follow the paper. We could also change the benchmark for this model (with e.g. S3DIS), but I am not really sure that this is worth the extra work for an example.

CharlesGaydon avatar Aug 06 '22 10:08 CharlesGaydon

I identified some differences between my implementation and the original paper (in particular in terms of batch norms, activations, and number of channels). Will come back with fixes and modifications!

CharlesGaydon avatar Aug 25 '22 16:08 CharlesGaydon

Thanks! Sorry for the delay in review. Please keep me posted.

rusty1s avatar Aug 26 '22 05:08 rusty1s

Here comes the update! I made a few changes:

  • Set the BatchNorm and LeakyReLU with parameters from the paper (in particular: 0.99 momentum for BN, 0.2 alpha for lrelu, like in tensorflow).
  • Made a default_MLP class with plain_last=False, because by default class torch_geometric.nn.MLP disables BN, activation, and dropout for the last layer.
  • Added a missing Linear layer at the beginning of the network.
  • Changed the meaning of "d_out" in DilatedResidualBlock so that it is the actual number of channel that this block produces. This makes the dimensionnality of successive layers easier to read

For the classification task, I commented out two of the DilatedResidualBlock, because, as mentionned earlier, performances are way better for ModelNet with a smaller network.

CharlesGaydon avatar Aug 29 '22 12:08 CharlesGaydon

With these modifications :

  • ModelNet : 85%-88% accuracy after the ~50th epoch, with only two DilatedResidualBlocks instead of 4. (target in leaderbord: >=90%)
  • ShapeNet : Loss: 0.6704 Train Acc: 0.7500 Test IoU: 0.4746 at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)

There may be an issue with my implementation for the segmentation task, but maybe it is simply a matter of model size / configuration vs. the task (as it is the case for classification where a smaller model works way better).

@rusty1s Ready for review. :)

CharlesGaydon avatar Aug 29 '22 13:08 CharlesGaydon

I realized that the flow direction of the MessagePassing class responsible for the summarization of local neighborhood was inverted. I got immediate improvement for the segmentation task once fixed. ShapeNet : Loss: 0.3035 Train Acc: 0.8881Test IoU: 0.6739 at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)

EDIT: not so sure it was inverted, I reverted back.

CharlesGaydon avatar Aug 30 '22 09:08 CharlesGaydon

https://github.com/pyg-team/pytorch_geometric/pull/5117/commits/1e544ca19300ac07d4b62713e132067e9cc82b25 : I also fixed the knn operation: first LFA aggregated only on neighborhood points, which means that the second LFA was working on a mixture of input features (for other points) and lfa outputs (for neighborhood centers). Correcting this leads to way faster convergence. Crazy how everything can still run in the presence of logic

--> Reaching 92.1% Test IoU for ModelNet at epoch 129 (with two ResidualBlocks) :1st_place_medal:

CharlesGaydon avatar Sep 01 '22 12:09 CharlesGaydon

I think one last thing that I have to change is linked to the last upsampling :)

In the paper, this diagram seems to mean that the output of the first FC layer is used for the last skip connection. image But on the author's reference tensorflow implementation as well as in aRI0U's implementation (they are structurally very similar), it is different: the unsampled output of the first dilated residual block is used instead. This makes sense, and I just realized that this is mentionned in the paper as well:

Next, the upsampled feature maps are concatenated with the intermediate feature maps produced by encoding layers through skip connections,

CharlesGaydon avatar Sep 01 '22 13:09 CharlesGaydon

And voilà :100: ! We reach PointNet++-comparable performances on ShapeNet's Airplane class after a few epochs :) ShapeNet : Loss: 0.2036 Train Acc: 92%, Test IoU: 82.7% at epoch 30 (target: 90% train accuracy / 79% test IoU for PointNet++)

@rusty1s :)

CharlesGaydon avatar Sep 01 '22 14:09 CharlesGaydon

I also made an implementation and tested for 60 epochs. Reach similar results. The highest is at epoch 58 with IoU of 0.83.

Epoch: 58, Test IoU: 0.8301
[10/40] Loss: 0.2139 Train Acc: 0.9232
[20/40] Loss: 0.2094 Train Acc: 0.9241
[30/40] Loss: 0.2041 Train Acc: 0.9247
[40/40] Loss: 0.2030 Train Acc: 0.9235

I just put it here for simplicity. The code is relatively easier thanks to the point_transformer_segmentation example.

import os.path as osp
from typing import Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Linear, Sequential, ReLU, Dropout, LeakyReLU

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Linear, knn_interpolate, MessagePassing, knn_graph
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import intersection_and_union as i_and_u
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax


class SharedMLP(torch.nn.Module):
    def __init__(self, in_channels, out_channels, bn=False, act=None):
        super(SharedMLP, self).__init__()
        self.lin = Linear(in_channels, out_channels)
        if bn and bn is not None:
            self.bn = torch.nn.BatchNorm1d(out_channels, eps=1e-6, momentum=0.99)
        else:
            self.bn = None

        self.act = act

        self.lin.reset_parameters()
        if self.bn:
            self.bn.reset_parameters()

    def forward(self, x):
        x = self.lin(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.act is not None:
            x = self.act(x)
        return x


class RandlaConv(MessagePassing):
    def __init__(self, d, d_slash, add_self_loops: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(RandlaConv, self).__init__(**kwargs)
        self.mlp_rppe = SharedMLP(10, d, bn=True, act=ReLU())  # Relative Point Position Encoding
        self.mlp_att = Linear(2 * d, 2 * d, bias=False)
        self.mlp_post_pool = SharedMLP(2 * d, d_slash, bn=True, act=ReLU())
        self.add_self_loops = add_self_loops

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.mlp_att)

    def forward(self,
                x: Union[Tensor, PairTensor],
                pos: Union[Tensor, PairTensor],
                edge_index: Adj
                ) -> Tensor:

        if isinstance(x, Tensor):
            x: PairTensor = (x, x)

        if isinstance(pos, Tensor):
            pos: PairTensor = (pos, pos)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(
                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))

        # propagate_type: (x: PairTensor, pos: PairTensor)
        out = self.propagate(edge_index, x=x, pos=pos, size=None)
        f_tilde = self.mlp_post_pool(out)
        return f_tilde

    def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,
                index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor:
        delta = pos_i - pos_j
        dist = delta.norm(dim=-1, keepdim=True)
        rp = torch.cat([pos_i, pos_j, delta, dist], dim=-1)
        r = self.mlp_rppe(rp)

        # makesure the dim of r and xj are both d and the size are the same
        assert x_j.shape == r.shape
        f_hat = torch.cat([x_j, r], dim=-1)

        # g(f,W)
        f_hat = self.mlp_att(f_hat)
        # do softmax along the KNN
        s = softmax(f_hat, index, ptr, size_i)
        f_hat = s * f_hat
        # propagate will do sum aggregation after attentive pooling between KNN
        return f_hat


class DilatedResidualBlock(torch.nn.Module):
    def __init__(self, d_in, d_out, k=16):
        super(DilatedResidualBlock, self).__init__()
        self.mlp_start = SharedMLP(d_in, d_out // 2, act=LeakyReLU(0.2))
        self.mlp_end = SharedMLP(d_out, 2 * d_out)
        self.mlp_skip = SharedMLP(d_in, 2 * d_out, bn=True)

        self.locse_ap1 = RandlaConv(d_out // 2, d_out // 2)
        self.locse_ap2 = RandlaConv(d_out // 2, d_out)

        self.lrelu = torch.nn.LeakyReLU()
        self.k = k

    def reset_parameters(self):
        reset(self.lrelu)

    def forward(self, x, pos, batch):
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x_skip = self.mlp_skip(x)  # 2*dout

        x = self.mlp_start(x)  # dout / 2
        x = self.locse_ap1(x, pos, edge_index)  # dout / 2
        x = self.locse_ap2(x, pos, edge_index)  # dout
        x = self.mlp_end(x)  # 2*dout
        x = x + x_skip
        x = self.lrelu(x)  # 2 * dout
        return x


class FPModule(torch.nn.Module):
    """Upsampling with a skip connection."""

    def __init__(self, k, nn):
        super().__init__()
        self.k = k
        self.nn = nn

    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.nn(x)
        return x, pos_skip, batch_skip


def inverse_permutation(perm):
    inv = torch.empty_like(perm)
    inv[perm] = torch.arange(perm.size(0), device=perm.device)
    return inv


class Net(torch.nn.Module):
    def __init__(self, num_features, num_classes, k=16, decimation=4):
        super(Net, self).__init__()
        self.k = k  # knn
        self.decimation = decimation  # decimation ratio

        self.fc_start = SharedMLP(num_features, 8, bn=True, act=LeakyReLU(0.2))

        # encoder
        self.module_down = torch.nn.ModuleList([
            DilatedResidualBlock(8, 16, self.k),
            DilatedResidualBlock(32, 64, self.k),
            DilatedResidualBlock(128, 128, self.k),
            DilatedResidualBlock(256, 256, self.k)
        ])

        self.mlp_summit = SharedMLP(512, 512, act=ReLU())

        self.module_up = torch.nn.ModuleList([
            FPModule(1, SharedMLP(512 + 512, 256, bn=True, act=ReLU())),
            FPModule(1, SharedMLP(256 + 256, 128, bn=True, act=ReLU())),
            FPModule(1, SharedMLP(128 + 128, 32, bn=True, act=ReLU())),
            FPModule(1, SharedMLP(32 + 32, 8, bn=True, act=ReLU()))
        ])
        self.mlp_cls = Sequential(
            SharedMLP(8, 64, bn=True, act=ReLU()),
            SharedMLP(64, 32, bn=True, act=ReLU()),
            Dropout(),
            SharedMLP(32, num_classes)
        )

    def forward(self, x, pos, batch, ptr):
        # create random permutation for each batch
        B = ptr.size()[0] - 1
        batch_sizes = torch.Tensor([ptr[i + 1] - ptr[i] for i in range(B)]).long()

        indices = [torch.randperm(batch_sizes[i], dtype=torch.int64, device=ptr.device) + ptr[i] for i in range(B)]
        indices = torch.cat(indices)
        inverse_indices = inverse_permutation(indices)

        x = x[indices, :]
        pos = pos[indices, :]
        batch = batch[indices,]

        out_x = []
        out_pos = []
        out_batch = []

        x = self.fc_start(x)

        for encoder in self.module_down:
            # KNN aggregation
            x = encoder(x, pos, batch)
            out_x.append(x)
            out_pos.append(pos)
            out_batch.append(batch)

            # select first part of ::decimated indices in each batch
            batch_sizes = batch_sizes // self.decimation
            indices_decimated = torch.cat([
                torch.arange(batch_sizes[i], device=ptr.device) + ptr[i] for i in range(B)
            ])
            # update ptr
            for i in range(B):
                ptr[i + 1] = ptr[i] + batch_sizes[i]
            x, pos, batch = x[indices_decimated, :], pos[indices_decimated, :], batch[indices_decimated,]

        x = self.mlp_summit(x)

        for i, decoder in enumerate(self.module_up):
            skip_x = out_x.pop()
            skip_pos = out_pos.pop()
            skip_batch = out_batch.pop()

            x, pos, batch = decoder(x, pos, batch, skip_x, skip_pos, skip_batch)

        x = self.mlp_cls(x)
        x = x[inverse_indices, :]
        return x.log_softmax(dim=-1)


if __name__ == '__main__':

    category = 'Airplane'  # Pass in `None` to train on all categories.
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')
    transform = T.Compose([
        T.RandomTranslate(0.01),
        T.RandomRotate(15, axis=0),
        T.RandomRotate(15, axis=1),
        T.RandomRotate(15, axis=2)
    ])
    pre_transform = T.NormalizeScale()
    train_dataset = ShapeNet(path, category, split='trainval', transform=transform,
                             pre_transform=pre_transform)
    test_dataset = ShapeNet(path, category, split='test',
                            pre_transform=pre_transform)
    train_loader = DataLoader(train_dataset, batch_size=60, shuffle=True,
                              num_workers=6)
    test_loader = DataLoader(test_dataset, batch_size=60, shuffle=False,
                             num_workers=6)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net(3, train_dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.002)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)


    def train():
        model.train()

        total_loss = correct_nodes = total_nodes = 0
        for i, data in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data.x, data.pos, data.batch, data.ptr)
            loss = F.nll_loss(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()
            total_nodes += data.num_nodes

            if (i + 1) % 10 == 0:
                print(f'[{i + 1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '
                      f'Train Acc: {correct_nodes / total_nodes:.4f}')
                total_loss = correct_nodes = total_nodes = 0


    @torch.no_grad()
    def test(loader):
        model.eval()

        y_mask = loader.dataset.y_mask
        ious = [[] for _ in range(len(loader.dataset.categories))]

        for data in loader:
            data = data.to(device)
            pred = model(data.x, data.pos, data.batch, data.ptr).argmax(dim=1)

            i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch)
            iou = i.cpu().to(torch.float) / u.cpu().to(torch.float)
            iou[torch.isnan(iou)] = 1

            # Find and filter the relevant classes for each category.
            for iou, category in zip(iou.unbind(), data.category.unbind()):
                ious[category.item()].append(iou[y_mask[category]])

        # Compute mean IoU.
        ious = [torch.stack(iou).mean(0).mean(0) for iou in ious]
        return torch.tensor(ious).mean().item()


    for epoch in range(1, 61):
        train()
        iou = test(test_loader)
        print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')
        scheduler.step()

saedrna avatar Sep 02 '22 02:09 saedrna

Thank you very much for sharing your implementation @saedrna. That was really helpful in order to give final touches of simplification, better readibility. I cherry-picked the following elements:

  • SharedMLP as a class instead of a function
  • knn_graph instead of knn
  • Use of pyg's softmax instead of torch's softmax
  • Use of ptr for decimation
  • Decimation that happens in main forward loop, which removes a special case where we needed non-decimated block output.

I prefer to avoid loops for encoders/decoders in order to keep everything easy to read, which is important in an example - I thus keep the structure of the PointNet++ example.

I think we are a go :)

CharlesGaydon avatar Sep 02 '22 10:09 CharlesGaydon

@saedrna As a side comment, if you still wand to use your implementation with a loop, I have the following comments:

  • Authors of RandLA-Net used LeakyReLU(negative_slope=0.2) instead of standard ReLU() activations.
  • I think you would always want to have self_loops in your knn_graph, so the features of the neighborhood centroid are also considered as part of the local feature encoding. You can do this by specifying loop=True in knn_graph :)

CharlesGaydon avatar Sep 02 '22 10:09 CharlesGaydon

Looks good to me~ Just one more thing: one of the beautiful things for RandLA is that, although it's called the selection of random decimation points during the transition-down, in implementation we don't really need actually to permute every time. Just permute the input and choosing every decimated slice in the first parts should be more elegant (like x=x[perm], x=x[:N/d], x=x[inv_perm]). Just don't forget to permute back. I learned this trick from the Pytorch implementation. But he uses sort to permute back, which is O(nlogn), the O(n) solution exists with a temporary array.

https://github.com/aRI0U/RandLA-Net-pytorch/blob/057035adf587a4e377e431f54654a090e53740f2/model.py#L233-L300

saedrna avatar Sep 02 '22 14:09 saedrna

Looks good to me

Nice. :)

~ Just one more thing: one of the beautiful things for RandLA is that, although it's called the selection of random decimation points during the transition-down, in implementation we don't really need actually to permute every time. Just permute the input and choosing every decimated slice in the first parts should be more elegant (like x=x[perm], x=x[:N/d], x=x[inv_perm]). Just don't forget to permute back. I learned this trick from the Pytorch implementation. But he uses sort to permute back, which is O(nlogn), the O(n) solution exists with a temporary array.

https://github.com/aRI0U/RandLA-Net-pytorch/blob/057035adf587a4e377e431f54654a090e53740f2/model.py#L233-L300

I see te appeal as well. I actually encountered empty clouds on my data with this solution: I made sure that num_nodes in each cloud was slightly above decimation**4, but the opration x=x[:N/d] would not decimate permutated clouds homogeneously. This would result in empty clouds and errors.

I thus prefer to resort to either your concatenation i.e.

            indices_decimated = torch.cat([
                torch.arange(batch_sizes[i], device=ptr.device) + ptr[i] for i in range(B)
            ])

or mine, which leverages a randperm:

    idx_decim = torch.cat(
        [
            (ptr[i] + torch.randperm(decimated_num_nodes[i], device=ptr.device))
            for i in range(batch_size)
        ],
        dim=0,
    )

And if we have to cat things, might as well avoid a confusing permutation + inverse permutation.

I am actually thinking of adding a +1 to decimated_num_nodes[i] to always have at least one node even if the input cloud is too small. That would be a nice trick to avoid any unforeseen error and more complex preprocessing. Any objection to that strategy? :)

EDIT: that might be a good idea for my use cases but outside of the scope of this example.

CharlesGaydon avatar Sep 05 '22 12:09 CharlesGaydon

I think everyone has his own flavor for coding and I'm fine with this.

saedrna avatar Sep 07 '22 15:09 saedrna

I found an issue, the momentum parameter in PyTorch and Tensorflow is different. So it should be 0.01 rather than 0.99.

saedrna avatar Sep 10 '22 10:09 saedrna

I found an issue, the momentum parameter in PyTorch and Tensorflow is different. So it should be 0.01 rather than 0.99.

Indeed. Good catch, thank you, I'll update it soon. It is a shame because I just got my best model so far on my own data with this implementation. Hope it last with more momentum ^^

CharlesGaydon avatar Sep 12 '22 12:09 CharlesGaydon

@rusty1s You asked to be kept posted, so here is a gentle bump for review :).

CharlesGaydon avatar Sep 20 '22 15:09 CharlesGaydon

Super, let me do a final review tomorrow:) Thanks for the amazing progress!

rusty1s avatar Sep 20 '22 17:09 rusty1s

Thanks for adding this. This looks great already! Left some comments. Let me know if you have any questions.

My pleasure. Thanks for the great review, I'll be able to look into it next week. :)

CharlesGaydon avatar Sep 22 '22 08:09 CharlesGaydon

Super, please ping me :)

rusty1s avatar Sep 22 '22 09:09 rusty1s

Hi @rusty1s, I took most of your comment into account. I made a few replies as well for clarification. Last item on the list would be moving the decimation pooling function into nn.pool as you suggested. I will have the time to do so mid-October and hopefully finalize this PR :100:

CharlesGaydon avatar Sep 29 '22 15:09 CharlesGaydon

@rusty1s Hi! Back to the office! I moved the function to get decimation indices to its own decimation module under torch.nn.pool, along with a few simple tests. On a side note, I always wonder what ptr stands for (maybe pointer?). Feel free to edit its docstring if needed (currenly: ptr (LongTensor): indices of samples in the batch.).

CharlesGaydon avatar Oct 14 '22 13:10 CharlesGaydon

@rusty1s @saedrna The gentlest bump on this :)

CharlesGaydon avatar Oct 21 '22 07:10 CharlesGaydon

Yes, I will merge this over the weekend. Sorry for the delay.

rusty1s avatar Oct 21 '22 07:10 rusty1s

@rusty1s Resolved the conflict in Changelog :)

CharlesGaydon avatar Nov 03 '22 10:11 CharlesGaydon

@rusty1s Small bump :)

CharlesGaydon avatar Nov 14 '22 15:11 CharlesGaydon