pytorch_geometric
pytorch_geometric copied to clipboard
TypeError: BasicGNN.forward() got an unexpected keyword argument 'return_attention_weights'
🐛 Describe the bug
In GATv2Conv, return_attention_weights is expected to be passed as a parameter to the forward function.
When using Gatv2Conv indirectly by using GAT with v2=True, I cannot pass return_attention_weights=True when calling the GAT model (= using the forward function). This will throw the error above.
One can pass return_attention_weights=True in the GAT initalization. But then, this will never be used, as the forward function in GATv2Conv does not access self.return_attention_weights (it does not exist)). The forward function requires the parameter return_attention_weights instead of accessing a class attribute.
My solution for now is:
-
add in the
__init__function of Gatv2Conv:`self.return_attention_weights = kwargs["return_attention_weights"]` -
in the
forwardfunction of Gatv2Conv: https://github.com/pyg-team/pytorch_geometric/blob/38bb5f29375e9ba9dc56654af8d3f35551480f6e/torch_geometric/nn/conv/gatv2_conv.py#L312`if isinstance(self.return_attention_weights, bool):`
I am using pytorch geometric version 2.5.2.
I assume, the same problem occurs when using v2=False (with GatConv).
Versions
Python version: 3.11.7
Are you referring to models.GAT? If you need to return attention weights, I suggest to use the GNN layers and build your own model on top.