Optimize pi0 for torch.compile
What this does
Original code use loss.item(), which creates a scalar. Thus, if we call torch.compile() on this policy, the computational graph will break, and PyTorch will issue a warning regarding performance degradation
How it was tested
Initialize the model and then call torch.compile on it. See that there is no longer warnings about graph break.
Sorry, I forgot to mention that this happened under using DDP to wrap around the compiled PI0Policy. I am not sure about the single-card situation, and I will try it when I have time later today or tmr.
This PR has been automatically marked as stale because it has not had recent activity (6 months). It will be closed if no further activity occurs. Thank you for your contributions.
This PR was closed because it has been stalled for 21 days with no activity. Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.