ray
ray copied to clipboard
[RLlib] Compile update logic on learner and use cudagraphs
Why are these changes needed?
In the first attempt to leverage torch compile, we didn't introduce a compiled update method on the side of the learner (1) and also had little success with torch compiling on the rollout worker side because weight updates would effectively not happen when we would compile (2).
For (1): This PR makes an attempt at compiling on the learner side akin to what we do for eager tracing, meaning that there is a possibly_compiled_update()
method on the TorchLearner side that we introduce.
For (2): We get around the issue of not being able to set weights by using cudagraphs as the torch dynamo backend.
Related tensorboard that shows speedups on rollout worker side:
@kouroshHakha I've also added a configuration enumerator instead of relying on two long strings "complete_update" and "forward_train".