s2fft icon indicating copy to clipboard operation
s2fft copied to clipboard

Support L_output < L_signal for forward_jax sampling="dh"

Open mariogeiger opened this issue 1 year ago • 1 comments

Hi,

Would it be possible to compute only the low frequencies from a high resolution grid?

It seems to work with sampling = "mw" but not with sampling = "dh":

f_shape = samples.f_shape(L=10, sampling=sampling)
f = jax.random.normal(jax.random.PRNGKey(0), f_shape)
flm = forward_jax(f, L=6, sampling=sampling, reality=True)
assert flm.shape == samples.flm_shape(L=6)
reality=True:

ftm = ftm.at[:, L - 1 + m_offset :].set(t)
ValueError: Incompatible shapes for broadcasting: (20, 10) and requested shape (20, 14)

reality=False:

ftm = jnp.einsum("tm,t->tm", ftm, weights, optimize=True)
ValueError: Size of label 't' for operand 1 (20) does not match previous terms (12).

Maybe I should just use sampling = "mw" but my old code was using sampling = "dh" and I wanted to compare.

mariogeiger avatar Dec 05 '23 09:12 mariogeiger

Thanks for the comment @mariogeiger ! This is indeed not supported at present since we assume the signal is sampled at resolution L corresponding to the bandlimit. All harmonic coefficients < L are then computed. While with MW sampling you are getting the right shape, I'm not sure the correct values are computed.

The simplest way to support your use case is to compute all harmonic coefficients and simply discard those for $L_\text{ouput} \leq \ell < L_\text{signal}$. Of course this would not be the most efficient. How frequently do you need to do transforms in this manner? If just a one-off (e.g. to downsample your data), I would recommend this approach. But if you need to do this many times then the performance hit will be substantial.

If you could provide a little further detail about your use case and how you considered this previously with DH sampling then we can explore implementing this in s2fft.

jasonmcewen avatar Dec 06 '23 17:12 jasonmcewen