keras-io icon indicating copy to clipboard operation
keras-io copied to clipboard

GradCAM with SSD model producing zero gradients for some Conv layers

Open sriram-praveen-work opened this issue 2 years ago • 3 comments
trafficstars

I have implemented the GradCAM algorithm with an SSD model and I was able to generate the heatmap only for few layers. The convolutional layers towards the end of the model produce empty(zero) gradients. This is the architecture of the model used.

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 300, 300, 3  0           []                               
                                )]                                                                
                                                                                                  
 identity_layer (Lambda)        (None, 300, 300, 3)  0           ['input_1[0][0]']                
                                                                                                  
 input_mean_normalization (Lamb  (None, 300, 300, 3)  0          ['identity_layer[0][0]']         
 da)                                                                                              
                                                                                                  
 input_channel_swap (Lambda)    (None, 300, 300, 3)  0           ['input_mean_normalization[0][0]'
                                                                 ]                                
                                                                                                  
 conv1_1 (Conv2D)               (None, 300, 300, 64  1792        ['input_channel_swap[0][0]']     
                                )                                                                 
                                                                                                  
 conv1_2 (Conv2D)               (None, 300, 300, 64  36928       ['conv1_1[0][0]']                
                                )                                                                 
                                                                                                  
 pool1 (MaxPooling2D)           (None, 150, 150, 64  0           ['conv1_2[0][0]']                
                                )                                                                 
                                                                                                  
 conv2_1 (Conv2D)               (None, 150, 150, 12  73856       ['pool1[0][0]']                  
                                8)                                                                
                                                                                                  
 conv2_2 (Conv2D)               (None, 150, 150, 12  147584      ['conv2_1[0][0]']                
                                8)                                                                
                                                                                                  
 pool2 (MaxPooling2D)           (None, 75, 75, 128)  0           ['conv2_2[0][0]']                
                                                                                                  
 conv3_1 (Conv2D)               (None, 75, 75, 256)  295168      ['pool2[0][0]']                  
                                                                                                  
 conv3_2 (Conv2D)               (None, 75, 75, 256)  590080      ['conv3_1[0][0]']                
                                                                                                  
 conv3_3 (Conv2D)               (None, 75, 75, 256)  590080      ['conv3_2[0][0]']                
                                                                                                  
 pool3 (MaxPooling2D)           (None, 38, 38, 256)  0           ['conv3_3[0][0]']                
                                                                                                  
 conv4_1 (Conv2D)               (None, 38, 38, 512)  1180160     ['pool3[0][0]']                  
                                                                                                  
 conv4_2 (Conv2D)               (None, 38, 38, 512)  2359808     ['conv4_1[0][0]']                
                                                                                                  
 conv4_3 (Conv2D)               (None, 38, 38, 512)  2359808     ['conv4_2[0][0]']                
                                                                                                  
 pool4 (MaxPooling2D)           (None, 19, 19, 512)  0           ['conv4_3[0][0]']                
                                                                                                  
 conv5_1 (Conv2D)               (None, 19, 19, 512)  2359808     ['pool4[0][0]']                  
                                                                                                  
 conv5_2 (Conv2D)               (None, 19, 19, 512)  2359808     ['conv5_1[0][0]']                
                                                                                                  
 conv5_3 (Conv2D)               (None, 19, 19, 512)  2359808     ['conv5_2[0][0]']                
                                                                                                  
 pool5 (MaxPooling2D)           (None, 19, 19, 512)  0           ['conv5_3[0][0]']                
                                                                                                  
 fc6 (Conv2D)                   (None, 19, 19, 1024  4719616     ['pool5[0][0]']                  
                                )                                                                 
                                                                                                  
 fc7 (Conv2D)                   (None, 19, 19, 1024  1049600     ['fc6[0][0]']                    
                                )                                                                 
                                                                                                  
 conv6_1 (Conv2D)               (None, 19, 19, 256)  262400      ['fc7[0][0]']                    
                                                                                                  
 conv6_padding (ZeroPadding2D)  (None, 21, 21, 256)  0           ['conv6_1[0][0]']                
                                                                                                  
 conv6_2 (Conv2D)               (None, 10, 10, 512)  1180160     ['conv6_padding[0][0]']          
                                                                                                  
 conv7_1 (Conv2D)               (None, 10, 10, 128)  65664       ['conv6_2[0][0]']                
                                                                                                  
 conv7_padding (ZeroPadding2D)  (None, 12, 12, 128)  0           ['conv7_1[0][0]']                
                                                                                                  
 conv7_2 (Conv2D)               (None, 5, 5, 256)    295168      ['conv7_padding[0][0]']          
                                                                                                  
 conv8_1 (Conv2D)               (None, 5, 5, 128)    32896       ['conv7_2[0][0]']                
                                                                                                  
 conv8_2 (Conv2D)               (None, 3, 3, 256)    295168      ['conv8_1[0][0]']                
                                                                                                  
 conv9_1 (Conv2D)               (None, 3, 3, 128)    32896       ['conv8_2[0][0]']                
                                                                                                  
 conv4_3_norm (L2Normalization)  (None, 38, 38, 512)  512        ['conv4_3[0][0]']                
                                                                                                  
 conv9_2 (Conv2D)               (None, 1, 1, 256)    295168      ['conv9_1[0][0]']                
                                                                                                  
 conv4_3_norm_mbox_conf (Conv2D  (None, 38, 38, 84)  387156      ['conv4_3_norm[0][0]']           
 )                                                                                                
                                                                                                  
 fc7_mbox_conf (Conv2D)         (None, 19, 19, 126)  1161342     ['fc7[0][0]']                    
                                                                                                  
 conv6_2_mbox_conf (Conv2D)     (None, 10, 10, 126)  580734      ['conv6_2[0][0]']                
                                                                                                  
 conv7_2_mbox_conf (Conv2D)     (None, 5, 5, 126)    290430      ['conv7_2[0][0]']                
                                                                                                  
 conv8_2_mbox_conf (Conv2D)     (None, 3, 3, 84)     193620      ['conv8_2[0][0]']                
                                                                                                  
 conv9_2_mbox_conf (Conv2D)     (None, 1, 1, 84)     193620      ['conv9_2[0][0]']                
                                                                                                  
 conv4_3_norm_mbox_loc (Conv2D)  (None, 38, 38, 16)  73744       ['conv4_3_norm[0][0]']           
                                                                                                  
 fc7_mbox_loc (Conv2D)          (None, 19, 19, 24)   221208      ['fc7[0][0]']                    
                                                                                                  
 conv6_2_mbox_loc (Conv2D)      (None, 10, 10, 24)   110616      ['conv6_2[0][0]']                
                                                                                                  
 conv7_2_mbox_loc (Conv2D)      (None, 5, 5, 24)     55320       ['conv7_2[0][0]']                
                                                                                                  
 conv8_2_mbox_loc (Conv2D)      (None, 3, 3, 16)     36880       ['conv8_2[0][0]']                
                                                                                                  
 conv9_2_mbox_loc (Conv2D)      (None, 1, 1, 16)     36880       ['conv9_2[0][0]']                
                                                                                                  
 conv4_3_norm_mbox_conf_reshape  (None, 5776, 21)    0           ['conv4_3_norm_mbox_conf[0][0]'] 
  (Reshape)                                                                                       
                                                                                                  
 fc7_mbox_conf_reshape (Reshape  (None, 2166, 21)    0           ['fc7_mbox_conf[0][0]']          
 )                                                                                                
                                                                                                  
 conv6_2_mbox_conf_reshape (Res  (None, 600, 21)     0           ['conv6_2_mbox_conf[0][0]']      
 hape)                                                                                            
                                                                                                  
 conv7_2_mbox_conf_reshape (Res  (None, 150, 21)     0           ['conv7_2_mbox_conf[0][0]']      
 hape)                                                                                            
                                                                                                  
 conv8_2_mbox_conf_reshape (Res  (None, 36, 21)      0           ['conv8_2_mbox_conf[0][0]']      
 hape)                                                                                            
                                                                                                  
 conv9_2_mbox_conf_reshape (Res  (None, 4, 21)       0           ['conv9_2_mbox_conf[0][0]']      
 hape)                                                                                            
                                                                                                  
 conv4_3_norm_mbox_priorbox (An  (None, 38, 38, 4, 8  0          ['conv4_3_norm_mbox_loc[0][0]']  
 chorBoxes)                     )                                                                 
                                                                                                  
 fc7_mbox_priorbox (AnchorBoxes  (None, 19, 19, 6, 8  0          ['fc7_mbox_loc[0][0]']           
 )                              )                                                                 
                                                                                                  
 conv6_2_mbox_priorbox (AnchorB  (None, 10, 10, 6, 8  0          ['conv6_2_mbox_loc[0][0]']       
 oxes)                          )                                                                 
                                                                                                  
 conv7_2_mbox_priorbox (AnchorB  (None, 5, 5, 6, 8)  0           ['conv7_2_mbox_loc[0][0]']       
 oxes)                                                                                            
                                                                                                  
 conv8_2_mbox_priorbox (AnchorB  (None, 3, 3, 4, 8)  0           ['conv8_2_mbox_loc[0][0]']       
 oxes)                                                                                            
                                                                                                  
 conv9_2_mbox_priorbox (AnchorB  (None, 1, 1, 4, 8)  0           ['conv9_2_mbox_loc[0][0]']       
 oxes)                                                                                            
                                                                                                  
 mbox_conf (Concatenate)        (None, 8732, 21)     0           ['conv4_3_norm_mbox_conf_reshape[
                                                                 0][0]',                          
                                                                  'fc7_mbox_conf_reshape[0][0]',  
                                                                  'conv6_2_mbox_conf_reshape[0][0]
                                                                 ',                               
                                                                  'conv7_2_mbox_conf_reshape[0][0]
                                                                 ',                               
                                                                  'conv8_2_mbox_conf_reshape[0][0]
                                                                 ',                               
                                                                  'conv9_2_mbox_conf_reshape[0][0]
                                                                 ']                               
                                                                                                  
 conv4_3_norm_mbox_loc_reshape   (None, 5776, 4)     0           ['conv4_3_norm_mbox_loc[0][0]']  
 (Reshape)                                                                                        
                                                                                                  
 fc7_mbox_loc_reshape (Reshape)  (None, 2166, 4)     0           ['fc7_mbox_loc[0][0]']           
                                                                                                  
 conv6_2_mbox_loc_reshape (Resh  (None, 600, 4)      0           ['conv6_2_mbox_loc[0][0]']       
 ape)                                                                                             
                                                                                                  
 conv7_2_mbox_loc_reshape (Resh  (None, 150, 4)      0           ['conv7_2_mbox_loc[0][0]']       
 ape)                                                                                             
                                                                                                  
 conv8_2_mbox_loc_reshape (Resh  (None, 36, 4)       0           ['conv8_2_mbox_loc[0][0]']       
 ape)                                                                                             
                                                                                                  
 conv9_2_mbox_loc_reshape (Resh  (None, 4, 4)        0           ['conv9_2_mbox_loc[0][0]']       
 ape)                                                                                             
                                                                                                  
 conv4_3_norm_mbox_priorbox_res  (None, 5776, 8)     0           ['conv4_3_norm_mbox_priorbox[0][0
 hape (Reshape)                                                  ]']                              
                                                                                                  
 fc7_mbox_priorbox_reshape (Res  (None, 2166, 8)     0           ['fc7_mbox_priorbox[0][0]']      
 hape)                                                                                            
                                                                                                  
 conv6_2_mbox_priorbox_reshape   (None, 600, 8)      0           ['conv6_2_mbox_priorbox[0][0]']  
 (Reshape)                                                                                        
                                                                                                  
 conv7_2_mbox_priorbox_reshape   (None, 150, 8)      0           ['conv7_2_mbox_priorbox[0][0]']  
 (Reshape)                                                                                        
                                                                                                  
 conv8_2_mbox_priorbox_reshape   (None, 36, 8)       0           ['conv8_2_mbox_priorbox[0][0]']  
 (Reshape)                                                                                        
                                                                                                  
 conv9_2_mbox_priorbox_reshape   (None, 4, 8)        0           ['conv9_2_mbox_priorbox[0][0]']  
 (Reshape)                                                                                        
                                                                                                  
 mbox_conf_softmax (Activation)  (None, 8732, 21)    0           ['mbox_conf[0][0]']              
                                                                                                  
 mbox_loc (Concatenate)         (None, 8732, 4)      0           ['conv4_3_norm_mbox_loc_reshape[0
                                                                 ][0]',                           
                                                                  'fc7_mbox_loc_reshape[0][0]',   
                                                                  'conv6_2_mbox_loc_reshape[0][0]'
                                                                 , 'conv7_2_mbox_loc_reshape[0][0]
                                                                 ',                               
                                                                  'conv8_2_mbox_loc_reshape[0][0]'
                                                                 , 'conv9_2_mbox_loc_reshape[0][0]
                                                                 ']                               
                                                                                                  
 mbox_priorbox (Concatenate)    (None, 8732, 8)      0           ['conv4_3_norm_mbox_priorbox_resh
                                                                 ape[0][0]',                      
                                                                  'fc7_mbox_priorbox_reshape[0][0]
                                                                 ',                               
                                                                  'conv6_2_mbox_priorbox_reshape[0
                                                                 ][0]',                           
                                                                  'conv7_2_mbox_priorbox_reshape[0
                                                                 ][0]',                           
                                                                  'conv8_2_mbox_priorbox_reshape[0
                                                                 ][0]',                           
                                                                  'conv9_2_mbox_priorbox_reshape[0
                                                                 ][0]']                           
                                                                                                  
 predictions (Concatenate)      (None, 8732, 33)     0           ['mbox_conf_softmax[0][0]',      
                                                                  'mbox_loc[0][0]',               
                                                                  'mbox_priorbox[0][0]']          
                                                                                                  
 decoded_predictions (DecodeDet  (None, 200, 6)      0           ['predictions[0][0]']            
 ections)                                                                                         
                                                                                                  
