icoCNN
icoCNN copied to clipboard
icoCNN.tools.rotate_signal not working for r>2
I was trying to use ico cnn to do some experiments about SO(3) rotations.But I was wondering how to prove this network is equivariant to the rotations.On one hand I rotated the signals first and sent them to the net, on the other hand, I sent the signals directly to the net and rotated the output. I compared them but they turned out to be different. I'd appreciate it if you could show me how to test the equivariance of icoCNN
The icoCNNs are only equivariant to the 60 icosahedral rotations, not to the continuous space SO(3). This can approximate the equivariance to the spherical rotations but you won't obtain exact results in the experiment that you proposed, especially with untrained models. You can find more information about this in section 5.1 of the original paper by Cohen et al.
The library includes a couple of functions to generate and apply the icosahedral rotations, you should obtain almost perfect results with your experiment when applying those even with untrained models (of course, numerical errors might make the results not being exactly equal). You can find the documentation of these functions here: https://github.com/DavidDiazGuerra/icoCNN/blob/master/tools.md
Thanks for your response. Actually I am applying the icosahedral rotations, and the approximate equivariance is all that I need, here is my test code for equivariance:
`icocnn = icoCNN.ConvIco(4,1,1,6,6).to(device)
x = torch.randn(1,6,5,16,32).to(device)
mat = icoCNN.tools.random_icosahedral_rotation_matrix()
rotx = icoCNN.tools.rotate_signal(x,mat).to(device)
y = icocnn(x)
roty = icocnn(rotx)
rotated_y = icoCNN.tools.rotate_signal(y,mat).to(device)
result = torch.allclose(rotated_y, roty, atol=1e-1)`
The result turns out to be False,I am wondering if there is any mistake about this test code? Thanks!
I've been doing some tests and it seems like the problem is in the function icoCNN.tools.rotate_signal. It works well for r=2 but it seems to be unstable for higher resolutions. The implementation of the function is not very good (it applies the rotation on the icosahedral grid using the rotation matrix and then tries to find the equivalent pixel permutation to apply the same to the signal) and there should be better ways to do it, but unfortunately I don't have time to work on this at the moment.