Open3D-ML
Open3D-ML copied to clipboard
RandLaNet `update_probs` label smoothing logic is faulty
Checklist
- [X] I have searched for similar issues.
- [X] I have tested with the latest development wheel.
- [X] I have checked the release documentation and the latest documentation (for
masterbranch).
Describe the issue
The RandLaNet.update_probs is used to exponentially smooth the predicted results during inference. The best way of doing it should be that every time where there is a new prediction of a given point, the entire softmax distribution of that point is weighted between the old distribution and the new distribution. And in the end of the inference, the label of a given point is given by the argmax of this smoothed softmax distribution. However, based on the current behaviour, the final label is based on the argmax of the softmax distribution of the last prediction in stead of the smoothed prediction.
def update_probs(self, inputs, results, test_probs, test_labels):
"""Update test probabilities with probs from current tested patch.
Args:
inputs: input to the model.
results: output of the model.
test_probs: probabilities for whole pointcloud
test_labels: ground truth for whole pointcloud.
Returns:
updated probabilities and labels
"""
self.test_smooth = 0.95
for b in range(results.size()[0]):
result = torch.reshape(results[b], (-1, self.cfg.num_classes))
probs = torch.nn.functional.softmax(result, dim=-1)
probs = probs.cpu().data.numpy()
labels = np.argmax(probs, 1) <--- HERE
inds = inputs['data']['point_inds'][b]
test_probs[inds] = self.test_smooth * test_probs[inds] + (
1 - self.test_smooth) * probs
test_labels[inds] = labels <--- HERE
return test_probs, test_labels
Steps to reproduce the bug
It's a logical bug.
Error message
No response
Expected behavior
No response
Open3D, Python and System information
- Operating system: (e.g. OSX 10.15, Ubuntu 18.04, Windows 10 64-bit)
- Python version: (e.g. Python 3.8 / output from `import sys print(sys.version)`)
- Open3D version: (output from python: `print(open3d.__version__)`)
- System type: (x84 / arm64 / apple-silicon / jetson / rpi)
- Is this remote workstation?: yes or no
- How did you install Open3D?: (e.g. pip, conda, build from source)
- Compiler version (if built from source): (e.g. gcc 7.5, clang 7.0)
Additional information
No response