stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Feature Request] Add Attention nets (GTrXL model in particular)

Open RemiG3 opened this issue 2 years ago • 13 comments

🚀 Feature Request

This feature request is a duplicate from stable-baselines3 (see https://github.com/DLR-RM/stable-baselines3/issues/177).

The idea is to add the GTrXL model in the contrib repo from the paper Stabilizing Transformers for Reinforcement Learning, as done in RLlib: https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py.

~~@araffin has already mentioned that he created it and will make it public (comment).~~

~~I wonder if this is still relevant?~~

RemiG3 avatar Mar 15 '23 09:03 RemiG3

@araffin has already mentioned that he created it and will make it public (https://github.com/DLR-RM/stable-baselines3/issues/177#issuecomment-703268927).

I meant the SB3 contrib repo.

For GTrXL, are you willing to contribute that algorithm? Please read carefully the contributing guide if you decide to.

araffin avatar Mar 15 '23 16:03 araffin

I meant the SB3 contrib repo.

Sorry for the misunderstanding.

For GTrXL, are you willing to contribute that algorithm?

I'm not sure yet, I will try to implement it for my experiments first.

RemiG3 avatar Mar 16 '23 09:03 RemiG3

Also related: https://github.com/maohangyu/TIT_open_source

araffin avatar Mar 31 '23 11:03 araffin

@RemiG3 hey, have you started to implement it? Mayba I can give a free hand in it :)

richardjozsa avatar Apr 05 '23 22:04 richardjozsa

Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.

@araffin is it possible to create a new branch for this feature (to share the code)? If it is possible, I'll clean up the code and push it to this new branch soon.

RemiG3 avatar Apr 06 '23 17:04 RemiG3

Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve.

@araffin is it possible to create a new branch for this feature (to share the code)? If it is possible, I'll clean up the code and push it to this new branch soon.

yes, that's what a fork and pull request are meant for

araffin avatar Apr 06 '23 18:04 araffin

I have came accross on this, this is quite modular and easy to tune, Transformers-RL, the only backside is that, it has been implemented only to gaussian policy.

richardjozsa avatar Apr 08 '23 11:04 richardjozsa

Hey, I finally made the PR #176 to share the code.

It should work, but I'm not sure about the performances. It would be nice if someone could make comparisons with other methods (or RLlib attention net for example). I won't have time these next days.

RemiG3 avatar Apr 11 '23 20:04 RemiG3

RemiG3, Thank you for adding attention net to contrib. what's the shape of the input would be look like , for example if I want to use cartpole environment? Thanks again.

eric000888 avatar Apr 18 '23 16:04 eric000888

Thank you, @eric000888, for reporting this (feel free to provide the code you tested as you did in your first edits).

I have updated the branch to fix a bug on the dimension of minibatchs. But, I still have an exception when batch_size = 1 or n_steps = 1 and I found the same exception for RecurrentPPO.

So, it should now work for batch_size > 1 and n_steps > 1 (as RecurrentPPO).

EDIT: I also add assertions about these cases, as in the original PPO.

RemiG3 avatar Apr 19 '23 19:04 RemiG3

RemiG3, Sorry for late response, here is my first post code:

from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO from sb3_contrib.ppo_attention.policies import MlpAttnPolicy

VE = DummyVecEnv([lambda: gym.make("CartPole-v1")])

model = AttentionPPO( "MlpAttnPolicy", VE, n_steps=240, learning_rate=0.0003, verbose=1, batch_size=12, ent_coef=0.03, vf_coef=0.5, seed=1, n_epochs=10, max_grad_norm=1, gae_lambda=0.95, gamma=0.99,
device='cpu', policy_kwargs = dict( net_arch=dict(pi=[64,32],vf=[64,32]), ) )

First I create a vector environments and then setup the model like LSTM recurrent PPO, then run the model.learn(). I track the code and found the internal calculation return number is ok at the beginning but after a few loop it start return 'NA' and then stopped. I saw some other implementation use stacked frame and use sliding window as input format so I'm a little bit confused about what's should be the correct input format. But from your code I think the input should just one records at the time, don't need to stack the records.

I follow the code and saw you concatenate the tensor of input and memory, but the input format from SB3 is one records and then after the first round of full loop it's become batch number of records and that throw the error as the memory is still just one
tensor instead of batch.

Thank you for the update, i will try it this weekend.

eric000888 avatar Apr 25 '23 13:04 eric000888

another questions is if you just use GtrXL as feature extractor in PPO model, is this will get the same results? as the LSTM recurrent PPO has a flag to use the LSTM layer or not , similar like a feature extractor layer.

eric000888 avatar Apr 25 '23 13:04 eric000888

another thing is GtrXL demand more computation power , and PPO is like aiming a moving target, I found training a GtrXL PPO is a daunting task especially when using multiple layers. but if you can update the gradient on the whole trajectory then you may speed up the learning process. that means you collect all action/observation and then do one pass of back propagation.

eric000888 avatar Apr 25 '23 13:04 eric000888