JAX implementation
I tried to make a JAX implementation of HRM, specifically the Quick Demo: Sudoku Solver. However, even after implementing the details, the accuracy remained at 0% for Sudoku and around 10% for numbers, which corresponds to random guessing. I can't spot any differences that could explain this problem.
The only difference is in the data feed, where I dynamically check which samples need to be replaced (so no sample can be skipped). I also made the time embedding learnable (which shouldn’t have such a large impact on performance) to avoid errors with the attention implementation.
Here is my repo: The important files are in the "src" folder, along with "adam_atan2.py", "data_loader.py", and "train.py".
PS: In the future I plan to clean up the code, but I want to make it functional first.
It seems that the model haven't shown any learning progress. Check the gradient magnitude, parameter norm and loss trends