jax icon indicating copy to clipboard operation
jax copied to clipboard

Support Hessian of gamma-distributed samples

Open NeilGirdhar opened this issue 1 year ago • 3 comments

Fixes https://github.com/google/jax/issues/16076

NeilGirdhar avatar May 25 '24 10:05 NeilGirdhar

@jakevdp Is there any interest in merging something like this? It's really hard to do the same thing in client code without copying thousands of lines of Jax code and then keeping them updated (which is extremely time-consuming). Or perhaps you have an alternative idea for how I can accomplish this?

NeilGirdhar avatar Jul 22 '24 23:07 NeilGirdhar

Hi @NeilGirdhar, sorry for the delay on this. I'm hoping @froystig or @mattjj can weigh-in here.

I'm a bit concerned about the approach here, because the bounded while-loop might have memory impacts when computing the second derivative for large numbers of samples: my understanding is it would require statically allocating a buffer 256 times larger than the buffer of samples you're generating. It may be that remat can help with that, but I'm not entirely sure. I'd like to make sure we have the right solution here to avoid potentially adding a memory-related footgun to JAX library code.

Perhaps there are ways to compute this second derivative more directly, without differentiating through the implementation of the first derivative?

jakevdp avatar Aug 01 '24 18:08 jakevdp

Absolutely no problem about the delay. Congratulations on completing the Jax implementation of the Array API so quickly!

I'm a bit concerned about the approach here, because the bounded while-loop might have memory impacts when computing the second derivative for large numbers of samples: my understanding is it would require statically allocating a buffer 256 times larger than the buffer of samples you're generating. It may be that remat can help with that, but I'm not entirely sure. I'd like to make sure we have the right solution here to avoid potentially adding a memory-related footgun to JAX library code.

Perhaps there are ways to compute this second derivative more directly, without differentiating through the implementation of the first derivative?

Your concerns make perfect sense to me. I'm going to let the others weight in, but in the interest of eliminating back-and-forth, I'll make some comments and suggestions if that's okay 😄

First of all, the reason for my frequent force-pushes (sorry if that was noisy?) is because this feature is so important to me that I am now pointing my repo to my PR branch rather than to Jax directly. This way I have access to this feature. I tried lifting the gamma-random-generation out of Jax, but it's a large mass of code that has changed over the last year, so that was too much work to keep updated.

As for the time and space concerns, I want to first remind readers that only the second derivative code is slow. I agree with your point that this could be a footgun.

The ideal approach is probably to replace the for loop with solving a fixed point. (So, it would go back to being just a while loop in all cases.) Is there any precedent to fixed point optimization in Jax's source code? I know that JaxOpt and tjax have fixed point solvers. It would be some work, but should be possible to recast the algorithm slightly so that it fits the fixed point interface.

An alternative approach would be to tune the 256 constant. I think it's about ten times too big. I didn't think about speed or memory, and I just wanted to get it working. Tuning this constant might solve the speed problem, but the fixed point solution makes the memory cost constant, I think.

What do you think?

NeilGirdhar avatar Aug 01 '24 20:08 NeilGirdhar

@jakevdp Would you mind taking a look at this again?

NeilGirdhar avatar Feb 10 '25 22:02 NeilGirdhar

We believe that we don't need random_gamma_grad_p as a primitive at all any more, as proposed in #27628.

froystig avatar Apr 01 '25 01:04 froystig

Awesome!

NeilGirdhar avatar Apr 01 '25 07:04 NeilGirdhar

@NeilGirdhar thanks so much for pushing on this contribution! Roy and I were finally (sorry) reviewing it yesterday, and it's only thanks to your PR that we realized we didn't need this as a primitive anymore.

Our current understanding is that back in #3281 this code was added as a primitive so that it could be lowered to an XLA math library function (to be shared by TF and TF Probability in particular). But then in #15158 that special XLA lowering path was removed. At that point, we could've removed the primitive, but no one realized it until now.

Thanks for pushing on this, and for your patience. It's always great to hear from you.

mattjj avatar Apr 01 '25 15:04 mattjj

No worries, thank you for always focusing on finding the best solution! (And I appreciate the kind words.)

NeilGirdhar avatar Apr 01 '25 15:04 NeilGirdhar