feat: switch sebublba to using shard_map like mava
What?
Some mava upgrades, as promised :smile:
Quite a few nice changes by upgrading to shard_map over pmap here, which avoids some unnecessary device transfers. I found these in mava using transfer guard, quite a nice tool.
Why?
To make sebulba go brrr
How?
- Switch to
shard_map - Remove the default actor device - this causing some unnecessary transfers in the pipeline
- Stop using flax replicate/unreplicate rather explicitly
put - Move a
block_until_readyfrom the params source to the learner. I think theunreplicatethat in the learner before was doing this, without this block we get weird and undefined seg faults - One issue I see is that we're now passing the same key to all the learners, we are actually doing this in mava also and I realize it is a minor issue, I'm not entirely sure how to fix it, I tried quickly to switch the sharding for the key to the data sharding which I think should fix it, but it didn't...it's Sunday, hopefully I have time to look at it during the week or if you could find a solution that would be awesome
NOTE:
This is very much not benchmarked. I pulled in Mava's changes in about an hour and I tested it locally and it solves cartpole, but I haven't checked on a TPU or with a harder env
Thanks so much, I'll try review this and test it tomorrow on a GPU.
I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.
I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.
So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.
So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.
I think faster actors should almost always be better, maybe it just needs some hyper parameter tuning?
so i ended up completely refactoring the sebulba architectures taking a mix of inspiration from cleanba and instadeeps one. When timing the previous systems versus cleanba, cleanba was much faster and now stoix is as fast and getting good performance without being as (in my opinion) chaotic as cleanba. I'm going to now evaluate the shardmap change again at some point. I thought id give an update almost a year later haha
Honestly I'm not convinced we even need the shard map. You're welcome to close this PR if you want