[Feature] Enable parameter reset in loss module
Description
Allows to reset the parameters in the loss module.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2017
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 4 New Failures, 1 Unrelated Failure
As of commit 4b29473a32561b6251e01eeb6d50f51adc957690 with merge base 87f3437b26a8841e534b62ef6aa020d5fc287a90 ():
NEW FAILURES - The following jobs have failed:
- Continuous Benchmark (PR) / CPU Pytest benchmark (gh)
Workflow failed! Resource not accessible by integration - Continuous Benchmark (PR) / GPU Pytest benchmark (gh)
Workflow failed! Resource not accessible by integration - Habitat Tests on Linux / tests (3.9, 11.6) / linux-job (gh)
RuntimeError: Command docker exec -t 012a098c6051fb52e5a3e8062b34e1b7c0b5b7679a2414910fc6cbcfa5776379 /exec failed with exit code 139 - Unit-tests on MacOS CPU / tests (3.8) / macos-job (gh)
test/test_modules.py::TestMultiAgent::test_multiagent_mlp[batch1-None-False-True-3]
BROKEN TRUNK - The following job failed but were present on the merge base:
👉 Rebase onto the `viable/strict` branch to avoid these failures
- Unit-tests on Linux / tests-cpu (3.8) / linux-job (gh)
AttributeError: 'OrphanPath' object has no attribute 'exists'
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Thanks for this!
We'll need tests for the feature.
How do we handle the target parameters?
Wouldn't something like this be a bit more robust?
from torchrl.objectives import DQNLoss from torchrl.modules import QValueActor from torch import nn module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64)) value_net = QValueActor(module=module, action_space="categorical") loss = DQNLoss(value_network=value_net, action_space="categorical") with loss.value_network_params.to_module(loss.value_network): loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)
I like the solution! But we are accessing the parameters directly in your example so we would need to define a reset function manually, which I think is perfectly fine because then the user has to decide the way how to reset weights and biases:
def reset_parameters(params):
""" User specified resetting function depending on their needs for initialization """
if len(params.shape) > 1:
# weights
nn.init.xavier_uniform_(params)
elif len(params.shape) == 1:
# biases
nn.init.zeros_(params)
else:
raise ValueError("Unknown parameter shape: {}".format(params.shape))
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)
And for handling the target_network_params I think we could simply do something like:
loss.target_value_network_params.update(loss.value_network_params)
What do you think? I think we can close the draft. But we might want to mention the way to reset parameters somewhere in the docs.
loss.target_value_network_params.update(loss.value_network_params)
This won't work because the target params are locked (you can't update them). They're locked because we want to avoid this kind of operation :) You should update the data inplace:
loss.target_value_network_params.apply(lambda dest, src: dest.data.copy_(src), loss.value_network_params)
def reset_parameters(params):
""" User specified resetting function depending on their needs for initialization """
if len(params.shape) > 1:
# weights
nn.init.xavier_uniform_(params)
elif len(params.shape) == 1:
# biases
nn.init.zeros_(params)
else:
raise ValueError("Unknown parameter shape: {}".format(params.shape))
with loss.value_network_params.to_module(loss.value_network):
loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)
Unfortunately this isn't very generic (1) all tensors have a data attribute, even buffers. By doing this you will also use Xavier init on batch-norm buffers if they're 2d (2) If the model has a mixture of linear, conv and other layers it's going to be hard to have a fine grained control over the params being updated.
Not all modules are "weights" and "biases" and "biases" can be 2d (my point is: the dimension is a very indirect determinator of the tensor role in a model)
The way I usually see this work is to use the module reset_parameters if there is one, which provides a better control over difference in initialization methods.
Maybe we could allow the user to pass a reset function, but in that case we don't even need to re-populate the module (we can just do tensordict.apply(reset)). Note that you could also do
def reset(name, tensor):
if name == "bias":
tensor.data.zero_()
if name == "weight":
nn.init.xavier_uniform_(tensor)
tensordict.apply(reset, named=True)
which is more straightforward IMO