NTK4A
NTK4A copied to clipboard
Implement BiRNN NTK Calculation (#1)
This PR implements calculating the BiRNN NTK using summation and concatenation on the BiRNN hidden states.
The PR adds the python notebook, edits to utils, and NTK frob distance files for BiRNN.
The plots notebook is not edited to maintain consistency with the paper.
Merged via squash before creating PR.
Co-authored-by: Michael Santacroce (Microsoft Email)