celerite2 icon indicating copy to clipboard operation
celerite2 copied to clipboard

Derive associative scan algorithm for factorization

Open dfm opened this issue 4 years ago • 5 comments

I've derived the algorithms for matrix multiplication and solves, but I haven't been able to work out the factorization algorithm yet. There don't seem to be numerical issues for the ops that I've derived so far, but I haven't extensively tested it. This would be interesting because it would allow parallel implementation on a GPU.

dfm avatar Jan 28 '21 20:01 dfm

Hi @dfm, sorry to be plaguing you 😅. I'm working on a JAX project with GPU acceleration and I'd like to use celerite2. If I use it out of the box, I get a warning that says:

NotImplementedError: XLA translation rule for celerite2_factor on platform 'gpu' not found

which brought me here. Is this still on the to-do list?

bmorris3 avatar Jul 09 '21 12:07 bmorris3

There is no GPU support planned for celerite2. It's possible to parallelize some of the algorithms but it's slower than the CPU version for all the tests I've done and scales badly with J (J^3 instead of J^2).

dfm avatar Jul 09 '21 13:07 dfm

Thanks for the quick response! If you have any pointers on alternatives I'd be grateful.

bmorris3 avatar Jul 09 '21 13:07 bmorris3

I don't know of any good JAX libraries for GPs, but it's not too hard to implement the math yourself to try it out. If the GP is your bottleneck, I think it's unlikely that you'll get any benefit from using a GPU, but if your computation is dominated by other parts of the model that are improved by GPU acceleration and not too many data points then it might be worth it. Here's an example implementation of naive GP computations using JAX + GPU acceleration that could get you started: https://github.com/dfm/tinygp/blob/main/src/tinygp/gp.py

dfm avatar Jul 09 '21 13:07 dfm

Thanks so much, as always!

bmorris3 avatar Jul 09 '21 14:07 bmorris3