influenza_transformer icon indicating copy to clipboard operation
influenza_transformer copied to clipboard

wrong dimmensions in sandbox.py example

Open anshumansinha16 opened this issue 1 year ago • 0 comments

There is some issue with the sandbox.py file's input parameter. I get the following error

Traceback (most recent call last):
  File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/sandbox.py", line 163, in <module>
    prediction = model(src, tgt, src_mask, tgt_mask)
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/anshumansinha/Desktop/StructRepGen_Dev/influenza_transformer-main/transformer_timeseries.py", line 226, in forward
    decoder_output = self.decoder(
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 369, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 716, in forward
    x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 725, in _sa_block
    x = self.self_attn(x, x, x,
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1205, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/Users/anshumansinha/miniconda3/lib/python3.10/site-packages/torch/nn/functional.py", line 5251, in multi_head_attention_forward
    raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
RuntimeError: The shape of the 2D attn_mask is torch.Size([48, 48]), but should be (128, 128).
(victor_env) (base) anshumansinha@Anshumans-MacBook-Pro-3 influenza_transformer-main % 

anshumansinha16 avatar Jul 03 '23 02:07 anshumansinha16