DeepExplain icon indicating copy to clipboard operation
DeepExplain copied to clipboard

Fetch argument None has invalid type <class 'NoneType'>

Open ahof1704 opened this issue 4 years ago • 0 comments

Hi,

I am trying to use DeepExplain with the Concept Saliency Map. However, I run into the following issue:

import keras
sess = K.get_session()
print('sess: ',sess)
from ConceptSaliencyMaps.deepexplain.tensorflow import DeepExplain
from ConceptSaliencyMaps.deepexplain.utils import preprocess

list_files = []
all_files = train_files + test_files
for file_name in files_max:
    for file_name2 in all_files:
        if file_name in file_name2:
            list_files.append(file_name2)
            
test_set2 = zfish_age(list_files, path_to_save = path_to_augmented, test=True, transform = True, new_channel=new_channel, new_size_frame=size_frame, 
                     verbose=False)
test_generator2 = data.DataLoader(test_set2,batch_size=1,
                                       shuffle=False,
                                       num_workers=20)

input_img = keras.Input(shape=(50, 128, 128)) 

with DeepExplain(session=sess, graph=sess.graph) as de:
    with torch.no_grad():
        for i, d in enumerate(test_generator2): 
            xis, _, _, labels_name = d
            print('labels_name: {}'.format(labels_name))
                
            input_tensor = input_img
            img_array = xis.reshape([1,50,128,128])
            ris, zis = model(xis.to(device))
            print('zis.shape: ',zis.shape) # torch.Size([1, 256])
            latents = reducer.transform(zis.cpu().detach())
            print('latents.shape: ',latents.shape) # (1, 2)
            method = 'guidedbp'

            concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
            attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]```

Error:

TypeError                                 Traceback (most recent call last)
<ipython-input-169-177871cfe4fc> in <module>
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

<ipython-input-169-177871cfe4fc> in <listcomp>(.0)
     73 
     74             concept_score = [K.sum(latents*i) for i in concept_vectors[attr]]
---> 75             attributions_guided = [de.explain(method, i, input_tensor, img_array) for i in concept_score]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in explain(self, method, T, X, xs, **kwargs)
    733         _ENABLED_METHOD_CLASS = method_class
    734         method = _ENABLED_METHOD_CLASS(T, X, xs, self.session, self.keras_phase_placeholder, **kwargs)
--> 735         result = method.run()
    736         if issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod) and _GRAD_OVERRIDE_CHECKFLAG == 0:
    737             warnings.warn('DeepExplain detected you are trying to use an attribution method that requires '

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in run(self)
    463         for alpha in list(np.linspace(1. / self.steps, 1.0, self.steps)):
    464             xs_mod = [xs * alpha for xs in self.xs] if self.has_multiple_inputs else self.xs * alpha
--> 465             _attr = self.session_run(attributions, xs_mod)
    466             if gradient is None: gradient = _attr
    467             else: gradient = [g + a for g, a in zip(gradient, _attr)]

../ConceptSaliencyMaps/deepexplain/tensorflow/methods.py in session_run(self, T, xs)
     94         if self.keras_learning_phase is not None:
     95             feed_dict[self.keras_learning_phase] = 0
---> 96         return self.session.run(T, feed_dict)
     97 
     98     def _set_check_baseline(self):

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1163     # Create a fetch handler to take care of the structure of fetches.
   1164     fetch_handler = _FetchHandler(
-> 1165         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1166 
   1167     # Run request and get response.

..lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    472     """
    473     with graph.as_default():
--> 474       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    475     self._fetches = []
    476     self._targets = []

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.
--> 266       return _ListFetchMapper(fetch)
    267     elif isinstance(fetch, collections_abc.Mapping):
    268       return _DictFetchMapper(fetch)

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in __init__(self, fetches)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in <listcomp>(.0)
    373     """
    374     self._fetch_type = type(fetches)
--> 375     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    376     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    377 

../lib/python3.7/site-packages/tensorflow_core/python/client/session.py in for_fetch(fetch)
    261     if fetch is None:
    262       raise TypeError('Fetch argument %r has invalid type %r' %
--> 263                       (fetch, type(fetch)))
    264     elif isinstance(fetch, (list, tuple)):
    265       # NOTE(touts): This is also the code path for namedtuples.

TypeError: Fetch argument None has invalid type <class 'NoneType'>

I have reported the issue in the Concept Sal. Maps github as well, but both the developer and I believe the issue is related to DeepExplain. Any insights into this problem?

Please let me know if you need any further info related to the problem. Thanks!

ahof1704 avatar May 30 '21 14:05 ahof1704