pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

TypeError: BasicGNN.forward() got an unexpected keyword argument 'return_attention_weights'

Open Batene opened this issue 1 year ago • 1 comments
trafficstars

🐛 Describe the bug

In GATv2Conv, return_attention_weights is expected to be passed as a parameter to the forward function. When using Gatv2Conv indirectly by using GAT with v2=True, I cannot pass return_attention_weights=True when calling the GAT model (= using the forward function). This will throw the error above. One can pass return_attention_weights=True in the GAT initalization. But then, this will never be used, as the forward function in GATv2Conv does not access self.return_attention_weights (it does not exist)). The forward function requires the parameter return_attention_weights instead of accessing a class attribute.

My solution for now is:

  1. add in the __init__ function of Gatv2Conv:

     `self.return_attention_weights = kwargs["return_attention_weights"]`
    
  2. in the forward function of Gatv2Conv: https://github.com/pyg-team/pytorch_geometric/blob/38bb5f29375e9ba9dc56654af8d3f35551480f6e/torch_geometric/nn/conv/gatv2_conv.py#L312

     `if isinstance(self.return_attention_weights, bool):`
    

I am using pytorch geometric version 2.5.2. I assume, the same problem occurs when using v2=False (with GatConv).

Versions

Python version: 3.11.7

Batene avatar Apr 06 '24 09:04 Batene

Are you referring to models.GAT? If you need to return attention weights, I suggest to use the GNN layers and build your own model on top.

rusty1s avatar Apr 08 '24 13:04 rusty1s