torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Optimizer in backward with grad clipping is broken

Open apaz-cli opened this issue 9 months ago • 7 comments

I've been working on something similar, see https://github.com/PrimeIntellect-ai/CPUOptimizer, but I see a problem in your implementation.

If you hide the optimizer step inside the backward pass, the grads are applied and set to None before clip_grad_norm_(). Meaning you just... silently don't clip the grads.

Did you realize this? Is it documented somewhere, I just missed it?

And then, a more general research question. You can't clip the grads with optimizer in backward. But most RL recipies right now rely on grad clipping for stability. But we really really want to hide the optimizer step inside the backward pass, because we're interested in CPU offloading.

Do you have any tips/tricks for better stability without grad clipping? Besides just carefully tuning the other hparams?

I actually implemented a "running grad norm," where each parameter is clipped with the correctly weighted norm of all the params which came before it, but the convergence wasn't as good. It's very stable. But I think it changes the direction of the update. Open to other suggestions, if you've got any ideas.

Curious what your thoughts are on this.

apaz-cli avatar Mar 20 '25 03:03 apaz-cli

cc: @mori360

tianyu-l avatar Mar 20 '25 05:03 tianyu-l

I think we should disable grad clip at this moment with optimizer in backward. This is a known issue with optimizer in backward as @apaz-cli mentioned. I also heard some random thought about using the the total norm from the previous step. However, I'm not sure if this is a realistic approach.

fegin avatar Mar 20 '25 06:03 fegin

@fegin It actually is disabled, so... no change necessary I guess? But a comment or if statement describing the issue would be nice.

I don't think it works to use the norm of the previous step. First, because I tried it and it didn't work very well. But also because I believe the point is that if you have a bad sample in your batch you can reduce its effect. By the time you've applied the update it's too late, and it has no semantic meaning for the next update. You're likely to get lucky and get a number that's close to optimal, but we clip for the unlucky batches.

I think it would be good to use the median grad norm or something. Maybe median with fixed buffer size, as you could be doing millions of steps.

Haven't implemented this one yet. Was going to run the experiment in a week or two when no longer stressing about a release.

apaz-cli avatar Mar 20 '25 06:03 apaz-cli

Yeah, optimizer-in-backward could not work with gradient clipping as doced here, https://github.com/pytorch/torchtitan/blob/048f65d7ed81d5f46195dcf557209f4444223e1c/torchtitan/config_manager.py#L247-L254 Currently have no idea on how to enable gradient clipping after optimizer-in-backward

mori360 avatar Mar 20 '25 17:03 mori360

I don't think it works to use the norm of the previous step. First, because I tried it and it didn't work very well. But also because I believe the point is that if you have a bad sample in your batch you can reduce its effect. By the time you've applied the update it's too late, and it has no semantic meaning for the next update. You're likely to get lucky and get a number that's close to optimal, but we clip for the unlucky batches.

Nice explanation. I would like to learn if median norm works.

fegin avatar Mar 20 '25 18:03 fegin

@mori360 @fegin I'll keep you updated :3

apaz-cli avatar Mar 20 '25 18:03 apaz-cli

@mori360 @fegin I'm still going to run the experiment, but I just thought of a bunch of reasons why median clipping doesn't make any sense and probably won't work.

Actually, still fundamentally the same argument. You don't know how much to clip by until it explodes, and you don't know if it explodes until the end of backward(), because that's when it tends to do so. Clipping the grad with the median is probably a lot like decreasing the LR, which does improve convergence, but isn't what we're asking for and doesn't really solve the exploding gradients problem.

But, still going to try it, because why not.

I think the smarter thing might just be to compute the squared sums in the backward hooks and then pass the result as an argument to the optimizer step to apply it. Unfortunately not an optimization that's possible with the regular adamw api.

apaz-cli avatar Mar 23 '25 02:03 apaz-cli