rlpyt icon indicating copy to clipboard operation
rlpyt copied to clipboard

Pretrained models

Open juliusfrost opened this issue 5 years ago • 11 comments

Is there any chance you could release the pretrained models for the implemented algorithms? It would accelerate research for some, and help those without access to good hardware.

juliusfrost avatar Nov 05 '19 07:11 juliusfrost

Which networks do you need, and on which envs?

bmazoure avatar Nov 06 '19 04:11 bmazoure

I'd like to get models on the task of Atari: Discrete Control from Vision, for Policy gradient algorithms. As a start, PPO on Pong-v0, Seaquest-v0, SpaceInvaders-v0 would be very helpful. I am trying to learn a model of the environment on top of the algorithm.

juliusfrost avatar Nov 06 '19 10:11 juliusfrost

Oh interesting idea....I haven't actually saved any of the learned agents from development, or else I could post them. May be especially worth doing if we ever go back and run R2D1 again. If we do more runs I'll save and share.

Otherwise, hopefully someone can provide!

astooke avatar Nov 06 '19 20:11 astooke

Another related question is if there are some example scripts that show how to save, and then resume training, from a given snapshot. I didn't see any but I am currently making one now. Happy to share if it makes sense.

DanielTakeshi avatar Nov 21 '19 20:11 DanielTakeshi

Saving and loading is really easy. Here's example_1 showing how to train normally, and in particular, building the algorithm and the agent.

https://github.com/astooke/rlpyt/blob/75e96cda433626868fd2a30058be67b99bbad810/examples/example_1.py#L36-L37

To load a pre-trained model, which ideally was saved in the way that rlpyt provides (see https://github.com/astooke/rlpyt/issues/66 for details), you just have to use the initial_optim_state_dict and initial_model_state_dict arguments, which are here:

https://github.com/astooke/rlpyt/blob/75e96cda433626868fd2a30058be67b99bbad810/rlpyt/algos/dqn/dqn.py#L40

and here:

https://github.com/astooke/rlpyt/blob/75e96cda433626868fd2a30058be67b99bbad810/rlpyt/agents/base.py#L22

Here's how a code sketch would look like:

# Get the snapshot_pth, which contains agent and optim state dicts.
# You will need to adjust snapshot_pth based on your code.
data = torch.load(snapshot_pth)
itr = data['itr']  # might be useful for logging / debugging
cum_steps = data['cum_steps']  # might be useful for logging / debugging
agent_state_dict = data['agent_state_dict']  # 'model' and 'target' keys
optimizer_state_dict = data['optimizer_state_dict']

# Supply optimizer state dict here.
algo = DQN(min_steps_learn=5e4, initial_optim_state_dict=optimizer_state_dict)

# Supply the model state dict here. Auto initializes the target net as well.
agent = AtariDqnAgent(initial_model_state_dict=agent_state_dict['model'])

# Then proceed as normal creating the sampler, the runner, and the training code.

DanielTakeshi avatar Nov 22 '19 00:11 DanielTakeshi

What if I want to use the saved model state dictionary (in params.pkl) to sample actions? The goal is to have that in a loop and run env.step(action) with the output action from the rlpyt trained agent. I have tried initializing the agent as:

agent = DdpgAgent(initial_model_state_dict=agent_state_dict['model'])

but the agents' model parameter doesn't get initialized.

Is there a way of doing this already, or is rlpyt not intended to be used in this way? I guess the alternative is to create an instance of the torch model used for the agents' policy...

ritalaezza avatar Jan 17 '20 15:01 ritalaezza

@ritalaezza The model is initialized in agent's initialization. Most __init__() methods in rlpyt only save the args and real initialization is done in initialize() method.

For you question, one workaround I found is to manually call agent's initialize(). Then initial_model_state_dict will work.

kaixin96 avatar Jan 27 '20 02:01 kaixin96

@kaixin96 Thank you for the tip. I had already made another workaround... As I had guessed, creating an instance of the torch model used for the agents' policy and then calling load_state_dict() works as well.

But I actually like you way better, thanks again :)

ritalaezza avatar Jan 28 '20 14:01 ritalaezza

Another thing I forgot to mention is that you need to be careful about the epsilon, the exploration schedule. However, if you load the itr variable I think that should be enough as the epsilon is determined from that. See my issue report: https://github.com/astooke/rlpyt/issues/110

DanielTakeshi avatar Feb 24 '20 17:02 DanielTakeshi

@ritalaezza Hello! Did you succeeded with this idea:

The goal is to have that in a loop and run env.step(action) with the output action from the rlpyt trained agent. ?

I am trying to do the same, but got confused. There is no sample_action function for agent. There is step function, but it requires previous action and reward. I am new to this, so may have missed something. Would appreciate any help :)

kzorina avatar Mar 04 '20 11:03 kzorina

@kzorina agent.step() is the right function to use. If your agent does not use the previous action and previous reward, you can pass None for those, or else write an agent which doesn't have those inputs.

Hope that helps!

astooke avatar Mar 06 '20 22:03 astooke