KPConv-PyTorch icon indicating copy to clipboard operation
KPConv-PyTorch copied to clipboard

Costum Dataset regression task

Open Pippo809 opened this issue 2 years ago • 2 comments

Good morning, I wanted to use this repo to perform a regression task on a bunch of pointclouds. My dataset then consists of pcd's with their regression score (between 0 and 1). I was thinking of adapting the classification task for this purpose; however I don't understand very well how my dataset has to be organized. I was trying to reverse engineer the ModelNet40 Class but I can't understand very well how the getitem works and how it can be adapted for a dataset with only 1 class (but a regression target).

Pippo809 avatar Nov 17 '22 14:11 Pippo809

Hi @Pippo809,

First of all you will need to change the loss of the network here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/models/architectures.py#L151-L171

You can adapt this function for a regression task. I am not aware of the best ways to do the regression with a deep network, but in any case, I assume the labels tensor would be a float32 tensor with values between 0 and 1, instead of an int32 tensor with classes.

Therefore from there you can reverse engineer. The loss function is called here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/utils/trainer.py#L187-L190

and therefore the labels you need to modify are the batch.labels which are created here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/datasets/ModelNet40.py#L681-L704

where the input list is the one provided by the __getitem__ function. If you track the labels of this list to their origin, they are defined here:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/1defcd75cf7c0399704a6a9f63d3a550bfb8c1c9/datasets/ModelNet40.py#L175-L178

So from there, there are two things you need to do:

  1. Modify self.input_labels so that it contains your regression values instead of classes.
  2. Verify that all along the way to the loss function, you remove any .astype(np.int32) or similar functions that is applied on the labels, to keep the float values intact.

HuguesTHOMAS avatar Nov 17 '22 14:11 HuguesTHOMAS

Perfect! Thank you very much for the quick response. I'll see how it goes

Pippo809 avatar Nov 17 '22 14:11 Pippo809