ExplainableAI.jl
ExplainableAI.jl copied to clipboard
Explainable AI in Julia.
Documentation | Build Status | DOI |
---|---|---|
Explainable AI in Julia using Flux.jl.
This package implements interpretability methods and visualizations for neural networks, similar to Captum and Zennit for PyTorch and iNNvestigate for Keras models.
Installation
This package supports Julia ≥1.6. To install it, open the Julia REPL and run
julia> ]add ExplainableAI
Example
Let's use LRP to explain why an image of a castle gets classified as such using a pre-trained VGG16 model from Metalhead.jl:
using ExplainableAI
using Flux
using Metalhead
using FileIO, HTTP
# Load model
model = VGG(16, pretrain=true).layers
model = strip_softmax(flatten_chain(model))
# Load input
url = HTTP.URI("https://raw.githubusercontent.com/adrhill/ExplainableAI.jl/gh-pages/assets/heatmaps/castle.jpg")
img = load(url)
input = preprocess_imagenet(img)
input = reshape(input, 224, 224, 3, :) # reshape to WHCN format
# Run XAI method
analyzer = LRP(model)
expl = analyze(input, analyzer) # or: expl = analyzer(input)
# Show heatmap
heatmap(expl)
# Or analyze & show heatmap directly
heatmap(input, analyzer)
We can also get an explanation for the activation of the output neuron corresponding to the "street sign" class by specifying the corresponding output neuron position 920
:
analyze(input, analyzer, 920) # for explanation
heatmap(input, analyzer, 920) # for heatmap
Heatmaps for all implemented analyzers are shown in the following table. Red color indicate regions of positive relevance towards the selected class, whereas regions in blue are of negative relevance.
Analyzer | Heatmap for class "castle" | Heatmap for class "street sign" |
---|---|---|
LRP with EpsilonPlus composite |
![]() |
![]() |
LRP with EpsilonPlusFlat composite |
![]() |
![]() |
LRP with EpsilonAlpha2Beta1 composite |
![]() |
![]() |
LRP with EpsilonAlpha2Beta1Flat composite |
![]() |
![]() |
LRP with EpsilonGammaBox composite |
![]() |
![]() |
LRP |
![]() |
![]() |
InputTimesGradient |
![]() |
![]() |
Gradient |
![]() |
![]() |
SmoothGrad |
![]() |
![]() |
IntegratedGradients |
![]() |
![]() |
The code used to generate these heatmaps can be found here.
Video demonstration
Check out our talk at JuliaCon 2022 for a demonstration of the package.
Methods
Currently, the following analyzers are implemented:
├── Gradient
├── InputTimesGradient
├── SmoothGrad
├── IntegratedGradients
└── LRP
├── Rules
│ ├── ZeroRule
│ ├── EpsilonRule
│ ├── GammaRule
│ ├── WSquareRule
│ ├── FlatRule
│ ├── ZBoxRule
│ ├── ZPlusRule
│ ├── AlphaBetaRule
│ └── PassRule
└── Composite
├── EpsilonGammaBox
├── EpsilonPlus
├── EpsilonPlusFlat
├── EpsilonAlpha2Beta1
└── EpsilonAlpha2Beta1Flat
One of the design goals of ExplainableAI.jl is extensibility. Custom composites are easily defined and the package is easily extended by custom rules.
Roadmap
In the future, we would like to include:
- PatternNet
- DeepLift
- LIME
- Shapley values via ShapML.jl
Contributions are welcome!
Acknowledgements
Adrian Hill acknowledges support by the Federal Ministry of Education and Research (BMBF) for the Berlin Institute for the Foundations of Learning and Data (BIFOLD) (01IS18037A).