stable-baselines3
stable-baselines3 copied to clipboard
[Question] DQN optimizer parameters
❓ Question
I have a question about the optimizer initialization process in DQNPolicy. While working on a custom DQN model, I noticed that when creating the optimizer, we pass all of the parameters of DQNPolicy, which include the parameters of both q_net and q_net_target:
self.q_net = self.make_q_net()
self.q_net_target = self.make_q_net()
self.q_net_target.load_state_dict(self.q_net.state_dict())
self.q_net_target.set_training_mode(False)
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class( # type: ignore[call-arg]
self.parameters(),
lr=lr_schedule(1),
**self.optimizer_kwargs,
)
Source: /stable_baselines3/dqn/policies.py
However, since q_net_target parameters are updated using Polyak update, is there a reason why they are included in the optimizer? Would it be correct to use self.q_net.parameters() instead of self.parameters()?
Thanks.
Checklist
- [X] I have checked that there is no similar issue in the repo
- [X] I have read the documentation
- [X] If code there is, it is minimal and working
- [X] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
However, since q_net_target parameters are updated using Polyak update, is there a reason why they are included in the optimizer? Would it be correct to use self.q_net.parameters() instead of self.parameters()?
yes, I guess we did that from copy/pasting from other parts.
That's what we do (do not put target net params) for SAC for instance: https://github.com/DLR-RM/stable-baselines3/blob/4efee92fbad70f85aa094e27bd0a740274121795/stable_baselines3/sac/policies.py#L304
We do have a check though: https://github.com/DLR-RM/stable-baselines3/blob/4efee92fbad70f85aa094e27bd0a740274121795/tests/test_cnn.py#L195-L203
I would be happy to receive a PR that fixes this issue =)
Hello, is there anything to modify in the check you mention in test_cnn.py ? It seems like the check is not performed on the optimizer's parameters. Instead, it seems to only compare the old and new parameters between old and new target_q_network, and similarly for q_network. Did I misunderstand something ?