s2fft icon indicating copy to clipboard operation
s2fft copied to clipboard

Support L_output < L_signal for forward_jax sampling="dh"

Open mariogeiger opened this issue 7 months ago • 1 comments

Hi,

Would it be possible to compute only the low frequencies from a high resolution grid?

It seems to work with sampling = "mw" but not with sampling = "dh":

f_shape = samples.f_shape(L=10, sampling=sampling)
f = jax.random.normal(jax.random.PRNGKey(0), f_shape)
flm = forward_jax(f, L=6, sampling=sampling, reality=True)
assert flm.shape == samples.flm_shape(L=6)
reality=True:

ftm = ftm.at[:, L - 1 + m_offset :].set(t)
ValueError: Incompatible shapes for broadcasting: (20, 10) and requested shape (20, 14)

reality=False:

ftm = jnp.einsum("tm,t->tm", ftm, weights, optimize=True)
ValueError: Size of label 't' for operand 1 (20) does not match previous terms (12).

Maybe I should just use sampling = "mw" but my old code was using sampling = "dh" and I wanted to compare.

mariogeiger avatar Dec 05 '23 09:12 mariogeiger