captum icon indicating copy to clipboard operation
captum copied to clipboard

Adding IDGI to Captum

Open ShreeSinghi opened this issue 1 year ago • 4 comments

(making a new PR since I missed creating the design doc earlier) RE: #1246

IDGI API Design

Background

IDGI (Important Direction Integrated Gradients) is a generalized framework that can be applied on top of any Integrated Gradients method. It works by removing the explanation noise by only considering the "important direction" while calculating the Riemann Sum. The method requires an underlying path/method to be defined. The original paper shows results for 3 underlying paths/methods: IG, BlurIG, GIG, and in all cases, IDGI + baseline method outperforms the standalone baseline by far.

Requires:

  • Model
  • X = Input
  • n = number of riemann integration steps
  • index = the index of the output class that the image belongs to (in case of classification algorithms)
  • baseline image

Pseudocode:

image

Proposed Captum API Design for IDGI:

The design will be very similar to that of the ‎IntegratedGradients class. The IDGI framework is very generic and works for any classification/regression model. For now, we only implement IDGI for the original IG method which defines the path as a linear interpolation between the baseline and input image, but in the future we can take an additional method argument and choose a different path as defined by GIG or BlurIG. The IDGI class inherits from GradientAttribution and contains complete implementation of the algorithm. The implementation

Constructor:

IDGI(forward_func: Callable) 

Argument Descriptions:

  • forward_func - torch.nn.Module corresponding to model for which attributions are desired. This is consistent with all other attribution constructors in Captum.

attribute:

attribute(inputs: TensorOrTupleOfTensorsGeneric, 
          baselines:  BaselineType,
          target: TargetType
          additional_forward_args: Any,
          n_steps: int,
          internal_batch_size: Union[None, int], 
          return_convergence_delta: bool)

Argument Descriptions:

These arguments follow standard definitions of the existing IntegratedGradients. Here's a comparison of the existing methods (top) and IDGI (bottom) image

ShreeSinghi avatar May 24 '24 09:05 ShreeSinghi