BASS icon indicating copy to clipboard operation
BASS copied to clipboard

RuntimeError in M_Step()

Open CaswellBerry opened this issue 4 years ago • 2 comments

Hello, I'm trying to run your project on a Windows 10 machine using Python 3.6, PyTorch 1.3.0, TorchVision 0.4.1, cudatoolkit 10.1.243. However I get the runtime error at line 848 in BASS.py Nk.index_add_(0, argmax, Global.ones)

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #3 in call to th_index_add

When looking at the pytorch documentation of the function, it appears that maybe argmax and Global.ones have to be interchanged but that yields other errors in the following code.

PS: I already tried running it with --cpu as argument but it didn't help.

CaswellBerry avatar Mar 16 '20 15:03 CaswellBerry

I never faced this problem. Please contact me at [email protected], and I will guide you there. Please send me the image you try the segment and the args.

uzielroy avatar Mar 17 '20 09:03 uzielroy

Hi, I've seen the same problem. According to PyTorch's documents, the index parameter should be a LongTensor. In Nk.index_add_(0, argmax, Global.ones), argmax is an IntTensor, which causes the error. You can manually change argmax to argmax.long() to solve this.

xrz000 avatar Mar 23 '20 10:03 xrz000