OLMo icon indicating copy to clipboard operation
OLMo copied to clipboard

why is the total_grad_norm increasing across training?

Open ryanyxw opened this issue 1 year ago • 12 comments

❓ The question

This is a purely conceptual/intuition question, and I think it can only be asked with proper context with OLMO (which is why I didn't go to StackOverflow). I'd be very grateful if someone could answer this.

I noticed while going through the w&b training logs of OLMO 1B and OLMO 7B that the optim/total_grad_norm seems to be consistently increasing as training continues.

However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?

I'm a bit confused as to why this is the case. Thanks!

Screenshot 2024-05-25 at 21 53 54

Screenshot 2024-05-25 at 21 54 07

ryanyxw avatar May 26 '24 04:05 ryanyxw

Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to 1 / sqrt(step)).

But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.

epwalsh avatar May 31 '24 16:05 epwalsh

Hey @ryanyxw this is an interesting phenomenon that seems to be tied to the (effective) learning rate. @viking-sudo-rm is an expert here but I believe there's theoretical reasons to believe that grad norm will eventually blow up unless the LR keeps decreasing enough (e.g. with an schedule proportional to 1 / sqrt(step)).

But for whatever reason these grad norms curves have looked different in our latest 7B runs. There's an initial period where the grad norm grows to a peak, then it decreases and seems to settle.

oh, that sounds interesting. @epwalsh is this because your latest model is parameterized well and old model didn't saturate (decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?) ? i saw one of your researcher has been tried implementing Mu Parameterization recently.

SeunghyunSEO avatar Jun 02 '24 03:06 SeunghyunSEO

decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?

Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory

epwalsh avatar Jun 03 '24 16:06 epwalsh

decreasing grad norm means NN's parameter are still in sharp curvature in loss surface ?

Do you mean "increasing grad norm..."? Maybe, but I'm not sure how to test that theory

oh im sry, "increasing" right

SeunghyunSEO avatar Jun 04 '24 00:06 SeunghyunSEO

However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?

Counterintuitively, this does not need to be the case. If the weights are increasing over time, the grad-norm can increase even while the loss decreases (especially if the loss is flattening out). Intuitively, this is because the grad norm is roughly proportional to (or at least depends on) the parameter norm.

Explaining Growing Grad Norm in More Detail

In more mathematical detail, many neural network architectures are $k$-homogeneous w.r.t. their weights $\theta$, meaning that $f(c \theta) = c^k f(\theta)$, for some value $k$:

  • A ReLU network with depth $k$ is $k$-homogeneous
  • A transformer + linear output head is not exactly homogeneous but it is approximately 2-homogeneous (Merrill et al., 2021)

An important implication of $k$-homogeneity is that the gradient is $(k-1)$-homogeneous (derivation here):

$$\nabla f(c \theta) = c^{k-1} \nabla f(\theta)$$

This means that the gradient norm depends on the parameter norm:

$$ \nabla f(\theta) = \lVert \theta \rVert^{k-1} \cdot \nabla f(\theta / \lVert \theta \rVert) $$

$$\therefore \lVert \nabla f(\theta) \rVert = \lVert \theta \rVert^{k-1} \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$

Crucially, this says that making parameters larger while keeping their direction the same will increase the gradient norm.

In the case of transformers, which are approximately 2-homogeneous, we get that along a fixed direction in parameter space, the gradient norm is roughly proportional to the parameter norm:

$$\lVert \nabla f(\theta) \rVert \approx \lVert \theta \rVert \cdot \lVert \nabla f(\theta / \lVert \theta \rVert) \rVert $$

This means that if the direction our network is moving $\theta / \lVert \theta \rVert$ is roughly converged but the parameter norm $\lVert \theta \rVert$ is increasing, then we should expect the gradient norm to increase proportionally.

viking-sudo-rm avatar Jun 04 '24 15:06 viking-sudo-rm

