Stoix icon indicating copy to clipboard operation
Stoix copied to clipboard

feat: switch sebublba to using shard_map like mava

Open sash-a opened this issue 1 year ago • 4 comments

What?

Some mava upgrades, as promised :smile:
Quite a few nice changes by upgrading to shard_map over pmap here, which avoids some unnecessary device transfers. I found these in mava using transfer guard, quite a nice tool.

Why?

To make sebulba go brrr

How?

  • Switch to shard_map
  • Remove the default actor device - this causing some unnecessary transfers in the pipeline
  • Stop using flax replicate/unreplicate rather explicitly put
  • Move a block_until_ready from the params source to the learner. I think the unreplicate that in the learner before was doing this, without this block we get weird and undefined seg faults
  • One issue I see is that we're now passing the same key to all the learners, we are actually doing this in mava also and I realize it is a minor issue, I'm not entirely sure how to fix it, I tried quickly to switch the sharding for the key to the data sharding which I think should fix it, but it didn't...it's Sunday, hopefully I have time to look at it during the week or if you could find a solution that would be awesome

NOTE:

This is very much not benchmarked. I pulled in Mava's changes in about an hour and I tested it locally and it solves cartpole, but I haven't checked on a TPU or with a harder env

sash-a avatar Nov 03 '24 16:11 sash-a

Thanks so much, I'll try review this and test it tomorrow on a GPU.

EdanToledo avatar Nov 03 '24 16:11 EdanToledo

I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.

EdanToledo avatar Nov 04 '24 12:11 EdanToledo

I just did a comparison, and it seems like sebulba on main is faster. Looking at all the timing statistics, its the pipeline that is slowing things down. Everything else on this branch is faster. I am trying this on a 2-gpu system. I'll try figure out why its slow.

So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.

EdanToledo avatar Nov 04 '24 12:11 EdanToledo

So basically, because the actors are faster now, the pipeline fills up faster and then causes the insertions to slow down a lot i think because there is more waiting and blocking which i imagine increases the overheads. I'm not sure on the solution to it though.

I think faster actors should almost always be better, maybe it just needs some hyper parameter tuning?

sash-a avatar Nov 04 '24 14:11 sash-a

so i ended up completely refactoring the sebulba architectures taking a mix of inspiration from cleanba and instadeeps one. When timing the previous systems versus cleanba, cleanba was much faster and now stoix is as fast and getting good performance without being as (in my opinion) chaotic as cleanba. I'm going to now evaluate the shardmap change again at some point. I thought id give an update almost a year later haha

EdanToledo avatar Jun 13 '25 09:06 EdanToledo

Honestly I'm not convinced we even need the shard map. You're welcome to close this PR if you want

sash-a avatar Jun 14 '25 17:06 sash-a