Open3D-ML icon indicating copy to clipboard operation
Open3D-ML copied to clipboard

RandLaNet `update_probs` label smoothing logic is faulty

Open chingyulin opened this issue 3 years ago • 0 comments

Checklist

Describe the issue

The RandLaNet.update_probs is used to exponentially smooth the predicted results during inference. The best way of doing it should be that every time where there is a new prediction of a given point, the entire softmax distribution of that point is weighted between the old distribution and the new distribution. And in the end of the inference, the label of a given point is given by the argmax of this smoothed softmax distribution. However, based on the current behaviour, the final label is based on the argmax of the softmax distribution of the last prediction in stead of the smoothed prediction.

    def update_probs(self, inputs, results, test_probs, test_labels):
        """Update test probabilities with probs from current tested patch.
        Args:
            inputs: input to the model.
            results: output of the model.
            test_probs: probabilities for whole pointcloud
            test_labels: ground truth for whole pointcloud.
        Returns:
            updated probabilities and labels
        """
        self.test_smooth = 0.95

        for b in range(results.size()[0]):

            result = torch.reshape(results[b], (-1, self.cfg.num_classes))
            probs = torch.nn.functional.softmax(result, dim=-1)
            probs = probs.cpu().data.numpy()
            labels = np.argmax(probs, 1)                                                    <--- HERE
            inds = inputs['data']['point_inds'][b]

            test_probs[inds] = self.test_smooth * test_probs[inds] + (
                1 - self.test_smooth) * probs
            test_labels[inds] = labels                                                         <--- HERE

        return test_probs, test_labels

Steps to reproduce the bug

It's a logical bug.

Error message

No response

Expected behavior

No response

Open3D, Python and System information

- Operating system: (e.g. OSX 10.15, Ubuntu 18.04, Windows 10 64-bit)
- Python version: (e.g. Python 3.8 / output from `import sys  print(sys.version)`)
- Open3D version: (output from python: `print(open3d.__version__)`)
- System type: (x84 / arm64 / apple-silicon / jetson / rpi)
- Is this remote workstation?: yes or no
- How did you install Open3D?: (e.g. pip, conda, build from source)
- Compiler version (if built from source): (e.g. gcc 7.5, clang 7.0)

Additional information

No response

chingyulin avatar Aug 23 '22 13:08 chingyulin