stable-baselines
stable-baselines copied to clipboard
[question] How do I load a tensorflow ckpt?
I am trying to load a pre-trained model from some old code using this framework and my familiarity with tensorflow is very limited. I've tried multiple things to load the model but I am unable to find the right way 🤯
Here is how the model is created and saved. I just want to load back the weights after saving for evaluation.
Below, I've shown a representative of how the model is created then stored.
import tensorflow as tf
from stable_baselines import PPO1
from stable_baselines.common.policies import FeedForwardPolicy
training_sess = None
class MyMlpPolicy(FeedForwardPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **_kwargs):
super(MyMlpPolicy, self).__init__(
sess,
ob_space,
ac_space,
n_env,
n_steps,
n_batch,
reuse,
net_arch=[{"pi": [32, 16], "vf": [32, 16]}],
feature_extraction="mlp",
**_kwargs
)
global training_sess
training_sess = sess
model = PPO1(MyMlpPolicy, env)
# This is how the model is saved
with model.graph.as_default():
saver = tf.train.Saver()
saver.save(training_sess, "./model_0.ckpt")
# The above step produces 4 types of files
# 1. checkpoint
# 2. model_0.ckpt.data-00000-of-00001
# 3. model_0.ckpt.index
# 4. model_0.ckpt.meta
Please fill in the issue template. If you only want to save the full agent you do not need to do any TF stuff, only use save
and load
functions (see examples in docs). We can not offer custom tech support for saving/loading in a custom way like this.
we also highly recommend to switch to Stable-Baselines3 (PyTorch).
we also highly recommend to switch to Stable-Baselines3 (PyTorch).
Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.
Yeah that is what I currently use. I just need to run some old code for a comparison. I just have their provided weights.
In that case you should look at the set_parameters
function in the SB3 documentation :).
You can close this issue if your question has been answered.