influenza_transformer
influenza_transformer copied to clipboard
wrong dimmensions in sandbox.py example
There is some issue with the sandbox.py file's input parameter. I get the following error
Traceback (most recent call last):
File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/sandbox.py", line 163, in <module>
prediction = model(src, tgt, src_mask, tgt_mask)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/transformer_timeseries.py", line 226, in forward
decoder_output = self.decoder(
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 369, in forward
output = mod(output, memory, tgt_mask=tgt_mask,
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 716, in forward
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 725, in _sa_block
x = self.self_attn(x, x, x,
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1205, in forward
attn_output, attn_output_weights = F.multi_head_attention_forward(
File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/functional.py", line 5251, in multi_head_attention_forward
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([48, 48]), but should be (128, 128).
(victor_env) (base) anshumansinha@Anshumans-MacBook-Pro-3 influenza_transformer-main %