Convolutional-LSTM-in-Tensorflow
Convolutional-LSTM-in-Tensorflow copied to clipboard
Hidden state not equated to None?
Hi Oliver, Thanks for a nice implementation. I had one problem in understanding main_conv_lstm.py. Shouldn't the hidden stated be equated to None before each iteration ( after line 143 in main_conv_lstm.py )? If this is not done, I think the last hidden state of previous batch step will be provided as input as first hidden state to the next batch.
Thanks
Thanks for the comment!
The hidden state is reset to zero after every batch automatically. This is because in the computational graph the first hidden state is a tf.zero node made in the ConvRNNCell zero_state function. If the hidden state was a variable and I updated the variable every batch then I might have to worry. Sorry this explanation is lousy.
We can make sure the the hidden state is reset to zero after every batch by printing it though. Here is some silly code I wrote to check it. When run, the first value in the matrix is 0 indicating that the first hidden state in the sequence is zero.
import os.path
import time
import numpy as np
import tensorflow as tf
import cv2
import bouncing_balls as b
import layer_def as ld
import BasicConvLSTMCell
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_dir', './checkpoints/train_store_conv_lstm',
"""dir to store trained net""")
tf.app.flags.DEFINE_integer('seq_length', 10,
"""size of hidden layer""")
tf.app.flags.DEFINE_integer('seq_start', 5,
""" start of seq generation""")
tf.app.flags.DEFINE_integer('max_step', 200000,
"""max num of steps""")
tf.app.flags.DEFINE_float('keep_prob', 1.0,
"""for dropout""")
tf.app.flags.DEFINE_float('lr', .001,
"""for dropout""")
tf.app.flags.DEFINE_integer('batch_size', 64,
"""batch size for training""")
tf.app.flags.DEFINE_float('weight_init', .1,
"""weight init for fully connected layers""")
fourcc = cv2.cv.CV_FOURCC('m', 'p', '4', 'v')
def generate_bouncing_ball_sample(batch_size, seq_length, shape, num_balls):
dat = np.zeros((batch_size, seq_length, shape, shape, 3))
for i in xrange(batch_size):
dat[i, :, :, :, :] = b.bounce_vec(32, num_balls, seq_length)
return dat
def network(inputs, hidden, lstm=True):
conv1 = ld.conv_layer(inputs, 3, 2, 16, "encode_1")
# conv2
conv2 = ld.conv_layer(conv1, 3, 1, 16, "encode_2")
# conv3
conv3 = ld.conv_layer(conv2, 3, 2, 32, "encode_3")
# conv4
conv4 = ld.conv_layer(conv3, 1, 1, 32, "encode_4")
y_0 = conv4
if lstm:
# conv lstm cell
with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)):
cell = BasicConvLSTMCell.BasicConvLSTMCell([8,8], [3,3], 32)
if hidden is None:
hidden = cell.zero_state(FLAGS.batch_size, tf.float32)
#####################################
# spit out old hidden state to record
#####################################
hidden_old = hidden
y_1, hidden = cell(y_0, hidden)
else:
y_1 = ld.conv_layer(y_0, 3, 1, 32, "encode_3")
# conv5
conv5 = ld.transpose_conv_layer(y_1, 1, 1, 32, "decode_5")
# conv6
conv6 = ld.transpose_conv_layer(conv5, 3, 2, 16, "decode_6")
# conv7
conv7 = ld.transpose_conv_layer(conv6, 3, 1, 16, "decode_7")
# x_1
x_1 = ld.transpose_conv_layer(conv7, 3, 2, 3, "decode_8", True) # set activation to linear
##################################
# Added returning old hidden state
##################################
return x_1, hidden, hidden_old
# make a template for reuse
network_template = tf.make_template('network', network)
def train():
"""Train ring_net for a number of steps."""
with tf.Graph().as_default():
# make inputs
x = tf.placeholder(tf.float32, [None, FLAGS.seq_length, 32, 32, 3])
# possible dropout inside
keep_prob = tf.placeholder("float")
x_dropout = tf.nn.dropout(x, keep_prob)
# create network
x_unwrap = []
# conv network
hidden = None
#########################################################
# Store hidden state to see if its really setting to zero
#########################################################
hidden_store = []
for i in xrange(FLAGS.seq_length-1):
if i < FLAGS.seq_start:
x_1, hidden, hidden_old = network_template(x_dropout[:,i,:,:,:], hidden)
else:
x_1, hidden, hidden_old = network_template(x_1, hidden)
x_unwrap.append(x_1)
###################
# Grab hidden state
###################
hidden_store.append(tf.reduce_sum(hidden_old))
# pack them all together
x_unwrap = tf.stack(x_unwrap)
x_unwrap = tf.transpose(x_unwrap, [1,0,2,3,4])
##########
# stack it
##########
hidden_store = tf.stack(hidden_store)
# this part will be used for generating video
x_unwrap_g = []
hidden_g = None
for i in xrange(50):
if i < FLAGS.seq_start:
x_1_g, hidden_g, hidden_old = network_template(x_dropout[:,i,:,:,:], hidden_g)
else:
x_1_g, hidden_g, hidden_old = network_template(x_1_g, hidden_g)
x_unwrap_g.append(x_1_g)
# pack them generated ones
x_unwrap_g = tf.stack(x_unwrap_g)
x_unwrap_g = tf.transpose(x_unwrap_g, [1,0,2,3,4])
# calc total loss (compare x_t to x_t+1)
loss = tf.nn.l2_loss(x[:,FLAGS.seq_start+1:,:,:,:] - x_unwrap[:,FLAGS.seq_start:,:,:,:])
tf.summary.scalar('loss', loss)
# training
train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss)
# List of all Variables
variables = tf.global_variables()
# Build a saver
saver = tf.train.Saver(tf.global_variables())
# Summary op
summary_op = tf.summary.merge_all()
# Build an initialization operation to run below.
init = tf.global_variables_initializer()
# Start running operations on the Graph.
sess = tf.Session()
# init if this is the very time training
print("init network from scratch")
sess.run(init)
# Summary op
graph_def = sess.graph.as_graph_def(add_shapes=True)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph_def=graph_def)
for step in xrange(FLAGS.max_step):
dat = generate_bouncing_ball_sample(FLAGS.batch_size, FLAGS.seq_length, 32, FLAGS.num_balls)
t = time.time()
#########################
# Return hidden state too
#########################
_, loss_r, h_s = sess.run([train_op, loss, hidden_store],feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
elapsed = time.time() - t
#########################################################
# print values for hidden state, The first should be zero
#########################################################
print(h_s)
if step%100 == 0 and step != 0:
summary_str = sess.run(summary_op, feed_dict={x:dat, keep_prob:FLAGS.keep_prob})
summary_writer.add_summary(summary_str, step)
print("time per batch is " + str(elapsed))
print(step)
print(loss_r)
assert not np.isnan(loss_r), 'Model diverged with loss = NaN'
if step%1000 == 0:
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
print("saved to " + FLAGS.train_dir)
# make video
print("now generating video!")
video = cv2.VideoWriter()
success = video.open("generated_conv_lstm_video.mov", fourcc, 4, (180, 180), True)
dat_gif = dat
ims = sess.run([x_unwrap_g],feed_dict={x:dat_gif, keep_prob:FLAGS.keep_prob})
ims = ims[0][0]
print(ims.shape)
for i in xrange(50 - FLAGS.seq_start):
x_1_r = np.uint8(np.maximum(ims[i,:,:,:], 0) * 255)
new_im = cv2.resize(x_1_r, (180,180))
video.write(new_im)
video.release()
def main(argv=None): # pylint: disable=unused-argument
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()
if __name__ == '__main__':
tf.app.run()