However, the perplexity (and thus loss) seems to be converging to a local/global minimum. If the weights are converging to a local minimum, the gradient norm should also be decreasing, right? Since the loss landscape flattens out?

Dear Will,

Thank you very much for your detailed response!

One thing I am curious about is that loss somehow keeps going up along with the gradient norm.

image

Please check our loss curve. Would you happen to have any suggestions or solutions to this issue?

Thank you very much, and have a great day!

Best regards,

Shuyue Dec. 13th, 2024

SuperBruceJia avatar Dec 13 '24 18:12 SuperBruceJia

Hi @viking-sudo-rm, I am witnessing the same phenomenon during fine-tuning for a 7B model. I am just curious: Did you try using weight decay to mitigate this issue?

Also, are you aware of the perils of having larger parameters? Should we settle for larger parameters? Or should we strive for convergence with smaller parameter norms ?

Alex-Mathai-98 avatar Dec 30 '24 15:12 Alex-Mathai-98

Hi @viking-sudo-rm, I am witnessing the same phenomenon during fine-tuning for a 7B model. I am just curious: Did you try using weight decay to mitigate this issue?

Also, are you aware of the perils of having larger parameters? Should we settle for larger parameters? Or should we strive for convergence with smaller parameter norms ?

Please check this GitHub solution. Also, please check our configurations and model loader.

I hope this helps!

Merry Christmas and Happy New Year!

Best regards,

Shuyue Dec. 30th, 2024

SuperBruceJia avatar Dec 30 '24 15:12 SuperBruceJia

Hi @SuperBruceJia - thanks for the post. Based on my understanding, your training loss increased correct? That is not the case for me. The training loss and the validation loss decreased smoothly - but the gradient norm increased.

Alex-Mathai-98 avatar Dec 30 '24 16:12 Alex-Mathai-98

Hi @SuperBruceJia - thanks for the post. Based on my understanding, your training loss increased correct? That is not the case for me. The training loss and the validation loss decreased smoothly - but the gradient norm increased.

In this case, my suggestion is to set a max_grad_norm : 0.01.

Please check our configurations and model loader if you are interested.

I hope this helps!

Merry Christmas and Happy New Year!

Best regards,

Shuyue Dec. 30th, 2024

SuperBruceJia avatar Dec 30 '24 17:12 SuperBruceJia

Hi all - sorry I missed this thread.

Also, are you aware of the perils of having larger parameters? Should we settle for larger parameters? Or should we strive for convergence with smaller parameter norms ?

In general it seems parameters, activations, and gradients that are too large should be avoided, especially if you want to train with low precision. A key reason why is that floating points accrue more errors when the numerical values they represent are larger. With low precision, they floats don't even have to be that large for this to happen.

Anecdotally, it seems like keeping the grad norm in check might be the most important, but controlling and monitoring the parameters and activations can be helpful for diagnosing grad norm issues. I wrote a bit more about the emergence of large-magnitude "outlier" parameters and activations in early OLMo releases in this blog post. The soon to be released OLMo technical report will also document some of the changes we made in later OLMo versions to have a flatter grad norm throughout training. The suggestions provided above by others above should also be helpful.

viking-sudo-rm avatar Dec 30 '24 18:12 viking-sudo-rm

Please check our loss curve. Would you happen to have any suggestions or solutions to this issue?

@SuperBruceJia It's hard to say from just the loss and grad norm curves alone and without knowing more details about your setup. One thing that comes to mind is to check which layers or modules in the network have large gradients. In particular, if the large gradients are coming from early layers, the problem could be related to the embedding norm shrinking too much, leading to increasing layer-norm gradients in those layers (my blog post and this paper have further details about this particular failure mode).

viking-sudo-rm avatar Dec 30 '24 18:12 viking-sudo-rm

Hi, thanks again for the inquiry! We’re currently working on closing out old tickets, so we’re closing this out for now, but if you require a follow-up response, please re-open and we will get back to you!

baileykuehl avatar Jul 01 '25 17:07 baileykuehl