DecisionTransformerInterpretability
DecisionTransformerInterpretability copied to clipboard
Mega Card: Improve Analysis App in various ways to facilitate better interpretability analysis of the new models
Analysis features
Static
Composition
- [x] Make composition maps
- [x] Replace composition scores with strip plots?
- [ ] Create a meta-composition score. Something that measures total influence?
- [ ] How do we check for composition between MLP_in and W_out? (seems expensive?, maybe tie to very specific hypotheses)
Dynamic
Logit Lens
- [x] By Layer
- [x] By Layer accumulated
- [x] By Head
Attention Maps:
- [ ] Make it easier to export a nice visualization of the attention map (cv is actually not great for that).
- [ ] Make it possible to calculate the rank(k) approximation to the attention map.
Causal
Activation Patching (features)
- [x] Set up component
- [x] Set up RTG Metric
- [x] Residual stream patching.
- [x] Patching via Attn and MLP
- [x] Head All Pos Patching
- [x] Head Specific Pos Patching (do later)
- [x] Head All Pos by Component
- [x] MLP at different Positions
- [ ] Show counterfactual attention map (ie: show difference in attention given intervention)
- [ ] Show what the logit diff is for each metric score. Activation Patching (token variations):
- [x] Action (fairly easy)
- [x] Key/Ball (important!)
- [ ] Timestep (also fairly easy)
RTG Scan
- [x] Switch to using t-lens for decomp
- [x] Provide more than one level of decomp
- [x] Add a clustergram to show heads which mediate a similar relationship between RTG and logits/logit diff
Congruence -> If features aren't in superposition, what effect do they have on the predictions?
- [x] - Pos
- [x] - Time
- [x] - W_in
- [x] - W_Out
- [x] - MLP Out
Renew old features:
- [ ] QK circuit visualizations for action and RTG embeddings
SVD Decomp / Explore ways to use dimensionality reduction to quickly understand what heads are doing.
- [ ] QK Circuit SVD SVD Decomp / Explore ways to use dimensionality reduction to quickly understand what heads are doing.#69
- [ ] OV Circuit SVD
Cache Characterization?
- [ ] Plot L2 norm of residual streams (along with mean and std)
Advanced
Implement Path Patching
- [ ] Understand Callum's code.
Implement AVEC
- [ ] Reread post to see if we can find.
Several things I feel are missing which are required for exploratory analysis to be more complete:
- [ ] visualise dot product of time embeddings with each other
- [ ] visualise dot product of positional embeddings with each other
- [ ] Use Jay's head type analysis but write specific patterns for attending to RTG, attending to positive RTG, attending to states, and attending to actions.
Several things I feel will be required for falsifying predictions of how the model is working:
- [ ] implement a variant of path patching for DTs either in a notebook or as part of the app.
- [ ] CaSc, not sure how feasible this is but it has always been the goal.