btaba
btaba
Hi @Kallinteris-Andreas You should call `mjx_get_physics_state_put_version` outside of the `jax.jit`. So once all the computations are done on device (in MJX-land), only then should you transfer the data back onto...
Hi @erwincoumans! I'm not able to reproduce the issue, which jax version are you using?
Ok thanks for the pointer! It turns out that jax>=0.4.6 is incompatible with public colab TPU runtimes (see https://stackoverflow.com/a/75734517). We're pinning the jax/jaxlib versions to >=0.4.6 now, so it's best...
@Daffan Thanks for reporting! Which environment are you using and which backend?
Thanks @imoneoi for the report and @traversaro for adding pytinyrenderer to conda. Indeed, can confirm `pip install brax` does not work on Python 3.11. We've been planning to add a...
Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation. https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/wrappers/training.py#L199...
Hi @joeryjoery , I'm not quite following why you want to add extra args to `pipeline.init` and `pipeline.step`. Does something like this not work: `jax.vmap(pipeline.init, in_axes=[custom_in_axes, None, None])(sys, q, qd)`...
Comments and questions on the proposed changes: [1] Subsume part of `System` inside `State`: You can do this already by adding `System` to `state.info`, and re-writing your env code to...
Hi @joeryjoery I think we tried a version of this implementation. A few comments: [1] Can you update your impl to make it work for nested fields in `sys`? You...
Hi @joeryjoery , `tree_replace` can be found here: https://github.com/google/brax/blob/f9a4d73181d699db0fa38b07c5a651f5dc8ee231/brax/base.py#L114 Thanks for the context on [2], I recommend using your own wrapper (for ensuring sampling a new system for every trajectory),...