brax
brax copied to clipboard
While loop in generalized/math forces to use grads Forward mode
https://github.com/google/brax/blob/280a1c50fa021b6c17a2a3347fea43a2887382bc/brax/v2/math.py#L278
For people who want to calculate gradients over environment steps, this while loop is a bit annoying ;). Forward mode works fine for me now. But I just wanted to point out that maybe the speed optimization is negligible compared to having the possibility for reverse mode gradients?
Actually, when looking at the latest Brax updates this can be done by setting 'approximate' to False here: https://github.com/google/brax/blob/673a41f780fe0f137507d1c35286577c48bde4d1/brax/v2/generalized/pipeline.py#L84 I don't need it at the moment. But just wanted to write it here. Do note that the results do change slightly.