jbuckman
jbuckman
## Describe the bug When we call async_reset multiple times, we get a crash. The use case here is that I want to run many parallel episodes with the async...
## Motivation The documentation says auto-reset is enabled by default, but for certain applications it is better to have it turned off. ## Solution When auto-reset is off, terminated environments...
I'm running several nested bars. The inner bars may be run many times for each step of the outer bars. When an inner bar completes, it "stays around", and a...
### Description This does not print anything: ```python import jax import jax.numpy as jnp def scanner(carry, x): jax.debug.print("carry={carry} x={x}", carry=carry, x=x) return carry + x, x out = jax.lax.scan(scanner, jnp.array(0.),...
### Description Inside Pallas kernels, we often want a loop, and to speed up compilation, we typically use a scan function such as jax.lax.fori_loop. (For example, in the attention kernel...