captum
captum copied to clipboard
Adding IDGI to Captum
(making a new PR since I missed creating the design doc earlier) RE: #1246
IDGI API Design
Background
IDGI (Important Direction Integrated Gradients) is a generalized framework that can be applied on top of any Integrated Gradients method. It works by removing the explanation noise by only considering the "important direction" while calculating the Riemann Sum. The method requires an underlying path/method to be defined. The original paper shows results for 3 underlying paths/methods: IG, BlurIG, GIG, and in all cases, IDGI + baseline method outperforms the standalone baseline by far.
Requires:
- Model
- X = Input
- n = number of riemann integration steps
- index = the index of the output class that the image belongs to (in case of classification algorithms)
- baseline image
Pseudocode:
Proposed Captum API Design for IDGI:
The design will be very similar to that of the IntegratedGradients class. The IDGI framework is very generic and works for any classification/regression model. For now, we only implement IDGI for the original IG method which defines the path as a linear interpolation between the baseline and input image, but in the future we can take an additional method argument and choose a different path as defined by GIG or BlurIG. The IDGI class inherits from GradientAttribution and contains complete implementation of the algorithm. The implementation
Constructor:
IDGI(forward_func: Callable)
Argument Descriptions:
forward_func-torch.nn.Modulecorresponding to model for which attributions are desired. This is consistent with all other attribution constructors in Captum.
attribute:
attribute(inputs: TensorOrTupleOfTensorsGeneric,
baselines: BaselineType,
target: TargetType
additional_forward_args: Any,
n_steps: int,
internal_batch_size: Union[None, int],
return_convergence_delta: bool)
Argument Descriptions:
These arguments follow standard definitions of the existing IntegratedGradients.
Here's a comparison of the existing methods (top) and IDGI (bottom)