probability icon indicating copy to clipboard operation
probability copied to clipboard

Can't jit PoissonLogNormalQuadratureCompound log_prob

Open GianmarcoCallegher opened this issue 1 year ago • 0 comments

If I try to jit the log_prob method of the PoissonLogNormalQuadratureCompound

from jax import jit
import tensorflow_probability.substrates.jax.distributions as tfd

jit(tfd.PoissonLogNormalQuadratureCompound(0.0, 1.0).log_prob)(1.)

I get the following error:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got
Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>.

operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line [/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1](https://file+.vscode-resource.vscode-cdn.net/var/folders/vr/x979b72d00jgztv5gxjgl5j00000gn/T/ipykernel_26756/1664260717.py:1) (<module>)

GianmarcoCallegher avatar Mar 22 '24 09:03 GianmarcoCallegher