reverse-engineering-neural-networks
reverse-engineering-neural-networks copied to clipboard
Add tests for datasets.
Especially to make sure that the index is less than the sequence length, as this causes issues due to jax's clamping behavior
In addition, should we change the select in loss and accuracy to select index - 1? This would maintain consistency with the other datasets, where index is given as the length of the sentences.
That might get confusing if one passes in index = 0 because then it would select the last element (-1)