attention-learn-to-route
attention-learn-to-route copied to clipboard
Reimplementation in RL4CO
Hi there 👋🏼
First of all, thanks a lot for your library, it has inspired several works in our research group! We are actively developing RL4CO, a library for all things Reinforcement Learning for Combinatorial Optimization. We started the library by modularizing the Attention Model, which is the basis for several other autoregressive models. We also used some recent software (such as TorchRL, TensorDict, PyTorch Lightning and Hydra) as well as routines such as FlashAttention, and made everything as easy to use as possible in the hope of helping practitioners and researchers.
We welcome you to check RL4CO out ^^
Hi! Thanks for bringing it to my attention, I will definitely check it out! Are you able to reproduce the results from the paper with your implementation (training time, evaluation time and performance?).
Thanks for your quick answer 🚀
- In terms of performance: yes! While reimplementing the AM we carefully checked whether everything was working step-by-step; in the exact same settings as the original code (same parameters, learning rate, batch sizes, validation datasets...) we actually obtained slightly better performance:
with the only difference being the MHA module with the linear bias set to
true- the bump in the performance of the last 20 epochs is a step ofMultiStepLRscheduler. - In terms of training time: this depends on the implementation. To be honest, your code was already really well-optimized although it has been around for a few years! In theory, we should get similar / slightly better performance since we use FlashAttention in the encoder. In practice, however, we found that we are slightly slower with the same hyperparameters (on a single RTX 3090 we can train your original AM for TSP50 in 14 hours and ours in 15 hours). The reason is that we use TensorDicts and TorchRL environments - we found the "culprits" are data loading and most importantly the fact that we need to re-create on the fly TensorDicts with large batch sizes at each step, which has a pretty big impact on performance. One simple way to solve this is to just refactor environments in pure PyTorch. Right now we are in contact with the TorchRL team and there have been several updates (just a few hours ago, the stable TensorDict 0.2.0 was released), so we plan to solve this in the next couple of weeks ;) There is another trick to noticeably improve performance. Setting
mask_innertoFalseenables the FlashAttention routine during decoding (so for each step) and the training above takes 12.5 hours, even with the current slow TensorDict problem! This of course, does degrade performance, but FlashAttention with masking may be added in the near future, so it holds good promise! - In terms of evaluation time: we found our evaluation script to be slightly faster - but if we check a single forward pass of the model, we get a similar trend as above, so we think there may have been some bottleneck in the eval script.
We would be more than happy to address your feedback if you check out RL4CO, you may contact us any time 😄
(Late) edit: now we are way more efficient as explained here!
Great! I have added a link in the readme. However, I wonder if you have also had a look at https://github.com/cpwan/RLOR, they claim a 8x speedup over this repo using PPO.
Yes, we are aware of it! From our understanding and our testing, their speedup is actually considered as the time to reach a certain cost as seen in Table 4. AM trains TSP 50 reaching 5.80 in 24 hours, while their PPO implementation - with some training and testing tricks - takes 3 hours. So, it is not a speedup per se - actually, due to the environment being in Numpy, even though vectorized, data collection is naturally a bottleneck - but rather, the time to reach a target performance. Besides, the comparison is made with their AM trained with PPO with larger batch size and learning rate and tested with the "multi-greedy" decoding scheme during inference (what in RL4CO we call multistart, i.e., the POMO decoding scheme that starts decoding from all nodes and then takes the best trajectory), while the baseline AM is just evaluated with one-shot greedy decoding. For these reasons, we think the 8x speedup claim is a bit overstated 👀