Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

[Proposal] Update and discuss behavior of SVD

Open diego898 opened this issue 2 years ago • 2 comments

Proposal

This proposal is in two parts, addressing TL current SVD implementation.

Note, the SVD is typically written as:

$$ U,S,V^T = \ldots $$

The important part here is the $V^T$.

Currently, TL returns $V$ not $V^T$, as TL uses torch.svd not torch.linalg.svd, and torch.svd also returned $V$.

https://github.com/neelnanda-io/TransformerLens/blob/347e65235140d942301452a55505c9215b36806a/transformer_lens/FactoredMatrix.py#L118-L125

To address this, we can do one or both of:

  • small: torch.svd in FactoredMatrix is deprecated and should be replaced by torch.linalg.svd. This can be done and keep all current behavior. Note, right now TL uses Vh in FactoredMatrix but should actually be called V, so if this is done, we should also do the renaming.
  • big: change the implementation to instead return V_T as torch.linalg.svd does. This would be a breaking change

Pitch

I'd like to make a pitch for doing both of the above. That is, switch to torch.linalg.svd, and return $V^T$ directly. This would let you do:

U, S, V_T = model.OV.svd()
torch.dist(model.OV, U @ torch.diag(S) @ V_T)

like in the docs, and at least for me, behaves more "like expected"

Note: the above is only conceptual. Right now, I dont think you can use FactoredMatrix directly and instead must use model.OV.AB, and the call to diag depends on the sizes of the others, and should probably be .diag_embed

The code change would be very small. Something like:

Ua, Sa, V_Ta = torch.linalg.svd(model.OV.A,full_matrices=False)
Ub, Sb, V_Tb = torch.linalg.svd(model.OV.B,full_matrices=False)

middle = Sa.diag_embed() @ V_Ta @ Ub @ Sb.diag_embed()
Um, Sm, V_Tm = torch.linalg.svd(middle,full_matrices=False)

Uf = Ua @ Um
Sf = Sm
V_Tf = V_Tm @ V_Tb

return Uf, Sf, V_Tf

I can submit a PR, but thought it might be best to discuss first

Alternatives

The first plan should be considered the 'safe and easy' alternative.

Checklist

  • [X] I have checked that there is no similar issue in the repo (required)

diego898 avatar Jul 14 '23 18:07 diego898

Not directly related, but may be worth clarifying in the docs if this change is done is the following behavior:

model.OV.svd returns the "thin"/"compact" SVD, having performed the more efficient calculation directly.

In principle, one could just do the full SVD torch.linalg.svd(model.OV.AB)and then take the first r=rank columns of $U$ and rows of $V^T$ yourself.

These do not need to directly match! The SVD is not unique relative to $\pm$ of the left and right singular vectors.

diego898 avatar Jul 14 '23 18:07 diego898

Happy to just send a pull-request for the first option, keeping current behavior, while this is discussed

diego898 avatar Jul 24 '23 18:07 diego898