LP_BNN icon indicating copy to clipboard operation
LP_BNN copied to clipboard

Batch ensemble

Open milliema opened this issue 3 years ago • 8 comments

Thanks for the great work! May I ask how did you re-implement the code for BatchEnsemble? Since the original official implementation is in Tensorflow, is there any Pytorch resources to refer to?

milliema avatar Nov 17 '21 07:11 milliema

To be 100% honest, I asked the paper's author if they had a draft of their code in Pytorch. Then based on that, I implemented BatchEnsemble. So I need to thank the authors of BatchEnsemble for their help. It is not an official implementation, but I checked that the results were consistent with their paper.

giannifranchi avatar Nov 17 '21 12:11 giannifranchi

Thank you for your quick response, the information is very helpful!

milliema avatar Nov 17 '21 14:11 milliema

Sorry to disturb again, I have some questions regarding the repeating operation. There seems difference in the repeated pattern in training and testing stages. In training, tile function is used to repeat the images to [x1,x1,x1,x2,x2,x2...] suppose n_models=3; In testing, torch.cat functions is used instead with repeated pattern [x1,x2,....x1,x2,...x1,x2...] (e.g. https://github.com/giannifranchi/LP_BNN/blob/751d1499eb6f794885d050c92ba06d34816bdbda/networks/batchensemble_layers.py#:~:text=A%2CB%2CC%5D%5D-,x%20%3D%20torch.cat(%5Bx%20for%20i%20in%20range(self.num_models)%5D%2C%20dim%3D0),-num_examples_per_model%20%3D%20int(x) Besides that, in the ensemble layers the same pattern [A, A, ...., B, B, .....C,C,....] is used to repeat alpha and gamma parameters. Will this lead to the misalignment in the calculation? I appreciate your explanation about the details. Please let me know if my understanding is incorrect. Thank you.

milliema avatar Nov 19 '21 10:11 milliema

Dear Milliema

Sorry to answer so late. I think there is no problem with the inference phase. If you read the paper carefully, they do not mention that they need to repeat n time the batch images in the training phase. Yet I realized that the rank one vectors: alpha and gamma, do not train perfectly. After multiple experiments and interactions with the authors, I realized that it also improves the performance of repeating the data during training. Regarding the training phase, each set of vector alpha and gamma randomly select data from the batch that might lead to seeing the same images multiple times. Yet that is probably linked to sampling theory and was out of the scope of my research. As I already answered to you the first time, most of the code for batchensemble does not come from me. Can you explain what you mean by " misalignment in the calculation"?

giannifranchi avatar Jan 02 '22 16:01 giannifranchi

Hi, thanks for the great work you have done.

Following your discussion with Milliema, you mentioned that

Yet I realized that the rank one vectors: alpha and gamma, do not train perfectly. After multiple experiments and interactions with the authors, I realized that it also improves the performance of repeating the data during training.

Could you please explain why repeating the data during training also improves the model performance? I actually found in this repository that by repeating the data, what the model receives are just multiple copies of the same data, i.e. the model is trained with [x1, x2, x3, x4] when the data is not repeated, while if the data is repeated the model is trained with [x1, x1, x1, x1, x2, x2, x2, x2, x3, x3, x3, x3, x4, x4, x4, x4]. I am a bit confused why this can be helpful for improving the model's performance?

Thanks again for your contribution and looking forward to seeing your reply!

ShwanMario avatar May 30 '22 15:05 ShwanMario

if you do not repeat the data during training, the weights alpha and gamma will never see all the data. So while the other weights will make an entire epoch, these weights will only make a fraction of an epoch. That is why it helps the training to repeat. I hope I am clear.

giannifranchi avatar May 31 '22 13:05 giannifranchi

Thanks for your immediate response.

When the data is not repeated, assume we have training samples [$x_1, x_2, x_3, x_4$] and num_models=2, then weights [$\alpha_1, \alpha_1, \alpha_2, \alpha_2$], thus $\alpha_1$ can see $x_1$ and $x_2$, while $\alpha_2$ can see $x_3$ and $x_4$.

However, my question is that, even with the repeated pattern in the code, still, alpha and gamma are able to see only a fraction of an epoch. E.g. your training data is [$x_1, x_2, x_3, x_4$], it becomes [$x_1, x_1, x_2, x_2, x_3, x_3, x_4, x_4$] after repeating it, and the weights alpha become [$\alpha_1, \alpha_1, \alpha_1, \alpha_1, \alpha_2, \alpha_2, \alpha_2, \alpha_2$]. Within this mode, $\alpha_1$ only sees $x_1$ and $x_2$, while $\alpha_2$ only sees $x_3$ and $x_4$, which I think is not what we want to achieve in this problem.

I hope I state my confusion clear. Thanks again for your reply.

ShwanMario avatar May 31 '22 21:05 ShwanMario

I agree with you that this is not the intended behavior. I will correct the code by the end of this month and apply x = torch.cat([x for i in range(num_models)], dim=0)

Thanks for the correction!

giannifranchi avatar Jun 06 '22 11:06 giannifranchi