imitation-learning
imitation-learning copied to clipboard
Add support for other ways of training than using tf.Estimator
Currently training only works using the tf.Estimator
framework. Some users might prefer using the standard sess.run
way of calling the training operation for a more low level way of doing training. A starting point for that version might look like this
def main():
features, labels = input_fn.train_input_fn(tfrecord_path, batch_size=bs, shuffle_buffer_size=sbs)()
model = trainer.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
train_op = model.train_op
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(num_epochs):
sess.run(train_op)