stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Feature Request] different activation functions in the network_architecture through the policy_kwargs

Open AlexPasqua opened this issue 3 years ago • 5 comments

🚀 Feature

Introduce the possibility of passing multiple activation functions to the policy newtork using the policy_kwargs.

Motivation

From what I understand, through the policy_kwargs it is possible to pass an activation function to be used by the net_arch part of the policy network. Oftentimes, though, the policy net (pi) and value function net (vf) need different activation functions. It looks like the only way to have different activation functions in these two sub-networks is to implement our own policy network, as shown in the advanced example here in your documentation. This is mentioned also in this issue #481

Alternatives

Ideally it would be possible to have multiple activation functions as follows: one for the shared layers and one for each of the layers of the two sub-networks (policy net (pi) and value net (vf)), mimicking how the architecture is passed. The architecture is passed this way: [<shared layers>, dict(vf=[<non-shared value network layers>], pi=[<non-shared policy network layers>])] (source: here), so I think it would be possible to use the same structure, but using PyTorch's activation functions instead of integers.

Example:

from torch.nn import ReLU, Softmax, Tanh

model = A2C('MultiInputPolicy', env,
             policy_kwargs=dict(
                 net_arch=[256, dict(pi=[128, 50], vf=[32, 1])],
                 activation_fn=[Tanh, dict(pi=[ReLU, Softmax], vf=[ReLU, ReLU])]
             )
        )

AlexPasqua avatar Sep 09 '22 10:09 AlexPasqua

Hello, sounds reasonable (even though I doubt changing the activation per layer will make a big difference). Could you do a draft PR to see how much complexity it adds? (we would have to add that feature for off-policy algorithms too for consistency)

Oftentimes, though, the policy net (pi) and value function net (vf) need different activation functions.

do you have references for that?

araffin avatar Sep 11 '22 20:09 araffin

Could you do a draft PR to see how much complexity it adds? (we would have to add that feature for off-policy algorithms too for consistency)

Initially I'll need to create my own policy network, because I need to implement this thing in a project pretty quickly, but I'll start to work on this draft PR in some days.

Oftentimes, though, the policy net (pi) and value function net (vf) need different activation functions.

do you have references for that?

I'm trying for example to implement the model in this paper, and they have a final softmax layer in the actor network (policy net) and not in the critic network (value net), and I think this can happen pretty often (but I'm not a massive expert).

AlexPasqua avatar Sep 12 '22 14:09 AlexPasqua

have a final softmax layer in the actor network (

I see, in that case, there is a misunderstanding but this is already the case for PPO and discrete actions: https://github.com/DLR-RM/stable-baselines3/blob/c4f54fcf047d7bf425fb6b88a3c8ed23fe375f9b/stable_baselines3/common/distributions.py#L275

araffin avatar Sep 12 '22 14:09 araffin

@araffin ok so, if I understood correctly, the last layer of the policy net for discrete actions has automatically a softmax activation function, then the one I put in the policy_kwargs is used in all the other layers of both the policy net and the value net. Is that correct?

(I'll try to work on a draft PR for that feature anyway!)

AlexPasqua avatar Sep 13 '22 07:09 AlexPasqua

is used in all the other layers of both the policy net and the value net. Is that correct?

yes

araffin avatar Sep 19 '22 20:09 araffin

Closing in favor https://github.com/DLR-RM/stable-baselines3/pull/1292

for reference, if one wants to have different activation functions for actor vs critic:


from typing import Callable, Tuple

from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from torch import nn
import torch as th


class CustomNetwork(nn.Module):
    """
    Custom network for policy and value function.
    It receives as input the features extracted by the features extractor.

    :param feature_dim: dimension of the features extracted
        with the features_extractor (e.g. features from a CNN)
    :param last_layer_dim_pi: number of units for the last layer of the policy network
    :param last_layer_dim_vf: number of units for the last layer of the value network
    """

    def __init__(
        self,
        feature_dim: int,
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        # IMPORTANT:
        # Save output dimensions, used to create the distributions
        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # Policy network
        self.policy_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_pi),
            nn.ReLU(),
        )
        # Value network
        self.value_net = nn.Sequential(
            nn.Linear(feature_dim, last_layer_dim_vf),
            nn.Tanh(),
        )

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        return self.policy_net(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        return self.value_net(features)


class CustomActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):

        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            # Pass remaining arguments to base class
            *args,
            **kwargs,
        )
        # Disable orthogonal initialization
        self.ortho_init = False
        # In case you want custom features extractor,
        # you can define them here:
        # self.pi_features_extractor = ...
        # self.vf_features_extractor = ...
        # self.share_features_extractor = False

    def _build_mlp_extractor(self) -> None:
        self.mlp_extractor = CustomNetwork(self.features_dim)


model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
model.learn(5000, progress_bar=True)

araffin avatar Jan 23 '23 15:01 araffin

@araffin now that the code is simpler and there are no shared layers in the mlp_extractor, what about re-proposing thei idea of #1116 - i.e. pass the desired activation functions to the constructor when creating the model? This way there wouldn't be the need to implement a policy net form scratch only for the activation function. In case you're interested, I could open a draft PR to have an idea.

AlexPasqua avatar Jan 23 '23 17:01 AlexPasqua

what about re-proposing thei idea of https://github.com/DLR-RM/stable-baselines3/pull/1116 - i.e. pass the desired activation functions to the constructor when creating the model?

I think defining a custom policy is now simple enough that it should not be needed. The fact that having different activation function between actor/critic is quite unusual and that the performance gain is unknown still holds too (I'm pretty sure there won't be much difference).

araffin avatar Jan 24 '23 10:01 araffin

what about re-proposing thei idea of #1116 - i.e. pass the desired activation functions to the constructor when creating the model?

I think defining a custom policy is now simple enough that it should not be needed. The fact that having different activation function between actor/critic is quite unusual and that the performance gain is unknown still holds too (I'm pretty sure there won't be much difference).

Alright, no problem, custom policies are indeed pretty simple now

AlexPasqua avatar Jan 24 '23 10:01 AlexPasqua