Layerwise-Relevance-Propagation
Layerwise-Relevance-Propagation copied to clipboard
Changing neural network architecture doesn't improve heatmaps
In mnist program, I modified it to feed images and changed neural network architecture by including more convolutional layers and tried for cat/Dog images instead of mnist data. I got heat maps which include features other than cat and dog also. please let me know what has to be done for getting proper heat maps.
model.py
import tensorflow as tf
class MNIST_CNN:
def init(self, name='MNIST_CNN'): self.name = name
def convlayer(self, input, shape, name): w_conv = tf.Variable(tf.truncated_normal(shape=shape, dtype=tf.float32, stddev=0.1), name='w_{0}'.format(name)) b_conv = tf.Variable(tf.constant(0.0, shape=shape[-1:], dtype=tf.float32), name='b_{0}'.format(name)) conv = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(input, w_conv, [1, 1, 1, 1], padding='SAME'), b_conv), name=name) return w_conv, b_conv, conv
def fclayer(self, input, shape, name, prop=True): w_fc = tf.Variable(tf.truncated_normal(shape=shape, dtype=tf.float32, stddev=0.1), name='w_{0}'.format(name)) b_fc = tf.Variable(tf.constant(0.0, shape=shape[-1:], dtype=tf.float32), name='b_{0}'.format(name)) if prop: fc = tf.nn.relu(tf.nn.bias_add(tf.matmul(input, w_fc), b_fc), name=name) return w_fc, b_fc, fc else: return w_fc, b_fc
def call(self, images, reuse=False): with tf.variable_scope(self.name):
if reuse:
scope.reuse_variables()
activations = []
with tf.variable_scope('input'):
images = tf.reshape(images, [-1, 128, 128, 1], name='input')
activations += [images, ]
with tf.variable_scope('conv1'):
w_conv1, b_conv1, conv1 = self.convlayer(images, [3, 3, 1, 64], 'conv1')
activations += [conv1, ]
with tf.variable_scope('conv2'):
w_conv2, b_conv2, conv2 = self.convlayer(conv1, [3, 3, 64, 64], 'conv2')
activations += [conv2, ]
with tf.variable_scope('max_pool1'):
max_pool1 = tf.nn.max_pool(conv2, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool1')
activations += [max_pool1, ]
with tf.variable_scope('conv3'):
w_conv3, b_conv3, conv3 = self.convlayer(max_pool1, [3, 3, 64, 128], 'conv3')
activations += [conv3, ]
with tf.variable_scope('conv4'):
w_conv4, b_conv4, conv4 = self.convlayer(conv3, [3, 3, 128, 128], 'conv4')
activations += [conv4, ]
with tf.variable_scope('max_pool2'):
max_pool2 = tf.nn.max_pool(conv4, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool2')
activations += [max_pool2, ]
with tf.variable_scope('conv5'):
w_conv5, b_conv5, conv5 = self.convlayer(max_pool2, [3, 3, 128, 256], 'conv5')
activations += [conv5, ]
with tf.variable_scope('conv6'):
w_conv6, b_conv6, conv6 = self.convlayer(conv5, [3, 3, 256, 256], 'conv6')
activations += [conv6, ]
with tf.variable_scope('max_pool3'):
max_pool3 = tf.nn.max_pool(conv6, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool3')
activations += [max_pool3, ]
with tf.variable_scope('conv7'):
w_conv7, b_conv7, conv7 = self.convlayer(max_pool3, [3, 3, 256, 512], 'conv7')
activations += [conv7, ]
with tf.variable_scope('conv8'):
w_conv8, b_conv8, conv8 = self.convlayer(conv7, [3, 3, 512, 512], 'conv8')
activations += [conv8, ]
with tf.variable_scope('max_pool4'):
max_pool4 = tf.nn.max_pool(conv8, [1, 2, 2, 1], [1, 2, 2, 1], padding='SAME', name='max_pool4')
activations += [max_pool4, ]
with tf.variable_scope('flatten'):
flatten = tf.contrib.layers.flatten(max_pool4)
activations += [flatten, ]
with tf.variable_scope('fc1'):
n_in = int(flatten.get_shape()[1])
w_fc1, b_fc1, fc1 = self.fclayer(flatten, [n_in, 4096], 'fc1')
activations += [fc1, ]
with tf.variable_scope('fc2'):
n_in = int(fc1.get_shape()[1])
w_fc2, b_fc2, fc2 = self.fclayer(fc1, [n_in, 4096], 'fc2')
activations += [fc2, ]
with tf.variable_scope('dropout2'):
dropout2 = tf.nn.dropout(fc2, keep_prob=0.5, name='dropout2')
with tf.variable_scope('output'):
w_fc3, b_fc3 = self.fclayer(dropout2, [4096, 2], 'fc3', prop=False)
logits = tf.nn.bias_add(tf.matmul(dropout2, w_fc3), b_fc3, name='logits')
preds = tf.nn.softmax(logits, name='output')
activations += [preds, ]
return activations, logits
@property def params(self): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
train.py program
from utils import DataGenerator, MNISTLoader from model import MNIST_CNN
import tensorflow as tf
logdir = './logs/' chkpt = './logs/model.ckpt' n_epochs = 20 batch_size = 20
class Trainer:
def __init__(self):
self.dataloader = MNISTLoader()
self.x_train, self.y_train = self.dataloader.train
#print("Train shape")
#print(self.x_train.shape)
self.x_validation, self.y_validation = self.dataloader.validation
with tf.variable_scope('MNIST_CNN'):
self.model = MNIST_CNN()
self.X = tf.placeholder(tf.float32, [None,128,128], name='X')
self.y = tf.placeholder(tf.float32, [None, 2], name='y')
self.activations, self.logits = self.model(self.X)
tf.add_to_collection('LayerwiseRelevancePropagation', self.X)
for act in self.activations:
tf.add_to_collection('LayerwiseRelevancePropagation', act)
self.l2_loss = tf.add_n([tf.nn.l2_loss(p) for p in self.model.params if 'b' not in p.name]) * 0.001
self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.y)) + self.l2_loss
self.optimizer = tf.train.AdamOptimizer().minimize(self.cost, var_list=self.model.params)
self.preds = tf.equal(tf.argmax(self.logits, axis=1), tf.argmax(self.y, axis=1))
self.accuracy = tf.reduce_mean(tf.cast(self.preds, tf.float32))
self.cost_summary = tf.summary.scalar(name='Cost', tensor=self.cost)
self.accuracy_summary = tf.summary.scalar(name='Accuracy', tensor=self.accuracy)
self.summary = tf.summary.merge_all()
def run(self):
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
self.saver = tf.train.Saver()
self.file_writer = tf.summary.FileWriter(logdir, tf.get_default_graph())
self.train_batch = DataGenerator(self.x_train, self.y_train, batch_size)
self.validation_batch = DataGenerator(self.x_validation, self.y_validation, batch_size)
for epoch in range(n_epochs):
self.train(sess, epoch)
self.validate(sess)
self.saver.save(sess, chkpt)
def train(self, sess, epoch):
n_batches = self.x_train.shape[0] // batch_size
if self.x_train.shape[0] % batch_size != 0:
n_batches += 1
avg_cost = 0
avg_accuracy = 0
for batch in range(n_batches):
x_batch, y_batch = next(self.train_batch)
_, batch_cost, batch_accuracy, summ = sess.run([self.optimizer, self.cost, self.accuracy, self.summary],
feed_dict={self.X: x_batch, self.y: y_batch})
avg_cost += batch_cost
avg_accuracy += batch_accuracy
self.file_writer.add_summary(summ, epoch * n_batches + batch)
completion = batch / n_batches
print_str = '|'+int(completion * 20)*'#'+ (19 - int(completion * 20)) * ' ' + '|'
print('\rEpoch {0:>3} {1} {2:3.0f}% Cost {3:6.4f} Accuracy {4:6.4f}'.format('#' + str(epoch + 1), print_str, completion * 100, avg_cost / (batch + 1), avg_accuracy / (batch + 1)), end='')
#print("end="' ')
def validate(self, sess):
n_batches = self.x_validation.shape[0] // batch_size
if self.x_validation.shape[0] % batch_size != 0:
n_batches += 1
avg_accuracy = 0
for batch in range(n_batches):
x_batch, y_batch = next(self.validation_batch)
avg_accuracy += sess.run([self.accuracy, ], feed_dict={self.X: x_batch, self.y: y_batch})[0]
avg_accuracy /= n_batches
print('Validation Accuracy {0:6.4f}'.format(avg_accuracy))
if name == 'main': Trainer().run()
utils.py program
import gzip import pickle import os import glob import cv2 import numpy as np import matplotlib.cm as cm import matplotlib.pyplot as plt
DATA_PATH = './mnist_png/training' TEST_PATH = './mnist_png/testing'
class DataGenerator:
def __init__(self, X, y, batch_size):
assert(X.shape[0] == y.shape[0])
self.X = X
self.y = y
self.batch_size = batch_size
self.num_samples = X.shape[0]
self.num_batches = X.shape[0] // self.batch_size
if X.shape[0] % self.batch_size != 0:
self.num_batches += 1
self.batch_index = 0
def __iter__(self):
return self
def __next__(self, shuffle=True):
if self.batch_index == self.num_batches:
self.batch_index = 0
if shuffle:
indices = np.random.permutation(self.num_samples)
self.X = self.X[indices]
self.y = self.y[indices]
start = self.batch_index * self.batch_size
end = min(self.num_samples, start + self.batch_size)
self.batch_index += 1
return self.X[start: end], self.y[start: end]
class MNISTLoader:
def __init__(self, loc=DATA_PATH):
self.loc = loc
self.run()
def run(self):
classes = ['cats','dogs']
images = []
labels = []
ids = []
cls = []
for fld in classes: # assuming data directory has a separate folder for each class, and that each folder is named after the class
index = classes.index(fld)
#print('Loading {} files (Index: {})'.format(fld, index))
path = os.path.join(DATA_PATH, fld, '*g')
files = glob.glob(path)
for fl in files:
image = cv2.imread(fl)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#print(image.shape)
image = cv2.resize(image, (128, 128), interpolation = cv2.INTER_LINEAR)
image = image.astype(np.float32)
image = np.multiply(image, 1.0 / 255.0)
images.append(image)
label = np.zeros(len(classes))
label[index] = 1.0
labels.append(label)
flbase = os.path.basename(fl)
ids.append(flbase)
cls.append(fld)
images = np.array(images)
print('Total number of images {}'.format(len(images)))
labels = np.array(labels)
validation_size = int(0.2 * images.shape[0])
'''try:
with gzip.open(DATA_PATH, 'rb') as f:
data = pickle.load(f, encoding='bytes')
except FileNotFoundError:
print('Dataset not found!')
exit()'''
self.x_validation = images[:validation_size]
self.y_validation = labels[:validation_size]
self.x_train = images[validation_size:]
self.y_train = labels[validation_size:]
print('Total number of training images {}'.format(len(self.x_train)))
print('Total number of Validation images {}'.format(len(self.x_validation)))
'''train_set, validation_set, test_set = data
self.x_train, self.y_train = train_set
self.x_validation, self.y_validation = validation_set
self.x_test, self.y_test = test_set'''
'''print(self.x_train[0].shape)
plt.imshow(self.x_train[0])
plt.show()'''
test_images=[]
test_labels=[]
for fld in classes: # assuming data directory has a separate folder for each class, and that each folder is named after the class
index = classes.index(fld)
#print('Loading {} files (Index: {})'.format(fld, index))
path = os.path.join(TEST_PATH, fld, '*g')
files = glob.glob(path)
for fl in files:
test_image = cv2.imread(fl)
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)
test_image = cv2.resize(test_image, (128, 128), interpolation = cv2.INTER_LINEAR)
test_images.append(test_image)
test_label = np.zeros(len(classes))
test_label[index] = 1.0
test_labels.append(test_label)
#flbase = os.path.basename(fl)
#ids.append(flbase)
#cls.append(fld)
test_images = np.array(test_images)
test_labels = np.array(test_labels)
self.x_test = test_images
self.y_test = test_labels
'''I = np.eye(10)
self.y_train = I[self.y_train]
self.y_validation = I[self.y_validation]
self.y_test = I[self.y_test]'''
'''def get_samples(self):
#data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
#samples_indices = np.random.choice(np.argwhere(np.argmax(data[1], axis=1) == digit).flatten(), size=n_samples)
#return data[0][samples_indices]
#print("%%%%")
#print(self.x_test.shape)
data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
print("!!!!")
#print(data.shape)
samples_indices = (data[2].flatten())
print(data[0][samples_indices].shape)
return data[0][samples_indices]
#return self.x_test'''
def get_samples(self):
#data = [self.train, self.validation, self.test][np.random.choice(np.arange(3))]
#samples_indices = np.random.choice(np.argwhere(np.argmax(data[1], axis=1) == digit).flatten(), size=n_samples)
#return data[0][samples_indices]
print("%%%%")
print(self.x_test.shape)
return self.x_test
@property
def train(self):
return self.x_train, self.y_train
@property
def validation(self):
return self.x_validation, self.y_validation
@property
def test(self):
return self.x_test, self.y_test
if name == 'main': dl = MNISTLoader()
train = dl.train
validation = dl.validation
test = dl.test
dg = DataGenerator(train[0], train[1], 20)
for i in range(5):
x, y = next(dg)
#print(i, x.shape, y.shape)
print('x_train shape', train[0].shape)
print('y_train shape', train[1].shape)
print('x_validation shape', validation[0].shape)
print('y_validation shape', validation[1].shape)
print('x_test shape', test[0].shape)
print('y_test shape', test[1].shape)
print(dl.get_samples())
lrp.py program
from utils import MNISTLoader from tensorflow.python.ops import gen_nn_ops from matplotlib.cm import get_cmap
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import cv2
logdir = './logs/' chkpt = './logs/model.ckpt' resultsdir = './results/'
class LayerwiseRelevancePropagation: print("Welcome") def init(self): self.dataloader = MNISTLoader() self.epsilon = 1e-10
with tf.Session() as sess:
saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
saver.restore(sess, tf.train.latest_checkpoint(logdir))
weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='MNIST_CNN')
self.activations = tf.get_collection('LayerwiseRelevancePropagation')
self.X = self.activations[0]
self.act_weights = {}
for act in self.activations[2:]:
for wt in weights:
if len(act.name.split('/'))>2:
name = act.name.split('/')[2]
if name == wt.name.split('/')[2]:
if name not in self.act_weights:
self.act_weights[name] = wt
self.activations = self.activations[:0:-1]
self.relevances = self.get_relevances()
def get_relevances(self): relevances = [self.activations[0], ]
for i in range(1, len(self.activations)):
if len(self.activations[i - 1].name.split('/'))>2:
name = self.activations[i - 1].name.split('/')[2]
#print(name)
if 'output' in name or 'fc' in name:
relevances.append(self.backprop_fc(name, self.activations[i], relevances[-1]))
elif 'flatten' in name:
relevances.append(self.backprop_flatten(self.activations[i], relevances[-1]))
elif 'max_pool' in name:
relevances.append(self.backprop_max_pool2d(self.activations[i], relevances[-1]))
elif 'conv' in name:
relevances.append(self.backprop_conv2d(name, self.activations[i], relevances[-1]))
else:
#raise 'Error parsing layer!'
print("Error parsing layer!")
return relevances
def backprop_fc(self, name, activation, relevance): w = self.act_weights[name] w_pos = tf.maximum(0.0, w) z = tf.matmul(activation, w_pos) + self.epsilon s = relevance / z #print("!!!!") #print(name,s.shape,w_pos.shape) c = tf.matmul(s, tf.transpose(w_pos)) return c * activation
def backprop_flatten(self, activation, relevance): shape = activation.get_shape().as_list() shape[0] = -1 #print("flatten") return tf.reshape(relevance, shape)
def backprop_max_pool2d(self, activation, relevance, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1]): z = tf.nn.max_pool(activation, ksize, strides, padding='SAME') + self.epsilon s = relevance / z c = gen_nn_ops.max_pool_grad_v2(activation, z, s, ksize, strides, padding='SAME') #print("Max pool") return c * activation
def backprop_conv2d(self, name, activation, relevance, strides=[1, 1, 1, 1]): w = self.act_weights[name] w_pos = tf.maximum(0.0, w) z = tf.nn.conv2d(activation, w_pos, strides, padding='SAME') + self.epsilon s = relevance / z c = tf.nn.conv2d_backprop_input(tf.shape(activation), w_pos, s, strides, padding='SAME') #print("Conv") return c * activation
def get_heatmap(self,i): samples = self.dataloader.get_samples()
with tf.Session() as sess:
saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
saver.restore(sess, tf.train.latest_checkpoint(logdir))
#samples[i]=samples[i].reshape(1,128,128)
print(samples[i].shape,self.relevances[-1])
cmap_type='rainbow'
shape = list(samples[i].shape)
cmap = get_cmap(name='rainbow')
heatmap = cmap(samples[i].flatten())[:, :1]
heatmap = heatmap
print(heatmap.shape)
shape[-1] = 3
return heatmap.reshape(128,128)
def test(self): #samples = self.dataloader.get_samples(n_samples=1, digit=np.random.choice(10)) samples=self.dataloader.get_samples()
#print(len(samples))
leng=len(samples)
with tf.Session() as sess:
saver = tf.train.import_meta_graph('{0}.meta'.format(chkpt))
saver.restore(sess, tf.train.latest_checkpoint(logdir))
R = sess.run(self.relevances, feed_dict={self.X: samples})
#for r in R:
#print(r.sum())
return leng
if name == 'main':
lent=LayerwiseRelevancePropagation().test()
print(lent)
for i in range(lent):
heatmap = LayerwiseRelevancePropagation().get_heatmap(i)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.axis('off')
ax.imshow(heatmap, cmap='Reds', interpolation='bilinear')
fig.savefig('{0}{1}.jpg'.format(resultsdir, i))