[Proposal] Update and discuss behavior of SVD
Proposal
This proposal is in two parts, addressing TL current SVD implementation.
Note, the SVD is typically written as:
$$ U,S,V^T = \ldots $$
The important part here is the $V^T$.
Currently, TL returns $V$ not $V^T$, as TL uses torch.svd not torch.linalg.svd, and torch.svd also returned $V$.
https://github.com/neelnanda-io/TransformerLens/blob/347e65235140d942301452a55505c9215b36806a/transformer_lens/FactoredMatrix.py#L118-L125
To address this, we can do one or both of:
- small:
torch.svdinFactoredMatrixis deprecated and should be replaced bytorch.linalg.svd. This can be done and keep all current behavior. Note, right now TL usesVhinFactoredMatrixbut should actually be calledV, so if this is done, we should also do the renaming. - big: change the implementation to instead return
V_Tastorch.linalg.svddoes. This would be a breaking change
Pitch
I'd like to make a pitch for doing both of the above. That is, switch to torch.linalg.svd, and return $V^T$ directly. This would let you do:
U, S, V_T = model.OV.svd()
torch.dist(model.OV, U @ torch.diag(S) @ V_T)
like in the docs, and at least for me, behaves more "like expected"
Note: the above is only conceptual. Right now, I dont think you can use FactoredMatrix directly and instead must use model.OV.AB, and the call to diag depends on the sizes of the others, and should probably be .diag_embed
The code change would be very small. Something like:
Ua, Sa, V_Ta = torch.linalg.svd(model.OV.A,full_matrices=False)
Ub, Sb, V_Tb = torch.linalg.svd(model.OV.B,full_matrices=False)
middle = Sa.diag_embed() @ V_Ta @ Ub @ Sb.diag_embed()
Um, Sm, V_Tm = torch.linalg.svd(middle,full_matrices=False)
Uf = Ua @ Um
Sf = Sm
V_Tf = V_Tm @ V_Tb
return Uf, Sf, V_Tf
I can submit a PR, but thought it might be best to discuss first
Alternatives
The first plan should be considered the 'safe and easy' alternative.
Checklist
- [X] I have checked that there is no similar issue in the repo (required)
Not directly related, but may be worth clarifying in the docs if this change is done is the following behavior:
model.OV.svd returns the "thin"/"compact" SVD, having performed the more efficient calculation directly.
In principle, one could just do the full SVD torch.linalg.svd(model.OV.AB)and then take the first r=rank columns of $U$ and rows of $V^T$ yourself.
These do not need to directly match! The SVD is not unique relative to $\pm$ of the left and right singular vectors.
Happy to just send a pull-request for the first option, keeping current behavior, while this is discussed