MusicTransformer-pytorch
MusicTransformer-pytorch copied to clipboard
a mistake in the "train.py"
hello jason,
thanks for your implementation! good job!
however, when i wanted to train the music transformer, i found that there is a mistake in the "train.py".
i received this error:
Traceback (most recent call last):
File "train.py", line 92, in
After further discovering, i found that the reason is that a tuple is returned in the line 91, train.py. And then the returned tuple is transferred to "metric_set()", however "metric_set()" should receive a tensor instead of a tuple. Actually, adding a new variable to receive the returned attention weights could avoid the above error.
Do you know how to fix this in code?
if self.training:
return fc.contiguous()
else:
return fc.contiguous(), [weight.contiguous() for weight in w]
Separate the original line 43 in model.py to the above, just solve the issue
I made a PR #11 to fix this.