self-attention-gan icon indicating copy to clipboard operation
self-attention-gan copied to clipboard

HELP REGARDING THE NON LOCAL BLOCK CODE

Open nanmehta opened this issue 3 years ago • 0 comments

`def Nonlocalblock(x): batch_size, height, width, in_channels = x.get_shape().as_list() print("height",height) print("width",width) print("in_channels",in_channels) #print("out_channels",out_channels) print( "shape", x.get_shape())

g1 = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(x)
g1 = tf.math.multiply(g1,x)
print("g1",g1.shape)
g = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding = 'same')(g1)
print("phi",g.shape)  #x, tf.stack( [ -1, nb_maps, nb_feats ] 
hw =  height * width
g_x = tf.reshape(g, [ batch_size, hw, in_channels])
g_x = tf.squeeze(g_x ,axis= 0)        
phi = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("phi",phi.shape)
theta = tf.keras.layers.Conv2D(in_channels, 1, strides=(1, 1), padding='same')(g1)
print("theta",theta.shape)
print("g_x",g_x.shape)                                               #64,16384   
theta_x = tf.reshape(theta, [ batch_size, hw, in_channels])                     #64,16384
print( "theta_x",theta_x.shape)
phi_x = tf.reshape(phi, ([ batch_size, hw, in_channels]))
phi_x1 = tf.squeeze(phi_x ,axis= 0) 
print( "phi_x",phi_x.shape) 
#theta_x1 = tf.transpose(theta_x, [0,2,1])   
#theta_x1 = tf.squeeze(theta_x1 ,axis= 0)                                        #16384,64
#print( "theta_x1",theta_x1.shape) 
print( "theta_x",theta_x.shape)           
f = tf.matmul( theta_x,g_x,transpose_b=True )                                          #64,64
print("f",f.shape)
f = tf.nn.softmax(f, -1)
y = tf.matmul(phi_x1,f )  
print("y",y.shape) 
y1 = tf.nn.softmax(y)
print("y1",y1.shape)                                            #64,16384
y1 = tf.reshape(y1, [ batch_size, height, width,in_channels])
print("y1",y1.shape)
print("in_channels",in_channels)  
print( "y2" , y1.shape )
return y1 

`

I have implemented this non-local attention block as shown in the above code, but the problem is that when I am using it in a network the batch-size is always None, so while using it for multiplication and reshaping is giving me error

nanmehta avatar Mar 08 '21 17:03 nanmehta