BASS
BASS copied to clipboard
RuntimeError in M_Step()
Hello,
I'm trying to run your project on a Windows 10 machine using Python 3.6, PyTorch 1.3.0, TorchVision 0.4.1, cudatoolkit 10.1.243. However I get the runtime error at line 848 in BASS.py
Nk.index_add_(0, argmax, Global.ones)
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 in call to th_index_add
When looking at the pytorch documentation of the function, it appears that maybe argmax
and Global.ones
have to be interchanged but that yields other errors in the following code.
PS: I already tried running it with --cpu as argument but it didn't help.
I never faced this problem. Please contact me at [email protected], and I will guide you there. Please send me the image you try the segment and the args.
Hi, I've seen the same problem. According to PyTorch's documents, the index
parameter should be a LongTensor
. In Nk.index_add_(0, argmax, Global.ones)
, argmax
is an IntTensor
, which causes the error. You can manually change argmax
to argmax.long()
to solve this.