s2fft
s2fft copied to clipboard
Support L_output < L_signal for forward_jax sampling="dh"
Hi,
Would it be possible to compute only the low frequencies from a high resolution grid?
It seems to work with sampling = "mw"
but not with sampling = "dh"
:
f_shape = samples.f_shape(L=10, sampling=sampling)
f = jax.random.normal(jax.random.PRNGKey(0), f_shape)
flm = forward_jax(f, L=6, sampling=sampling, reality=True)
assert flm.shape == samples.flm_shape(L=6)
reality=True:
ftm = ftm.at[:, L - 1 + m_offset :].set(t)
ValueError: Incompatible shapes for broadcasting: (20, 10) and requested shape (20, 14)
reality=False:
ftm = jnp.einsum("tm,t->tm", ftm, weights, optimize=True)
ValueError: Size of label 't' for operand 1 (20) does not match previous terms (12).
Maybe I should just use sampling = "mw"
but my old code was using sampling = "dh"
and I wanted to compare.