TSA
TSA copied to clipboard
the formules of Sigma2 in ISDA and TSA seem to be same,but code looks like different.
Hello,I read your paper,this is a impressive work! I notice that the caculate of sigma2 is different between TSA and ISDA. ISDA: sigma2 = (weight_m - NxW_kj).pow(2).mul(CV_temp.view(N, 1, A).expand(N, C, A)).sum(2) TSA:sigma2 = torch.bmm(torch.bmm(NxW_ij - NxW_kj, t_CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1))
ISDA use dot mul while your work use bmm , i want to know if there any meaningful difference between this two implementation? ignore【Lambda * datW_x_detaMean_NxC】,the formules in ISDA and TSA seem to be same,but code looks like different.
THANKS!
Actually, the sigma2 in ISDA that you give here is the implementation when we approximate the covariance matrices by their diagonals. This can be found in the "Training details" section of ISDA. Since the datasets we use are not as large as ImageNet, so we do not adopt the approximation. The implementation that we use can be found at https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/master/Image%20classification%20on%20CIFAR/ISDA.py
Thanks for your reply!