captum
captum copied to clipboard
Visualization feature is not working
🐛 Bug
<I am following the example on Model Interpretation for Pretrained ResNet Model¶. When I try to make the visualization I receive an error -->
To Reproduce
Steps to reproduce the behavior:
- Run the tutorial given in the link https://captum.ai/tutorials/Resnet_TorchVision_Interpret#Model-Interpretation-for-Pretrained-ResNet-Model
7 np.transpose(transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
8 method='heat_map',
9 cmap=default_cmap,
10 show_colorbar=True,
11 sign='positive',
12 outlier_perc=1)
File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/captum/attr/_utils/visualization.py:250, in visualize_image_attr(attr, original_image, method, sign, plt_fig_axis, outlier_perc, cmap, alpha_overlay, show_colorbar, title, fig_size, use_pyplot) 248 plt_axis.set_yticklabels([]) 249 plt_axis.set_xticklabels([]) --> 250 plt_axis.grid(b=False) 252 heat_map = None 253 # Show original image
File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axes/_base.py:3194, in _AxesBase.grid(self, visible, which, axis, **kwargs) 3192 _api.check_in_list(['x', 'y', 'both'], axis=axis) 3193 if axis in ['x', 'both']: -> 3194 self.xaxis.grid(visible, which=which, **kwargs) 3195 if axis in ['y', 'both']: 3196 self.yaxis.grid(visible, which=which, **kwargs)
File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:1660, in Axis.grid(self, visible, which, **kwargs) 1657 if which in ['major', 'both']: 1658 gridkw['gridOn'] = (not self._major_tick_kw['gridOn'] 1659 if visible is None else visible) -> 1660 self.set_tick_params(which='major', **gridkw) 1661 self.stale = True
File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:932, in Axis.set_tick_params(self, which, reset, **kwargs) 919 """ 920 Set appearance parameters for ticks, ticklabels, and gridlines. 921 (...) 929 gridlines. 930 """ 931 _api.check_in_list(['major', 'minor', 'both'], which=which) --> 932 kwtrans = self._translate_tick_params(kwargs) 934 # the kwargs are stored in self._major/minor_tick_kw so that any 935 # future new ticks will automatically get them 936 if reset:
File /scratch/anaconda3/envs/XAI_BM/lib/python3.10/site-packages/matplotlib/axis.py:1076, in Axis.translate_tick_params(kw, reverse) 1074 for key in kw: 1075 if key not in allowed_keys: -> 1076 raise ValueError( 1077 "keyword %s is not recognized; valid keywords are %s" 1078 % (key, allowed_keys)) 1079 kwtrans.update(kw_) 1080 return kwtrans
ValueError: keyword grid_b is not recognized; valid keywords are ['size', 'width', 'color', 'tickdir', 'pad', 'labelsize', 'labelcolor', 'zorder', 'gridOn', 'tick1On', 'tick2On', 'label1On', 'label2On', 'length', 'direction', 'left', 'bottom', 'right', 'top', 'labelleft', 'labelbottom', 'labelright', 'labeltop', 'labelrotation', 'grid_agg_filter', 'grid_alpha', 'grid_animated', 'grid_antialiased', 'grid_clip_box', 'grid_clip_on', 'grid_clip_path', 'grid_color', 'grid_dash_capstyle', 'grid_dash_joinstyle', 'grid_dashes', 'grid_data', 'grid_drawstyle', 'grid_figure', 'grid_fillstyle', 'grid_gapcolor', 'grid_gid', 'grid_in_layout', 'grid_label', 'grid_linestyle', 'grid_linewidth', 'grid_marker', 'grid_markeredgecolor', 'grid_markeredgewidth', 'grid_markerfacecolor', 'grid_markerfacecoloralt', 'grid_markersize', 'grid_markevery', 'grid_mouseover', 'grid_path_effects', 'grid_picker', 'grid_pickradius', 'grid_rasterized', 'grid_sketch_params', 'grid_snap', 'grid_solid_capstyle', 'grid_solid_joinstyle', 'grid_transform', 'grid_url', 'grid_visible', 'grid_xdata', 'grid_ydata', 'grid_zorder', 'grid_aa', 'grid_c', 'grid_ds', 'grid_ls', 'grid_lw', 'grid_mec', 'grid_mew', 'grid_mfc', 'grid_mfcalt', 'grid_ms']
-->
Expected behavior
Environment
Describe the environment used for Captum
- PyTorch: '2.1.0.dev20230711+cu121'
- OS: Linux
- Captum from source: pip install git+https://github.com/pytorch/captum.git
- Build command you used (if compiling from source):
- Python 3.10.9
- CUDA version: 12.1
- GPU : A100
Hi,
I manage to reproduce this error.
It can be corrected by a change in the line 250 of captum/attr/_utils/visualization.py
:
250 plt_axis.grid(b=False)
to
250 plt_axis.grid(visible=False)
After this change, it worked on my side.
The b argument seems to have been removed from matplotlib.axes.Axes.grid since version 3.5. If you cannot or don't want to modify Captum files, donwgrade to matplotlib 3.4.3 might work (not tested)
The issue has already been fixed: #1118
You can install directely from source to make it works :
pip install git+https://github.com/pytorch/captum.git
Thanks, that solved my issue over here, too!