Easy-Transformer
Easy-Transformer copied to clipboard
[Proposal] Have ActivationCache.get_full_resid_decomposition support passing in a vector/tensor to project onto
Proposal
Allow ActivationCache.get_full_resid_decomposition to receive a project_output_onto
tensor that is either a [d_model] tensor or [d_model, num_outputs] tensor, such we multiply the output by that. Internally, rather than taking (neurons * W_out), take neurons * (W_out @ project_output_onto), this is much more memory efficient.
Motivation
There's a ton of neurons, and creating a [d_mlp, d_model] tensor at every position and batch can blow out your GPU memory fast. This means that if we just want eg the contribution of a neuron to the output logit of the correct next token, we can just feed in that vector and save memory.
This is a bit messy, since there's many ways we might want to do this (eg, having a different output vector per position for each correct next token), but this seems like a good MVP.
Will take a look at this whilst doing recursive DLA