rl icon indicating copy to clipboard operation
rl copied to clipboard

[Feature] Enable parameter reset in loss module

Open BY571 opened this issue 1 year ago • 4 comments

Description

Allows to reset the parameters in the loss module.

BY571 avatar Mar 18 '24 14:03 BY571

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2017

Note: Links to docs will display an error until the docs builds have been completed.

:x: 4 New Failures, 1 Unrelated Failure

As of commit 4b29473a32561b6251e01eeb6d50f51adc957690 with merge base 87f3437b26a8841e534b62ef6aa020d5fc287a90 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Mar 18 '24 14:03 pytorch-bot[bot]

Thanks for this!

We'll need tests for the feature.

How do we handle the target parameters?

Wouldn't something like this be a bit more robust?

from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")

with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda module: module.reset_parameters() if hasattr(module, "reset_parameters") else None)

I like the solution! But we are accessing the parameters directly in your example so we would need to define a reset function manually, which I think is perfectly fine because then the user has to decide the way how to reset weights and biases:

def reset_parameters(params):
    """ User specified resetting function depending on their needs for initialization """
    if len(params.shape) > 1:
        # weights
        nn.init.xavier_uniform_(params)
    elif len(params.shape) == 1:
        # biases
        nn.init.zeros_(params)
    else:
        raise ValueError("Unknown parameter shape: {}".format(params.shape))
  
with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)

And for handling the target_network_params I think we could simply do something like:

loss.target_value_network_params.update(loss.value_network_params)

What do you think? I think we can close the draft. But we might want to mention the way to reset parameters somewhere in the docs.

BY571 avatar Mar 20 '24 08:03 BY571

loss.target_value_network_params.update(loss.value_network_params)

This won't work because the target params are locked (you can't update them). They're locked because we want to avoid this kind of operation :) You should update the data inplace:

loss.target_value_network_params.apply(lambda dest, src: dest.data.copy_(src), loss.value_network_params)

vmoens avatar Mar 20 '24 11:03 vmoens

def reset_parameters(params):
    """ User specified resetting function depending on their needs for initialization """
    if len(params.shape) > 1:
        # weights
        nn.init.xavier_uniform_(params)
    elif len(params.shape) == 1:
        # biases
        nn.init.zeros_(params)
    else:
        raise ValueError("Unknown parameter shape: {}".format(params.shape))
  
with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)

Unfortunately this isn't very generic (1) all tensors have a data attribute, even buffers. By doing this you will also use Xavier init on batch-norm buffers if they're 2d (2) If the model has a mixture of linear, conv and other layers it's going to be hard to have a fine grained control over the params being updated.

Not all modules are "weights" and "biases" and "biases" can be 2d (my point is: the dimension is a very indirect determinator of the tensor role in a model)

The way I usually see this work is to use the module reset_parameters if there is one, which provides a better control over difference in initialization methods.

Maybe we could allow the user to pass a reset function, but in that case we don't even need to re-populate the module (we can just do tensordict.apply(reset)). Note that you could also do

def reset(name, tensor):
    if name == "bias":
        tensor.data.zero_()
    if name == "weight":
        nn.init.xavier_uniform_(tensor)
tensordict.apply(reset, named=True)

which is more straightforward IMO

vmoens avatar Mar 20 '24 11:03 vmoens