KPConv-PyTorch icon indicating copy to clipboard operation
KPConv-PyTorch copied to clipboard

Toronto3D process problems

Open GeoSur opened this issue 2 years ago • 55 comments

Hello mr.Thomas.Thank you for your wonderful work!!And I have met a new question: I have modified this dataset as the same as s3dis, but when i run the triaing process, it raised "ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:If you notice unstability, reduce the expected_N valueIf convergece is too slow, increase the expected_N value" I carefully compared my modified data with s3dis, is this because the xyz coordinate value of my input data is too large?((I have reduced the xyz coordinate value as the dataset instructions, but they are still greater than 100)

Another question: How to select an area for validation? Simply change the validation_split in s3dis?

GeoSur avatar Mar 22 '22 11:03 GeoSur

it raised "ERROR: It seems that the calibration have not reached convergence. Here are some plot to understand why:If you notice unstability, reduce the expected_N value If convergece is too slow, increase the expected_N value"

This error happens because the number of points per batch is too big or too small. Their should be plots appearing when this error happens, could you show them here?

is this because the xyz coordinate value of my input data is too large?

No, it should be because of the number of points per batch. The in_radius and first_subsampling_dl parameters control the number of points per sphere. And the batch_num parameter controls the number of spheres per batch.

Another question: How to select an area for validation? Simply change the validation_split in s3dis?

yes indeed

HuguesTHOMAS avatar Mar 22 '22 13:03 HuguesTHOMAS

This error happens because the number of points per batch is too big or too small. Their should be plots appearing when this error happens, could you show them here?

Thanks for your reply! Here are the results: 截屏2022-03-23 09 01 36 截屏2022-03-23 09 06 35

GeoSur avatar Mar 23 '22 01:03 GeoSur

The first photo shows the dataset which I have modified type as s3dis

GeoSur avatar Mar 23 '22 01:03 GeoSur

Ok from what I see, during calibration, the batch limit (which is the maximum of points allowed in a batch) is very low around 7000, where it should be between 50000 and 100000. This can be because of two reasons in my opinion:

  • If you modified in_radius or first_subsampling_dl, meaning you have smaller spheres or a larger subsampling grid, this automatically reduces the number of points per sphere and thus per batch.
  • if you did not, which I think is the case, the reason for fewer points is simply that the average density of your dataset is much lower than S3DIS.

In case this is about density, you have two ways to solve the issue:

  1. Increase in_radius, which will automatically increase the number of points per batch. This is also a good thing if you have large objects like cars to classify, the network will have a larger volume to process. Also I assume your dataset has varying densities like Semantic3D for example, so you want you first_subsampling_dl to still be small, but the general rule for this is parameters is to be small enough to capture details about the shape of the smallest object you are trying to classify. If you are only interested in cars and pedestrians, you can make it larger like 0.12m for example. But if you are interested in smaller objects like bollards, low vegetation, or road markings, you should probably keep it small like between 0.03 and 0.05.

  2. If you are happy with the size of your spheres and the subsampling grid, but still have a low number of points per sphere, you can simply increase the number of spheres per batch with the parameter batch_num. This is the most simple way to control the number of points per batch and optimize the network memory consumption on your GPU. Generally speaking, you can go as high as 100000 points per batch with a good GPU (12 GB), even higher if you have a bigger one. But this depends on a lot of thing, so just look at what happens in your case.

Best, Hugues

HuguesTHOMAS avatar Mar 23 '22 13:03 HuguesTHOMAS

Also I assume your dataset has varying densities like Semantic3D for example

As you assumed this dataset is a mobile LiDAR dataset too, so I will modify this code in the two ways you mentioned.

Increase in_radius and first_subsampling_dl

Actually, I am not very certain about the logic of setting this parameter in_radius, how does it correlate with the batch limit, or can it be adjusted by multiples of the batch limit?

Increase the number of batch_num

And I will also try to increase the number of batch_num with a good GPU, and compare the results of these two modification methods.

Thaks a lot for your patient and professional reply!!!

GeoSur avatar Mar 24 '22 01:03 GeoSur

After my attempts, I modified in_radius and first_subsampling_dl as the photo shows, this trainer runed successfully, but few minutes later, this error raised: 截屏2022-03-24 10 57 02 I assume this error is happening in the validation process, but I do not know why. Do you have any advice?

GeoSur avatar Mar 24 '22 03:03 GeoSur

Is this because I modified the .ply file of "Toronto" dataset directly instead of pre-processing it as a .txt file? The original format of this dataset is .ply files, so I just modified them as s3dis, Maybe I caused this error? How to solve it... Another diffidence is that Toronto3d has an "unlabeled" class.

GeoSur avatar Mar 24 '22 07:03 GeoSur

Actually, I am not very certain about the logic of setting this parameter in_radius, how does it correlate with the batch limit, or can it be adjusted by multiples of the batch limit?

in_radius is the radius of the input spheres that the network processes. Bigger spheres mean larger portions of the dataset are fed to the network, it also means more points and thus more memory consumption.

batch_num is the number of distinct spheres you want in your batch. It is actually an average batch size, as explained on page 11 of our paper.

The two parameters are not correlated, but together, they will determine the total number of points your batch contains (sum of the number of points of each sphere).

After my attempts, I modified in_radius and first_subsampling_dl as the photo shows, this trainer ran successfully, but a few minutes later, this error raised:

It seems that you have an empty truth vector: The truth values are stored in a 0D array.

In trainer.py just before line 539, try to print the shape of truth and preds, if both are [0], then it means you have an empty input sphere

HuguesTHOMAS avatar Mar 24 '22 13:03 HuguesTHOMAS

I try to print the shape of "true", and I find there is "0" at the last line. which factor may cause this issue?

Are there non zero values for this shape before the 0 happens? If yes, then there is a problem when selecting a particular input sphere, if no, then their is a problem with the whole validation split

HuguesTHOMAS avatar Mar 24 '22 13:03 HuguesTHOMAS

In trainer.py just before line 539, try to print the shape of truth and preds, if both are [0], then it means you have an empty input sphere

I have tried to print the shape of "truth" and I find "0" in the last line, and it is not always "0". have tested this error several times and found that it occurs basically in the first three times in validation processes, but it seems to happen randomly.

What should I do to solve it? Is this because of the problem with the calibration process?

GeoSur avatar Mar 24 '22 13:03 GeoSur

I have tried to print the shape of "truth" and I find "0" in the last line, and it is not always "0". have tested this error several times and found that it occurs basically in the first three times in validation processes, but it seems to happen randomly.

Ok so this confirms that the error happens when you select a particular area of the validation point cloud. The sphere are picked randomly, so this is why it does not happen always at the same time.

What you can do:

If you go back to trainer.py and follow where truth comes from, it is stored here:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/utils/trainer.py#L507-L509

Place a if statement to track the error and print everything you can to understand the error:


if target.shape[0] < 1:
    print('batch_ind', b_i)
    print('length', length)
    print('probs shape', probs.shape)
    print('point indices', inds)
    print('cloud index', c_i )

you can even print the value of the points themselves to know where in the dataset the sphere was:


if target.shape[0] < 1:

    in_pts = batch.points[0].cpu().numpy()
    in_pts = in_pts[i0:i0 + length]

    print('in_pts shape', in_pts .shape)
    print('in_pts mean', np.mean(in_pts, axis=0))

Depending on what is shown in these message, you should find the solution to your problem

HuguesTHOMAS avatar Mar 24 '22 13:03 HuguesTHOMAS

Thank you for your patient reply!!I even suspected that it was caused by too much difference between this data and s3dis, I had also considered switching to the kitti form, I am now going to locate the problem as you said. Hope I can solve it.

GeoSur avatar Mar 24 '22 14:03 GeoSur

Hello mr.Thomas.Thank you for your wonderful work!!And I have met a question: self.all_splits=[0,1,2,3,4,5,6,7,8,9] if my validation is:3,5,6 how can i set the :sel.validation_split =[?]

SC-shendazt avatar Mar 24 '22 15:03 SC-shendazt

Hi @SC-shendazt,

the code only handles one validation split in S3DIS.py, but it is very easy to modify:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L140-L157

You can define self.validation_splits = [3,5,6] and change

self.all_splits[i] (!=) == self.validation_split

by something like

self.all_splits[i] (not) in self.validation_splits

The only other place where validation_split is used is:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/utils/trainer.py#L431-L433

But you can change it with something like

for val_split_i in val_loader.dataset.validation_splits:
    if val_split_i not in val_loader.dataset.all_splits: 
        return 

This is just a safe check

HuguesTHOMAS avatar Mar 24 '22 15:03 HuguesTHOMAS

Thank you very much. I'll have a try first

SC-shendazt avatar Mar 24 '22 15:03 SC-shendazt

Hello, Mr.Thomas! I tried to place these codes below the

predictions.append (probs)
targets.append (target)
10 += length

like the photo shows, but it printed nothing, and I deleted the "if statement", it printed these lines: 2022-03-25 10-45-41屏幕截图 2022-03-25 10-46-14屏幕截图 Maybe I'm retarded, and I don't find a solution from the messages. By the way, is this because of the differences (like the density of points) between toronto3d and s3dis?

GeoSur avatar Mar 25 '22 03:03 GeoSur

mmmh sorry there is a mistake here. The line

i0 += length

should be after these lines:

    in_pts = batch.points[0].cpu().numpy()
    in_pts = in_pts[i0:i0 + length]

    print('in_pts shape', in_pts .shape)
    print('in_pts mean', np.mean(in_pts, axis=0))

Or this does not work

HuguesTHOMAS avatar Mar 25 '22 14:03 HuguesTHOMAS

Thank you for your reply! I changed it but it still did not work, so I deleted the "if statement". The print message is as the photo shows: 2022-03-26 11-07-20屏幕截图 2022-03-26 11-23-30屏幕截图 I guess this is because the parameter of in_radius or first_subsampling_dl, the input spheres of the validation process may be empty sometimes. So maybe I could change the parameters to solve this problem? But this is hard work because of the memory consumption and the density of dataset.

GeoSur avatar Mar 26 '22 03:03 GeoSur

@GeoSur “”in_radius or first_subsampling_dl “” Perhaps you could try different values for these two parameters eg:in_radius=16 first_subsampling_dl=0.2~0.X

SC-shendazt avatar Mar 26 '22 07:03 SC-shendazt

@SC-shendazt Thank you!! Actually i am trying to do this, but the gpu_memory of this computer is not large enough. so i have to try it on a more powerful device later...

GeoSur avatar Mar 26 '22 07:03 GeoSur

If the error is caused by an empty input sphere, you could probably place a safecheck in the function that gets these spheres and create the batch:

https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L231

I'll let you investigate that.

HuguesTHOMAS avatar Mar 26 '22 17:03 HuguesTHOMAS

@HuguesTHOMAS Thanks for your reply!!! I also think I should find my own solution to this problem! And thanks for your time!! I really appreciate it!!! Well another question: Could this Net work on the dataset which only has XYZ and true_class values?

GeoSur avatar Mar 27 '22 03:03 GeoSur

QQ图片20220328181827 QQ图片20220328181855 Hello mr.Thomas.,And I have met a new question,Could you give me some advice?

SC-shendazt avatar Mar 28 '22 10:03 SC-shendazt

Well, this seems pretty clear, your labels seem to have float64 dtype instead of integer. You should be able to solve this yourself by following the code where the labels are defined

HuguesTHOMAS avatar Mar 28 '22 13:03 HuguesTHOMAS

Thanks for your reply.!

SC-shendazt avatar Mar 28 '22 13:03 SC-shendazt

@HuguesTHOMAS Hello Thomas!It is me still..

Could this Net work on the dataset which only has XYZ and true_class values(without RGB features compared withS3DIS)?

GeoSur avatar Mar 28 '22 15:03 GeoSur

Sure, you just have to change the number of features to 1 in the configuration here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/train_S3DIS.py#L146

If you want more control on the input features you want to add, this is done here: https://github.com/HuguesTHOMAS/KPConv-PyTorch/blob/e600c1667d085aeb5cf89d8dbe5a97aad4270d88/datasets/S3DIS.py#L402-L411

HuguesTHOMAS avatar Mar 28 '22 17:03 HuguesTHOMAS

@HuguesTHOMAS Thank you for your reply and this wonderful work, that is really cool !!! But why the parameter "in_features_dim" should be modified from 5 to 1, in the S3DIS dataset, there are x, y, z, r, g, b, counting to 6 rows. Or the parameter is not correlated with rows of datasets? As I can see from this code it should create KDTree from the ply files, so it may still need an r,g,b rows or features?

GeoSur avatar Mar 29 '22 02:03 GeoSur

the features do not include x, y, z. For S3DIS they are: 1, r, g, b, h, where 1 is just a column of 1, and h is the height of the points in global coordinates. Please refer to the paper for more info on the features.

Changing to 1 will ignore additional features and just use the column of ones, which is the basic geometric feature.

As I can see from this code it should create KDTree from the ply files, so it may still need an r,g,b rows or features?

The KDTree is only computed on point (x, y, z) for local spherical neighborhoods

HuguesTHOMAS avatar Mar 29 '22 13:03 HuguesTHOMAS

@HuguesTHOMAS Thanks for your reply! Well the s3dis.py line736

 else:
            print('\nPreparing KDTree for cloud {:s}, subsampled at {:.3f}'.format(cloud_name, dl))

            # Read ply file
            data = read_ply(file_path)
            points = np.vstack((data['x'], data['y'], data['z'])).T
            colors = np.vstack((data['red'], data['green'], data['blue'])).T
            labels = data['class']

It still process the rgb features. And there seems not having an if statement.

GeoSur avatar Mar 29 '22 14:03 GeoSur