pyvene
pyvene copied to clipboard
[P1] Speed up training of multiple DAS interventions with caching
Commonly, we want to exhaustively train DAS on every layer and position (or e.g. every attention head in a layer) to find which ones are causally relevant for the model's computations. When dealing with a fixed dataset, we could speed this process up by caching and reusing activations. Unclear what the best way to implement this is; should already be possible to have a minimal example with CollectIntervention
and the activations_sources=
arg in inference.