RLeXplore icon indicating copy to clipboard operation
RLeXplore copied to clipboard

Action space for SB3

Open vicbentuupc opened this issue 1 year ago • 1 comments

Hi,

I'm pretty new to this and trying to use this to train a PPO agent using SB3. I used the code in the example (with my custom environment instead) which action space is the following: "action_space = gym.spaces.MultiDiscrete([ 3, 3, 2, 2, 2, 2, 2, 2, ]) "

This got me the following error: "UserWarning: Using a target size (torch.Size([256, 8])) that is different to the input size (torch.Size([256, 18])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."

18 corresponds to my action space encoded in one-hot vectors, so I tried using the following function to convert to the adecuate shape:

def encode_actions(self, actions):
        original_shape = actions.shape
        actions_flat = actions.view(-1, len(self.action_dims))

        encoded = []
        for idx, dim in enumerate(self.action_dims):
            one_hot = th.nn.functional.one_hot(actions_flat[:, idx].long(), num_classes=dim)
            encoded.append(one_hot)

        encoded_actions = th.cat(encoded, dim=-1).float()
        encoded_actions = encoded_actions.view(*original_shape[:-1], -1)
        return encoded_actions

But still got: "Error: The size of tensor a (18) must match the size of tensor b (256) at non-singleton dimension 1"

I'm pretty new to this, so maybe is just that it's not prepared to work with multidiscrete envs.

vicbentuupc avatar Apr 12 '25 16:04 vicbentuupc

This becuase onehot operation expanded the dimension of the orginal action, like 2 -> [0, 0, 1].

There are 8 separate actions in your action space, and 3+3+2+2+2+2+2+2=18

myismyname avatar Apr 28 '25 10:04 myismyname