gradient-variance-loss
gradient-variance-loss copied to clipboard
tensorflow 1.x implementation
Was able to get this to work in tf 1.x, tested to work on 1.15.4 with NumPy 1.18.0
def gradientvariance(target, output, size = 128):
kernel_x = tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], tf.float32)
kernel_x = tf.expand_dims(tf.expand_dims(kernel_x, 2), 3)
kernel_y = tf.constant([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], tf.float32)
kernel_y = tf.expand_dims(tf.expand_dims(kernel_y, 2), 3)
grayoutput = grayoutput = tf.expand_dims(tf.reduce_mean(output, axis = -1), 3)
grayoutput = tf.placeholder_with_default(grayoutput,[None,size,size,1])
graytarget = tf.expand_dims(tf.reduce_mean(target, axis = -1), 3)
graytarget = tf.placeholder_with_default(graytarget,[None,size,size,1])
gxtarget = tf.nn.conv2d(graytarget, kernel_x, padding = 'SAME')
gytarget = tf.nn.conv2d(graytarget, kernel_y, padding = 'SAME')
gxoutput = tf.nn.conv2d(grayoutput, kernel_x, padding = 'SAME')
gyoutput = tf.nn.conv2d(grayoutput, kernel_y, padding = 'SAME')
gx_target_patches = tf.image.extract_image_patches(gxtarget, ksizes = [1, 8, 8, 1],
strides = [1, 8, 8, 1],
rates = [1, 1, 1, 1], padding = 'VALID')
gy_target_patches = tf.image.extract_image_patches(gytarget, ksizes = [1, 8, 8, 1],
strides = [1, 8, 8, 1],
rates = [1, 1, 1, 1], padding = 'VALID')
gx_output_patches = tf.image.extract_image_patches(gxoutput, ksizes = [1, 8, 8, 1],
strides = [1, 8, 8, 1],
rates = [1, 1, 1, 1], padding = 'VALID')
gy_output_patches = tf.image.extract_image_patches(gyoutput, ksizes = [1, 8, 8, 1],
strides = [1, 8, 8, 1],
rates = [1, 1, 1, 1], padding = 'VALID')
var_target_x = tf.math.reduce_variance(gx_target_patches, axis=(1, 2))
var_output_x = tf.math.reduce_variance(gx_output_patches, axis=(1, 2))
var_target_y = tf.math.reduce_variance(gy_target_patches, axis=(1, 2))
var_output_y = tf.math.reduce_variance(gy_output_patches, axis=(1, 2))
return (tf.reduce_mean(tf.abs(var_target_x - var_output_x)) +
tf.reduce_mean(tf.abs(var_target_y + var_output_y)))
thanks for your work!