imitation-learning icon indicating copy to clipboard operation
imitation-learning copied to clipboard

Add support for other ways of training than using tf.Estimator

Open roberttorfason opened this issue 6 years ago • 0 comments

Currently training only works using the tf.Estimator framework. Some users might prefer using the standard sess.run way of calling the training operation for a more low level way of doing training. A starting point for that version might look like this

def main():
    features, labels = input_fn.train_input_fn(tfrecord_path, batch_size=bs, shuffle_buffer_size=sbs)()
    model = trainer.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN)
    train_op = model.train_op

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(num_epochs):
            sess.run(train_op)

roberttorfason avatar Dec 12 '18 14:12 roberttorfason