Transformers-RL
Transformers-RL copied to clipboard
How to combine it with PPO algorithm?
If my observation is an image of shape (4, 84, 84), and action dim is 3, so how to modify the code below?
if __name__ == '__main__':
states = torch.randn(1,1, 4) # seq_size, batch_size, dim - better if dim % 2 == 0
print("=> Testing Policy")
policy = TransformerGaussianPolicy(state_dim=states.shape[-1], act_dim=4)
for i in range(10):
act = policy(states)
action = act[0].sample()
print(torch.isnan(action).any(), action.shape)