brax
brax copied to clipboard
Multi-Process/Multi-Controller setup
Jax has the ability to run in multi-controller mode: https://docs.jax.dev/en/latest/multi_process.html does brax support it? I am scheduling training using torchx where one process is assigned a single gpu and processes need to coordinate work, is there any guide for how to use brax in this setup?