Changing *access* order of attention and MLP outputs affects model predictions unexpectedly
I'm attempting to zero-ablate all self-attention outputs of the last token across all layers, such that the model's prediction should only depend on the last token and not be affected by any preceding tokens.
However, I've observed unexpected behavior: when I modify attention outputs, the order in which I subsequently access the outputs of the MLP and attention components significantly alters the final prediction. This issue arises even though I expect the output to remain consistent, as the only explicit change made is setting the attention outputs to zero.
Reproducible Example:
from nnsight import NNsight, LanguageModel
import nnsight
import torch
model = LanguageModel("Qwen/Qwen2.5-7B")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id
with model.trace('Test prompt') as tracer:
for layer_idx, layer in enumerate(model.model.layers):
layer_attn = layer.self_attn.output[0]
layer_mlp = layer.mlp.output
layer.self_attn.output[0][:, -1, :] = 0
final_out_one = model.output[0].save()
with model.trace('Test prompt') as tracer:
for layer_idx, layer in enumerate(model.model.layers):
layer_mlp = layer.mlp.output
layer_attn = layer.self_attn.output[0]
layer.self_attn.output[0][:, -1, :] = 0
final_out_two = model.output[0].save()
for final_out in [final_out_one, final_out_two]:
probs = torch.nn.functional.softmax(final_out[:, -1, :], dim=-1)
probs_argsort = probs.argsort(dim=-1, descending=True)
final_probs = (probs_argsort[0] == model.tokenizer.encode(" ABCD")[0]).nonzero()[0].item()
print(final_probs)
print("WHY DID THIS CHANGE??")
Observed Output:
7743
119759
WHY DID THIS CHANGE??
Expected Output:
The prediction should remain identical between runs, regardless of the order in which the attention and MLP outputs are accessed, as the zero-ablation operation is identical in both cases.
Environment:
NNsight version: 0.4.5
Model: Qwen/Qwen2.5-7B
PyTorch version: 2.6.0+cu124
Hey @Michaelikarasik, NNsight doc states
Interventions within the tracing context do not necessarily execute in the order they are defined. Instead, their execution is tied to the module they are associated with.
I basically think this is no longer true in latest NNsight versions, but that this hasn't been said publicly enough. Using nnsight 0.4.5 and dev branch:
from nnsight import NNsight, LanguageModel
import nnsight
import torch
model = LanguageModel("Maykeye/TinyLLama-v0")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id
with model.trace('Test prompt') as tracer:
for layer_idx, layer in enumerate(model.model.layers):
layer_attn = layer.self_attn.output[0]
layer_mlp = layer.mlp.output
layer.self_attn.output[0][:, -1, :] = 0
final_out_one = model.output[0].save()
with model.trace('Test prompt') as tracer:
for layer_idx, layer in enumerate(model.model.layers):
layer_mlp = layer.mlp.output
layer_attn = layer.self_attn.output[0]
layer.self_attn.output[0][:, -1, :] = 0
final_out_two = model.output[0].save()
with model.trace('Test prompt') as tracer:
final_out_vanilla = model.output[0].save()
print(f"final_out_one - final_out_two: {(final_out_one - final_out_two).abs().max()}")
print(f"final_out_vanilla - final_out_two: {(final_out_vanilla - final_out_two).abs().max()}")
print(f"final_out_one - final_out_vanilla: {(final_out_one - final_out_vanilla).abs().max()}")
Yields:
final_out_one - final_out_two: 6.797832489013672
final_out_vanilla - final_out_two: 6.797832489013672
final_out_one - final_out_vanilla: 0.0
My guess is that final_out_one doesn't modify your forward pass for some obscure reason linked to this new behavior.
Hi @Michaelikarasik - thanks for submitting this issue.
As rightly pointed out by Clement, the inconsistency you are experiencing is due to accessing modules out of order. In short, the general rule should be to access modules following the model's original order from its computation graph. Hence, the correct version of your experiment should look the following:
from nnsight import NNsight, LanguageModel
import nnsight
import torch
model = LanguageModel("Qwen/Qwen2.5-7B")
model.eval()
model.generation_config.pad_token_id = model.tokenizer.pad_token_id
with model.trace('Test prompt') as tracer:
for layer_idx, layer in enumerate(model.model.layers):
layer_attn = layer.self_attn.output[0]
layer.self_attn.output[0][:, -1, :] = 0
layer_mlp = layer.mlp.output
final_out_one = model.output[0].save()
for final_out in [final_out_one, final_out_two]:
probs = torch.nn.functional.softmax(final_out[:, -1, :], dim=-1)
probs_argsort = probs.argsort(dim=-1, descending=True)
final_probs = (probs_argsort[0] == model.tokenizer.encode(" ABCD")[0]).nonzero()[0].item()
print(final_probs)
@AdamBelfki3 @Butanium Thanks!! Do you think an exception or warning should be raised in such cases?
Yes I think there should be a warning / exception for this kind of failure
@Michaelikarasik In 0.5, order is enforced and you will indeed receive an error!