probability icon indicating copy to clipboard operation
probability copied to clipboard

TFP JAX: The transition kernel drastically decreases speed.

Open SebastianSosa opened this issue 1 year ago • 4 comments

Dear all,

I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I've encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:

  • TFP GPU: NUTS alone took 118.2952 seconds
  • TFP GPU: NUTS + Bijector took 1986.8306 seconds
  • TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
  • TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
  • Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds

I've conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.

Could there be something I've missed? Is there room for optimization in this process?

Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed: data.csv gitissue.txt google Colab

Please note that I'm only using the first 100 lines of the data.

Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)

Thank you in advance for your assistance.

Sebastian

SebastianSosa avatar Apr 09 '24 09:04 SebastianSosa

Hi - It looks like the colab is locked down, so I can not access it.

ColCarroll avatar May 31 '24 11:05 ColCarroll

Does this link allow access?

I made a simulation instead of using real data, as it allows us to evaluate how the models perform with the increase in data size. I can update it in the next few days.

SebastianSosa avatar May 31 '24 11:05 SebastianSosa

Note that the data is not saved with the colab, so I can not run this, but it looks as though the problem is with your use of tfp.bijectors.CorrelationCholesky(ni). Note that CorrelationCholesky doesn't take any parameters, and ni is silently being accepted as an argument to validate_args.

Downstream, I think this will lead to some wild posterior, and so TFP NUTS is (correctly) exhausting its tree doublings and doing ~10x as much work.

ColCarroll avatar May 31 '24 14:05 ColCarroll