nnsight icon indicating copy to clipboard operation
nnsight copied to clipboard

Changing *access* order of attention and MLP outputs affects model predictions unexpectedly

Open Michaelikarasik opened this issue 8 months ago • 4 comments

I'm attempting to zero-ablate all self-attention outputs of the last token across all layers, such that the model's prediction should only depend on the last token and not be affected by any preceding tokens.

However, I've observed unexpected behavior: when I modify attention outputs, the order in which I subsequently access the outputs of the MLP and attention components significantly alters the final prediction. This issue arises even though I expect the output to remain consistent, as the only explicit change made is setting the attention outputs to zero.

Reproducible Example:

from nnsight import NNsight, LanguageModel
import nnsight
import torch

model = LanguageModel("Qwen/Qwen2.5-7B")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id

with model.trace('Test prompt') as tracer:
    for layer_idx, layer in enumerate(model.model.layers):
        layer_attn = layer.self_attn.output[0]
        layer_mlp = layer.mlp.output
        layer.self_attn.output[0][:, -1, :] = 0

    final_out_one = model.output[0].save()

with model.trace('Test prompt') as tracer:
    for layer_idx, layer in enumerate(model.model.layers):
        layer_mlp = layer.mlp.output
        layer_attn = layer.self_attn.output[0]
        layer.self_attn.output[0][:, -1, :] = 0

    final_out_two = model.output[0].save()

for final_out in [final_out_one, final_out_two]:
    probs = torch.nn.functional.softmax(final_out[:, -1, :], dim=-1)
    probs_argsort = probs.argsort(dim=-1, descending=True)
    final_probs = (probs_argsort[0] == model.tokenizer.encode(" ABCD")[0]).nonzero()[0].item()
    print(final_probs)

print("WHY DID THIS CHANGE??")

Observed Output:

7743
119759
WHY DID THIS CHANGE??

Expected Output:

The prediction should remain identical between runs, regardless of the order in which the attention and MLP outputs are accessed, as the zero-ablation operation is identical in both cases.

Environment:

NNsight version: 0.4.5

Model: Qwen/Qwen2.5-7B

PyTorch version: 2.6.0+cu124

Michaelikarasik avatar Apr 13 '25 16:04 Michaelikarasik

Hey @Michaelikarasik, NNsight doc states

Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.

I basically think this is no longer true in latest NNsight versions, but that this hasn't been said publicly enough. Using nnsight 0.4.5 and dev branch:

from nnsight import NNsight, LanguageModel
import nnsight
import torch

model = LanguageModel("Maykeye/TinyLLama-v0")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id

with model.trace('Test prompt') as tracer:
    for layer_idx, layer in enumerate(model.model.layers):
        layer_attn = layer.self_attn.output[0]
        layer_mlp = layer.mlp.output
        layer.self_attn.output[0][:, -1, :] = 0

    final_out_one = model.output[0].save()

with model.trace('Test prompt') as tracer:
    for layer_idx, layer in enumerate(model.model.layers):
        layer_mlp = layer.mlp.output
        layer_attn = layer.self_attn.output[0]
        layer.self_attn.output[0][:, -1, :] = 0

    final_out_two = model.output[0].save()

with model.trace('Test prompt') as tracer:
    final_out_vanilla = model.output[0].save()

print(f"final_out_one - final_out_two: {(final_out_one - final_out_two).abs().max()}")
print(f"final_out_vanilla - final_out_two: {(final_out_vanilla - final_out_two).abs().max()}")
print(f"final_out_one - final_out_vanilla: {(final_out_one - final_out_vanilla).abs().max()}")

Yields:

final_out_one - final_out_two: 6.797832489013672
final_out_vanilla - final_out_two: 6.797832489013672
final_out_one - final_out_vanilla: 0.0

My guess is that final_out_one doesn't modify your forward pass for some obscure reason linked to this new behavior.

Butanium avatar Apr 14 '25 14:04 Butanium

Hi @Michaelikarasik - thanks for submitting this issue.

As rightly pointed out by Clement, the inconsistency you are experiencing is due to accessing modules out of order. In short, the general rule should be to access modules following the model's original order from its computation graph. Hence, the correct version of your experiment should look the following:

from nnsight import NNsight, LanguageModel
import nnsight
import torch

model = LanguageModel("Qwen/Qwen2.5-7B")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id

with model.trace('Test prompt') as tracer:
    for layer_idx, layer in enumerate(model.model.layers):
        layer_attn = layer.self_attn.output[0]
        layer.self_attn.output[0][:, -1, :] = 0
        layer_mlp = layer.mlp.output


    final_out_one = model.output[0].save()

for final_out in [final_out_one, final_out_two]:
    probs = torch.nn.functional.softmax(final_out[:, -1, :], dim=-1)
    probs_argsort = probs.argsort(dim=-1, descending=True)
    final_probs = (probs_argsort[0] == model.tokenizer.encode(" ABCD")[0]).nonzero()[0].item()
    print(final_probs)

AdamBelfki3 avatar Apr 14 '25 22:04 AdamBelfki3

@AdamBelfki3 @Butanium Thanks!! Do you think an exception or warning should be raised in such cases?

Michaelikarasik avatar Apr 15 '25 10:04 Michaelikarasik

Yes I think there should be a warning / exception for this kind of failure

Butanium avatar Apr 15 '25 10:04 Butanium

@Michaelikarasik In 0.5, order is enforced and you will indeed receive an error!

JadenFiotto-Kaufman avatar Jul 07 '25 22:07 JadenFiotto-Kaufman