numpyro
numpyro copied to clipboard
feature request : Batched scale_tril
I would like to do batch transformations using transforms.LowerCholeskyAffine. So far I can do a "batched" transform as long as the loc and scale_tril are the same however if I try to use different locs with the same scale_tril then I get the following error message.
ValueError: Only support 2-dimensional scale_tril matrix. Please make a feature request if you need to use this transform with batched scale_tril.