glass icon indicating copy to clipboard operation
glass copied to clipboard

Run individual `iternorms` on ell blocks

Open Saransh-cpp opened this issue 11 months ago • 3 comments

Add your issue here

The function mutates arrays in-place heavily, and it should be refactored to work with immutable arrays and data structures. Moreover, the correlations do not need to be regular and could be irregular/ragged to save memory; hence, we should investigate #382 in parallel with this issue.

Saransh-cpp avatar Jan 23 '25 10:01 Saransh-cpp

I have been looking into the code and the paper, and I think it would be best to have a separate JAX implementation for the function. The array mutations are a smart choice for CPUs as they help in keeping the memory overhead low. Changing the behavior to create copies will result in a bad performance on CPUs, and IIRC, we had a discussion that GLASS should run smoothly on both the architectures as both of them are required for specific needs.

JAX indexing syntax does create copies of arrays when running on CPU, but it does not when JIT compiled. So, the JAX implementation would be accelerated on GPUs and memory efficient on CPUs with JIT compilation.

Saransh-cpp avatar Jan 25 '25 12:01 Saransh-cpp

cc: @ntessore

Saransh-cpp avatar Jan 27 '25 14:01 Saransh-cpp

Maybe it makes sense to split this issue into two separate tasks:


On the science side: figure out a way to have different values of ncorr for different $\ell$ (as in $a_{\ell m}$). If we write $a_{\ell m}^{(g)}$ for the alms of "generation" $g$, then we currently have a rectangular array of previous samples:

\left.\begin{matrix}
a_{00}^{(1)} & \ldots & a_{00}^{(\mathtt{ncorr})} \\
a_{10}^{(1)} & \ldots & a_{10}^{(\mathtt{ncorr})} \\
a_{20}^{(1)} & \ldots & a_{20}^{(\mathtt{ncorr})} \\
\vdots & \vdots & \vdots
\end{matrix}\right\} \; m=0
\left.\begin{matrix}
a_{11}^{(1)} & \ldots & a_{11}^{(\mathtt{ncorr})} \\
a_{21}^{(1)} & \ldots & a_{21}^{(\mathtt{ncorr})} \\
\vdots & \vdots & \vdots
\end{matrix}\right\} \; m=1
\left.\begin{matrix}
a_{22}^{(1)} & \ldots & a_{22}^{(\mathtt{ncorr})} \\
\vdots & \vdots & \vdots 
\end{matrix}\right\} \; m=2

If we go to $\ell$-dependent values of $\mathtt{ncorr}_\ell$ then this turns into a complicated ragged array:

\left.\begin{matrix}
a_{00}^{(1)} & \ldots & \ldots & \ldots & a_{00}^{(\mathtt{ncorr}_0)} \\
a_{10}^{(1)} & \ldots & \ldots & a_{10}^{(\mathtt{ncorr}_1)} \\
a_{20}^{(1)} & \ldots & a_{20}^{(\mathtt{ncorr}_2)} \\
\vdots & \vdots & \vdots
\end{matrix}\right\} \; m=0
\left.\begin{matrix}
a_{11}^{(1)} & \ldots & \ldots & a_{11}^{(\mathtt{ncorr}_1)} \\
a_{21}^{(1)} & \ldots & a_{21}^{(\mathtt{ncorr}_2)} \\
\vdots & \vdots & \vdots
\end{matrix}\right\} \; m=1
\left.\begin{matrix}
a_{22}^{(1)} & \ldots & a_{22}^{(\mathtt{ncorr}_2)} \\
\vdots & \vdots & \vdots 
\end{matrix}\right\} \; m=2

As you can see, this is made more awkward by the HEALPix order of the alms, where we have blocks in $m$, not $\ell$.


On the GPU side: make the code run on GPU/JAX. This should happen with whatever solution to the above we come up with.

ntessore avatar Jan 27 '25 16:01 ntessore