pytorch_geometric
pytorch_geometric copied to clipboard
[Fix] Improve dim_size handling in SetTransformerAggregation to prevent CUDA crash
This PR improves the robustness of SetTransformerAggregation by:
- Automatically setting
dim_size = index.max() + 1ifdim_sizeis not provided. - Raising a clear error if
index.max() >= dim_sizeto avoid CUDA crashes during evaluation.
This is helpful especially for datasets like PPI where data.batch may be missing. It replaces hard-to-debug GPU errors with clear and early validation.