stable-baselines icon indicating copy to clipboard operation
stable-baselines copied to clipboard

[question] How do I load a tensorflow ckpt?

Open Syzygianinfern0 opened this issue 3 years ago • 4 comments

I am trying to load a pre-trained model from some old code using this framework and my familiarity with tensorflow is very limited. I've tried multiple things to load the model but I am unable to find the right way 🤯

Here is how the model is created and saved. I just want to load back the weights after saving for evaluation.

Below, I've shown a representative of how the model is created then stored.

import tensorflow as tf
from stable_baselines import PPO1
from stable_baselines.common.policies import FeedForwardPolicy

training_sess = None


class MyMlpPolicy(FeedForwardPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
        super(MyMlpPolicy, self).__init__(
            sess,
            ob_space,
            ac_space,
            n_env,
            n_steps,
            n_batch,
            reuse,
            net_arch=[{"pi": [32, 16], "vf": [32, 16]}],
            feature_extraction="mlp",
            **_kwargs
        )
        global training_sess
        training_sess = sess


model = PPO1(MyMlpPolicy, env)

# This is how the model is saved
with model.graph.as_default():
    saver = tf.train.Saver()
    saver.save(training_sess, "./model_0.ckpt")

# The above step produces 4 types of files
# 1. checkpoint
# 2. model_0.ckpt.data-00000-of-00001
# 3. model_0.ckpt.index
# 4. model_0.ckpt.meta

Syzygianinfern0 avatar Jan 02 '22 11:01 Syzygianinfern0

Please fill in the issue template. If you only want to save the full agent you do not need to do any TF stuff, only use save and load functions (see examples in docs). We can not offer custom tech support for saving/loading in a custom way like this.

Miffyli avatar Jan 02 '22 14:01 Miffyli

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

araffin avatar Jan 02 '22 15:01 araffin

we also highly recommend to switch to Stable-Baselines3 (PyTorch).

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

Syzygianinfern0 avatar Jan 02 '22 15:01 Syzygianinfern0

Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.

In that case you should look at the set_parameters function in the SB3 documentation :).

You can close this issue if your question has been answered.

Miffyli avatar Jan 02 '22 16:01 Miffyli