Transformers-RL icon indicating copy to clipboard operation
Transformers-RL copied to clipboard

How to combine it with PPO algorithm?

Open borninfreedom opened this issue 3 years ago • 0 comments

If my observation is an image of shape (4, 84, 84), and action dim is 3, so how to modify the code below?

if __name__ == '__main__':
    states = torch.randn(1,1, 4) # seq_size, batch_size, dim - better if dim % 2 == 0
    print("=> Testing Policy")
    policy = TransformerGaussianPolicy(state_dim=states.shape[-1], act_dim=4)
    for i in range(10):
        act = policy(states)
        action = act[0].sample()
        print(torch.isnan(action).any(), action.shape)

borninfreedom avatar Jul 24 '21 03:07 borninfreedom