pytorch_geometric
pytorch_geometric copied to clipboard
Fixing one of the `captum` tests.
The current PR provides the fix for the following 9 failing tests.
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.node-MaskType.object-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.node-MaskType.object-None-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.node-None-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.edge-MaskType.object-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.edge-MaskType.object-None-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.edge-None-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.graph-MaskType.object-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.graph-MaskType.object-None-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1]' is invalid for input of size 3
FAILED test/explain/algorithm/test_captum_explainer.py::test_captum_explainer_multiclass_classification[index1-ModelTaskLevel.graph-None-MaskType.attributes-ShapleyValueSampling] - RuntimeError: shape '[-1, 2, 1, 1]' is invalid for input of size 3
NOTE: This is a resubmission of PR#9513, which had some issues with my captum branch.