==================================================================================================
Total params: 26,285,486
Trainable params: 26,285,486
Non-trainable params: 0
__________________________________________________________________________________________________

This is the change that ive made to the gradCAM code

def grad_cam(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape(persistent=True) as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        tape.watch(last_conv_layer_output)
        if pred_index is None:
            pred_index = 0
        class_channel = preds[:, pred_index][0]
        print(class_channel)
        conf_tensor = tf.reshape(class_channel[1], [1,])
        tape.watch(conf_tensor)


        # This is the gradient of the output neuron (top predicted or chosen)
        # with regard to the output feature map of the last conv layer
        grads = tape.gradient(conf_tensor, last_conv_layer_output)
        

        # This is a vector where each entry is the mean intensity of the gradient
        # over a specific feature map channel
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

        # We multiply each channel in the feature map array
        # by "how important this channel is" with regard to the top predicted class
        # then sum all the channels to obtain the heatmap class activation
        last_conv_layer_output = last_conv_layer_output[0]
        heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
        heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

I can choose conv2D layers upto "conv6_2" to generate the gradients, but I cannot generate any gradients in the layers after it. I am using the class confidence score and the feature map of selected layer to generate the gradients.

sriram-praveen-work avatar Apr 28 '23 04:04 sriram-praveen-work