Transformer-RL icon indicating copy to clipboard operation
Transformer-RL copied to clipboard

Experiments to train transformer network to master reinforcement learning environments.

Transformer-RL

Do the transformer layers can act as iterative beam searching in itself? I don't know. However training of transformer on a set of sequences corresponding to games can in a way act like trianing a searching algorithm. I have trained a GPT-like decoder only transformer model and it shows that it can learn the environments better (smaller network) than Dense FeedForward networks.

To run the cartpole run the command:

python3 run_dqn.pqy --name="CartPole-v1"
...
Epoch: 220 | 225387 | 0.021 | 135.542 | playing games: 107/1024: 22%|███████████████                                                              | 220/1000

Environments

I have tested the agent on CartPole-v1 and Breakout-ram-v0 and it did learn to play both with increasing rewards.

Algorithms

For now I have implemented DQN and will be implementing PPO in the coming iteration and see if there are any improvements.

Hypothesis

I am actively looking to collaborate with people to work on researching this approach. If transformers can indeed perform tree searching in it's own internal weights, this would drastically speed up processes that use Monte Carlo Tree Search. Now I am not suggesting that transformers have not been trained with RL, take this example from OpenAI that does summarization.

Old-Intro

Old code for tensorflow has been moved to old-tf and for pytorch in old-pytorch.

Experiments to train transformer network to master reinforcement learning environments. Before starting however I have written a small PoC in pytorch to train a transformer network on RAM configurations of popular gym environments. This was made as a quick hack and to learn pytorch package, however not a big fan of it, maybe I am just more used to tensorflow (more details in README). This package will completely be based upon and use tensorflow as our NN package.

A bit more about the coding conventions used here, I like the OpenAI style of coding where everything is tried to be mprphed into a function, this is more of a scripting convention and easier for experimentation and deployment. Plus it makes it easier on the eyes to read and maintain the code, increasing reusability. Just look at the repos from Google, its almost impossible to read and hack that code.