OLMo icon indicating copy to clipboard operation
OLMo copied to clipboard

Activations Exploding Across Layers

Open c3-utsavdutta98 opened this issue 10 months ago • 7 comments

❓ The question

I was curious if there are any explicit mechanisms in place to prevent activations norms exploding with the initialization olmo2 uses. Specifically, with a N(0,0.02) init, followed by x = x + norm(attn(x, ...)) layers, I would assume the variance of activations should keep increasing over layers? The paper suggests otherwise though in Section 3.2.

Unfortunately, when I initialize a random model, I don't see the same behaviour and observe constant increase in activation variance.

Was wondering if someone from the team could shed light on what prevents this from happening?

FYI, my random model isn't an "Olmo2" model, but a similar transformer based architecture, I do use QK layer norm in my attention layer.

c3-utsavdutta98 avatar Feb 15 '25 00:02 c3-utsavdutta98

Section 3.2 does not graph activations, it graphs gradients, and it shows the gradient norm across training steps, not through the layers. We have not checked whether activations grow as you go through the layers, but they should. Every layer should increase activation norm by 2*sqrt(2).

There is an unresolvable tension here. If you want the activations (on the residual stream) to always stay in the same window, how can you add to it from MLPs and attention? You can scale the contribution and the residual stream, but then you diminish the contributions of the earlier layers. In practice, it seems that the growth by sqrt(2) per MLP/attn is fine, but there could be something better out there.

dirkgr avatar Feb 21 '25 20:02 dirkgr

Thanks for the reply! And yeah, it seems with this norm layout the growth is inevitable.

I wonder if you tried re-norming on nthe main branch (this was proposed in the swin norm paper, where they do it every 6 layers).

Also, with reference to 3.2; this was the excerpt I was referring to :

Image

Where it says the "growth_exponent" stays near 0 for both activations and gradients? Which would mean, I assume, that the activations are also not scaling up in norm (somehow?).

I believe the expected value of the norm (where each $X_i$ is drawn from a (0, $\sigma$)) for a large N, should scale by $\sqrt(\sigma))$, so a 0 growth exponent in norms would imply that the variances also stay constant across layers.

c3-utsavdutta98 avatar Feb 21 '25 21:02 c3-utsavdutta98

@dirkgr Here's a pretty basic check on this. I got the activations in every layer for a single prompt, then averaged over batch and hidden dimension to get the average norm, mean, and variance per layer.

Norms:

Image

Means:

Image

Variance: Image

Code (I used the [nnsight](https://www.nnsight.net/) library):

#%%
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import nnsight
from nnsight import LanguageModel
import torch

model_name = "allenai/OLMo-2-1124-7B"
lm = LanguageModel(model_name)
#%%
prompt = "The capital of France is"
with lm.trace(prompt):
    activations = nnsight.list().save()
    for l in lm.model.layers:
        activations.extend(l.output)

activations = torch.stack(activations).squeeze()
#%%

import numpy as np
import matplotlib.pyplot as plt
variances = activations.var(dim=(-1, -2)).numpy(force=True)
means = activations.mean(dim=(-1, -2)).numpy(force=True)
norms = activations.norm(dim=-1).mean(dim=-1).numpy(force=True)
plt.plot(norms, marker="o")
#%%
plt.xlabel("layer")
plt.ylabel("norm of activations, \naveraged over batch and hidden dimension")
plt.grid(True)  
plt.show()
#%%
plt.plot(means, marker="o")  
plt.xlabel("layer")
plt.ylabel("mean of activations, \naveraged over batch and hidden dimension")
plt.grid(True)  
plt.show()
#%%
plt.plot(variances, marker='o')  
plt.xlabel("layer")
plt.ylabel("variance of activations, \naveraged over batch and hidden dimension")
plt.grid(True)  
plt.show()
#%%

loftusa avatar Mar 27 '25 14:03 loftusa

I also plotted pairwise $\lambda$ values where $\lambda = \frac{1}{n_{\text{layers}}} \log \frac{||v'||}{||v||}$, v' is layer i, and v is layer i+1:

Image

as well as the case where v' is the first layer:

Image

loftusa avatar Mar 27 '25 14:03 loftusa

Why does the mean of the activations start below zero?

dirkgr avatar Mar 28 '25 15:03 dirkgr

Why does the mean of the activations start below zero?

The mean of the weights of every module in self_attn in the first layer is (slightly) negative, and there's no bias term

q_proj.weight: tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.170, 0.149] μ=-7.518e-07 σ=0.009
k_proj.weight: tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.178, 0.148] μ=-1.179e-06 σ=0.009
v_proj.weight: tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.267, 0.253] μ=-1.314e-06 σ=0.016
o_proj.weight: tensor[4096, 4096] n=16777216 (64Mb) x∈[-0.622, 0.621] μ=-1.888e-06 σ=0.016

Unless I'm missing something, if you just sample activations from a normal distribution, that'd make them skew slightly negative on average before they're added into the residual stream.

loftusa avatar Mar 29 '25 11:03 loftusa

I thought you'd be plotting the mean of the absolute values. Otherwise it doesn't really show the effect you're after, which is (I think) that the magnitude of the activations grows from one layer to the next. You get that with your norm and variance graphs, but I'm not sure the mean graph contributes to my understanding of what's going on.

Anyways, sorry for the delay, but I find this super interesting. Can you make the same graphs for the 13B and 32B? I'm asking because your variance graph looks really scary, and it would lead you to believe that it's impossible to scale much deeper. But the 13B and 32B are much deeper models and they worked fine. I'm curious to see what happens to the activations there.

dirkgr avatar Apr 19 '25 00:04 dirkgr

Hi, thanks again for the inquiry! We’re currently working on closing out old tickets, so we’re closing this out for now, but if you require a follow-up response, please re-open and we will get back to you!

baileykuehl avatar Jul 01 '25 17:07 baileykuehl