s2cnn
s2cnn copied to clipboard
Equivariance error issue
Hi! Thank you for this work!
Following paragraph 5.1 of your paper, I am trying to check the equivariance of the convolution layers. I have a question about Figure 3. Is the error in 10^-6 magnitude for the four plots?
MAIN ISSUE: I tried to run the code below. It synthesizes 500 random feature maps and 500 random rotation around the Z-axis. Then, it measures the equivariance error: mean(STD(Rotation(Layer(DATA)) - Layer(Rotation(DATA))) / STD(Layer(DATA)))
import torch
import numpy as np
from s2cnn import S2Convolution
from s2cnn import s2_near_identity_grid
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Device
b_in = 26 # Bandwidth of the input data
c = 10 # Number of feature maps
grid_s2 = s2_near_identity_grid(n_alpha=6, max_beta=np.pi/32, n_beta=1) # S2 convolution
conv_S2 = S2Convolution(c, c, b_in, b_in, grid_s2)
conv_S2.to(device)
def phi(input, conv_S2):
input = conv_S2(input)
return input
def rot(input, angle):
n = round(input.size(3) * angle / 360) # rotate the signal around the Z axis
return torch.cat([input[:, :, :, n:], input[:, :, :, :n]], dim=3)
relative_error = 0
n = 500
for i in range(n):
with torch.no_grad():
x = torch.randn(1, c, b_in*2, b_in*2).to(device) # Random input [batch, feature, beta, alpha]
angle = np.random.randint(0, 360) # Random rotation around the Z-axis
y = phi(x, conv_S2) # Layer(DATA)
y1 = rot(phi(x, conv_S2), angle=angle) # Rotation(Layer(DATA))
y2 = phi(rot(x, angle=angle), conv_S2) # Layer(Rotation(DATA))
relative_error += torch.std(y1.data - y2.data) / torch.std(y.data)
relative_error /= n
print('relative error = {}'.format(relative_error))
The layer is an S2 convolution. I fixed the input and output bandwidth to be equal. I observed that the equivariance error depends a lot on that bandwidth.
-
For b_in<25, the error has a magnitude of 10^-7.
-
For b_in>25, the error has a magnitude of 10^-1.
I understand that the 10^-7 error is due to numerical precision and means that the layer is equivariant. But I don't know what I do wrong to get a 10^-1 error for a larger bandwidth. I currently work with data with bandwidth 32 and my model doesn't seem to be equivariant due to this problem. Does it come from numerical precision? Can I do something to solve it?
NOTE: I tried to run it on my CPU or GPU, but the problem is still here. I tried to change the size of the float but with no effect.
I also tried to change the output bandwidth to b_out=b_in//2.
-
For b_in<4, the error has a magnitude of 10^-7.
-
For b_in>4, the error has a magnitude of 10^-1.
Thanks for sharing the investigation.
I don't remember of anything that could explain this phenomenon.
Could you plot the outputs with plt.imshow
. Maybe by eye we can understand what's going on.
Here it is!
It looks like an aliasing artifact, localized on the equator.
Hi Axel
I am conducting some similar experiments and running to issues as well. Were you able to figure out what the problem was?
Thanks Suhas
I have no idea
Hi, I would also like to start by thanking you for making your great work available here.
For me this issue is not present on the CPU, but only on CUDA.
FIX: Change line 153 of s2_fft from
int l = powf(s, 0.5)
to
int l = sqrtf(s)
there might be an even more appropriate fix for square roots of ints, but I'm not sufficiently experienced with CUDA to know any.
On my computer with CUDA 10.1, int l = powf(625, 0.5)
yields 24 but int l = sqrtf(625)
yields the correct 25.
The issue is also present at l=49
and l=63
, highlighting the squaredness of the problem.
Some bug-tracing background:
I traced the problem to the S2-FFT of
![]()
For simplicity in the following, consider
A=0
,B=1
so that the function considered issin(beta)^25 * sin(25*alpha)
. The FFT of this function should be zero everywhere except for the coefficients corresponding tol=25
andm=±25
. For those coefficients, because 25 is odd, the real part should be zero and the imaginary part should be positive, and that is what happens at (almost all) other odd frequencies. At 25 however, the imaginary part of thel=25
,m=-25
coefficient is negative. See the minimal working example below.Minimal working example
import torch import numpy as np from lie_learn.spaces.S2 import meshgrid as S2_meshgrid from s2cnn.soft.s2_fft import S2_fft_real def test_fft_of_highest_order_harmonic(bw): beta, alpha = S2_meshgrid(bw, grid_type="SOFT") x = (np.sin(beta)**(bw - 1)) * np.sin((bw - 1)*alpha) for device in ["cpu", "cuda"]: x_fft = S2_fft_real.apply(torch.tensor(x[None, None, ...], dtype=torch.float32, device=device)) index_negative = -1 - 2*(bw - 1) # index of Y_(bw-1),(-bw+1) index_positive = -1 # index of Y_(bw-1),( bw-1) coeff_product = (x_fft[index_negative, 0, 0, 1] * x_fft[index_positive, 0, 0, 1]) if bw % 2 == 0: assert coeff_product > 0, \ ("The imaginary parts should have the same sign.\n" + f"Device: {device}, bandwidth: {bw}\n" + f"xfft, l=bw-1, m=-bw+1: {x_fft[index_negative, 0, 0, 1]: .2e}\n" + f" l=bw-1, m= bw-1: {x_fft[index_positive, 0, 0, 1]: .2e}") else: assert coeff_product < 0, \ ("The imaginary parts should have opposite signs.\n" + f"Device: {device}, bandwidth: {bw},\n" + f"xfft, l=bw-1, m=-bw+1: {x_fft[index_negative, 0, 0, 1]: .2e}\n" + f" l=bw-1, m= bw-1: {x_fft[index_positive, 0, 0, 1]: .2e}") if __name__ == "__main__": for bw in [24, 25, 26, 27]: test_fft_of_highest_order_harmonic(bw) print("Tests passed.") # AssertionError: The imaginary parts should have the same sign. # Device: cuda, bandwidth: 26 # xfft, l=bw-1, m=-bw+1: -2.93e-02 # l=bw-1, m= bw-1: 2.93e-02
Thanks a lot for investigating the problem and find a solution. Do you want to make a pull request ?