HRM icon indicating copy to clipboard operation
HRM copied to clipboard

JAX implementation

Open ondra117 opened this issue 4 months ago • 1 comments

I tried to make a JAX implementation of HRM, specifically the Quick Demo: Sudoku Solver. However, even after implementing the details, the accuracy remained at 0% for Sudoku and around 10% for numbers, which corresponds to random guessing. I can't spot any differences that could explain this problem.

The only difference is in the data feed, where I dynamically check which samples need to be replaced (so no sample can be skipped). I also made the time embedding learnable (which shouldn’t have such a large impact on performance) to avoid errors with the attention implementation.

Here is my repo: The important files are in the "src" folder, along with "adam_atan2.py", "data_loader.py", and "train.py".

PS: In the future I plan to clean up the code, but I want to make it functional first.

ondra117 avatar Aug 28 '25 11:08 ondra117

It seems that the model haven't shown any learning progress. Check the gradient magnitude, parameter norm and loss trends

imoneoi avatar Sep 08 '25 14:09 imoneoi