TSA icon indicating copy to clipboard operation
TSA copied to clipboard

the formules of Sigma2 in ISDA and TSA seem to be same,but code looks like different.

Open ShiyeLi opened this issue 3 years ago • 2 comments

Hello,I read your paper,this is a impressive work! I notice that the caculate of sigma2 is different between TSA and ISDA. ISDA: sigma2 = (weight_m - NxW_kj).pow(2).mul(CV_temp.view(N, 1, A).expand(N, C, A)).sum(2) TSA:sigma2 = torch.bmm(torch.bmm(NxW_ij - NxW_kj, t_CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1))

ISDA use dot mul while your work use bmm , i want to know if there any meaningful difference between this two implementation? ignore【Lambda * datW_x_detaMean_NxC】,the formules in ISDA and TSA seem to be same,but code looks like different.

THANKS!

ShiyeLi avatar Apr 12 '21 11:04 ShiyeLi

Actually, the sigma2 in ISDA that you give here is the implementation when we approximate the covariance matrices by their diagonals. This can be found in the "Training details" section of ISDA. Since the datasets we use are not as large as ImageNet, so we do not adopt the approximation. The implementation that we use can be found at https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/master/Image%20classification%20on%20CIFAR/ISDA.py

Xiemixue avatar Apr 16 '21 13:04 Xiemixue

Thanks for your reply!

ShiyeLi avatar Apr 30 '21 09:04 ShiyeLi