MusicTransformer-pytorch icon indicating copy to clipboard operation
MusicTransformer-pytorch copied to clipboard

a mistake in the "train.py"

Open 1996Wanglei opened this issue 5 years ago • 3 comments

hello jason, thanks for your implementation! good job! however, when i wanted to train the music transformer, i found that there is a mistake in the "train.py". i received this error: Traceback (most recent call last): File "train.py", line 92, in metrics = metric_set(sample, batch_y) File "/home/mirlab2019/wanglei/MusicTransformer-pytorch/custom/metrics.py", line 69, in call return self.forward(input=input, target=target) File "/home/mirlab2019/wanglei/MusicTransformer-pytorch/custom/metrics.py", line 75, in forward for k, metric in self.metrics.items()} File "/home/mirlab2019/wanglei/MusicTransformer-pytorch/custom/metrics.py", line 75, in for k, metric in self.metrics.items()} AttributeError: 'tuple' object has no attribute 'to'

After further discovering, i found that the reason is that a tuple is returned in the line 91, train.py. And then the returned tuple is transferred to "metric_set()", however "metric_set()" should receive a tensor instead of a tuple. Actually, adding a new variable to receive the returned attention weights could avoid the above error.

1996Wanglei avatar Nov 11 '19 05:11 1996Wanglei

Do you know how to fix this in code?

mzliang-annie avatar Dec 17 '19 20:12 mzliang-annie

        if self.training:
            return fc.contiguous()
        else:
            return fc.contiguous(), [weight.contiguous() for weight in w]

Separate the original line 43 in model.py to the above, just solve the issue

mzliang-annie avatar Dec 20 '19 22:12 mzliang-annie

I made a PR #11 to fix this.

adamoudad avatar Nov 29 '20 23:11 adamoudad