tgn
tgn copied to clipboard
Run script failed due to assertion
Hi team, when I tried to run script using Reddit dataset, it failed after 6 training epoch.
I run the script on NVIDIA RTX 3080ti, which does not support CUDA 10.1. Therefore, I upgraded torch version to 2.1 with CUDA 12.1. Other requirements (pandas, scikit-learn, python3.7) stated in README.md are unchanged.
How could this happen?
If you need any more details about my environment, please comment and I will provide them.