SAINT-pytorch icon indicating copy to clipboard operation
SAINT-pytorch copied to clipboard

SAINT on Ednet data

Open clara2911 opened this issue 4 years ago • 12 comments

Would you be able to add your code for running your implementation of SAINT on EdNet as well - besides the example on random data?

clara2911 avatar Jun 03 '21 12:06 clara2911

It would be great for me to have entire codes for reproducing the results on the paper either. Because I failed to get the performance with my implementation. Mine was about 55 compared to 78 from the paper.

kwonmha avatar Feb 16 '22 03:02 kwonmha

Hi kwonmha, I also try to reproduce the paper :)

This implementation of SAINT is not completely finished. For example, the dropout is not added, or the position_embeddings are wrongly added in every layer (instead of just in the first encoder/decoder), the LayerNorm should be placed like in the AttentionIsAllYouNeed paper (after the multi-head and after the FFN). You can also change the position encoding to the one from AttentionIsAllYouNeed.

All the rest is correct :) I reach AUC=0.76 with it but I'm not able to get the last 2% and also my metrics crash if I use a dimenson_model of 512 like in the paper (it works only with a smaller model)

Nino-SEGALA avatar Apr 13 '22 08:04 Nino-SEGALA

Hi, @Nino-SEGALA Thanks for the informing.

In my case, I think the problem exists in data processing or data itself, not in modeling. Because my model works fine with Ednet data from Kaggle.

Do you have any plan to upload your code on your github?

kwonmha avatar Apr 14 '22 04:04 kwonmha

I will try to upload it here with a Pull Request :)

I don't understand, it works with EdNet from Kaggle, but not with EdNet from the paper? What is the difference between them? Can you link both datasets? :)

Nino-SEGALA avatar Apr 15 '22 13:04 Nino-SEGALA

@Nino-SEGALA Here's the link to the dataset I mentioned. https://github.com/riiid/ednet It's KT-1 and you also need to download content data.

kwonmha avatar Apr 18 '22 02:04 kwonmha

Yes, I also use this one (and get 0.76 AUC with dim_model=128, if I use a larger model dim_model=512 I get AUC=0.5 me too :/)

Maybe you can try with a smaller model

And this dataset 'my model works fine with Ednet data from Kaggle' ? :)

Nino-SEGALA avatar Apr 19 '22 07:04 Nino-SEGALA

Thanks for informing!

kwonmha avatar Apr 19 '22 08:04 kwonmha

@kwonmha Here's the correct code of SAINT https://github.com/arshadshk/SAINT-pytorch/pull/6

Let me know if you solved the training of SAINT with a large dimension of the model (d_model=512) :D

Nino-SEGALA avatar Apr 26 '22 09:04 Nino-SEGALA

@Nino-SEGALA Have you tried applying Noam scheme learning rate scheduling mentioned on the paper? It's in Training Details section.

I got the same problem where auc stays around 0.5 with dimension 256, 512. And validation auc is going above 0.7 with Noam scheme. It looks neccessary for training large transformer model.

Noam scheduler code link I used Lina Achaji's code. And I added

    def zero_grad(self):
        self.optimizer.zero_grad()

In the class for convenience.

As it changes leraning rate regard to step, batch_size looks important which have effect on the number of steps in training.

I got 0.7746 AUC with dim 256, 7727 with dim 512

kwonmha avatar May 25 '22 06:05 kwonmha

Thanks a lot for your comment kwonmha!

I did my training without Noam Scheme, and since I have implemented it, I didn't retry to do the big trainings. It is indeed what's making the difference! I didn't reach metrics as high as you, but my model didn't train until convergence (it stopped a bit before). I'll let you know when I have my final results :D

@kwonmha could you also share your ACC, RMSE and BCE loss if you have them?

Nino-SEGALA avatar Jun 04 '22 16:06 Nino-SEGALA

Sorry but I haven't measure metrics other than AUC so far.

kwonmha avatar Jun 08 '22 04:06 kwonmha

I got 0.7666 AUC with dim 256, 0.7537 with dim 512 And my dim 512 training crashed after (AUC=0.6), even if it uses Noam Scheme now :/

Nino-SEGALA avatar Jun 10 '22 12:06 Nino-SEGALA