Bug with IteratedSigmoidCentered
How to replicate:
import oryx
import jax.numpy as jnp
oryx.bijectors.IteratedSigmoidCentered().forward(jnp.array([0., 0., 0.]))
Error message:
TypeError: abstract_eval_fun() missing 1 required keyword-only argument: 'debug_info'
I was wondering if I'm doing something wrong here, or this is just a bug with the bijector.
Thanks!
Are you running at HEAD or with the latest pypi release? I think the first should work (https://github.com/jax-ml/oryx/pull/98), but the second shouldn't.
I can make a release sometime soon if I'm right about that (I can also double check -- that doesn't have to be your responsibility!) Let me know how urgent this is!
thanks! latest pypi release, so I guess that makes sense. It turns out the version in tensorflow probability worked, and for my current purposes that's more convenient, so this is not urgent at all