torch-harmonics icon indicating copy to clipboard operation
torch-harmonics copied to clipboard

Equivariance of SFNO convolution

Open caspervanbavel opened this issue 1 year ago • 1 comments

In the current implementation, it doesn't seem like the convolution layers defined here: https://github.com/NVIDIA/torch-harmonics/blob/main/torch_harmonics/examples/sfno/models/layers.py are actually equivariant.

The implementation uses complex coefficient for the learned kernels, while they should be purely real. (This should be obvious from the fact that the m=0 harmonics can't even have a complex part.)

A sort of hacky fix is to do replace line 274 with:

self.weight = nn.Parameter(scale * torch.randn(*weight_shape))

and then update the contraction: https://github.com/NVIDIA/torch-harmonics/blob/main/torch_harmonics/examples/sfno/models/contractions.py#L46

def contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    ac = torch.view_as_complex(a)
    bc = b.type(torch.cfloat)
    res = torch.einsum("bixy,kix->bkxy", ac, bc)
    return torch.view_as_real(res)

I only tested for operator_type == "driscoll-healy" as I'm not sure what the other ones are supposed to do.

caspervanbavel avatar Nov 29 '24 14:11 caspervanbavel

Hi @caspervanbavel , thank you for raising this.

Indeed, you are correct - an equivariant convolution would require real-valued filters as described in Driscoll & Healy, 1994. In practice, in many of the applications we are targeting, symmetry is weakly broken; for instance the earth rotating around it's axis. We find that in such applications the method works better if complex valued filters (which are not equivariant) are permitted. In fact, a constant phase applied to all m simultaeneously corresponds to a rotation around the polar axis.

I agree that it might be best to offer such a functionality, for usecases which require strict equivariance.

bonevbs avatar Dec 02 '24 12:12 bonevbs