stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Question] DQN optimizer parameters

Open rtkbv opened this issue 1 year ago • 1 comments

❓ Question

I have a question about the optimizer initialization process in DQNPolicy. While working on a custom DQN model, I noticed that when creating the optimizer, we pass all of the parameters of DQNPolicy, which include the parameters of both q_net and q_net_target:

self.q_net = self.make_q_net()
self.q_net_target = self.make_q_net()
self.q_net_target.load_state_dict(self.q_net.state_dict())
self.q_net_target.set_training_mode(False)

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(  # type: ignore[call-arg]
    self.parameters(),
    lr=lr_schedule(1),
    **self.optimizer_kwargs,
)

Source: /stable_baselines3/dqn/policies.py

However, since q_net_target parameters are updated using Polyak update, is there a reason why they are included in the optimizer? Would it be correct to use self.q_net.parameters() instead of self.parameters()?

Thanks.

Checklist

rtkbv avatar Jun 12 '24 07:06 rtkbv

However, since q_net_target parameters are updated using Polyak update, is there a reason why they are included in the optimizer? Would it be correct to use self.q_net.parameters() instead of self.parameters()?

yes, I guess we did that from copy/pasting from other parts.

That's what we do (do not put target net params) for SAC for instance: https://github.com/DLR-RM/stable-baselines3/blob/4efee92fbad70f85aa094e27bd0a740274121795/stable_baselines3/sac/policies.py#L304

We do have a check though: https://github.com/DLR-RM/stable-baselines3/blob/4efee92fbad70f85aa094e27bd0a740274121795/tests/test_cnn.py#L195-L203

I would be happy to receive a PR that fixes this issue =)

araffin avatar Jun 12 '24 07:06 araffin

Hello, is there anything to modify in the check you mention in test_cnn.py ? It seems like the check is not performed on the optimizer's parameters. Instead, it seems to only compare the old and new parameters between old and new target_q_network, and similarly for q_network. Did I misunderstand something ?

corentinlger avatar Jul 03 '24 08:07 corentinlger

Hello, is there anything to modify in the check you mention in test_cnn.py ?

probably not indeed.

araffin avatar Jul 05 '24 13:07 araffin