MCTS.jl icon indicating copy to clipboard operation
MCTS.jl copied to clipboard

First version of lock-based parallelization for vanilla MCTS.

Open kykim0 opened this issue 3 years ago • 25 comments

This is an initial implementation of a lock-based parallelization for vanilla MCTS. It is similar to the tree parallelization with node-level locking, though some differences due to how tree nodes are implemented may exist.

I've done a fairly extensive testing to check program correctness as well as performance. For this initial change though, I focused more on correctness than performance. The latter can be improved by e.g., using a more fine-grained use of locks or a lock-free approach which will be looked into next.

On an Ubuntu machine w/ i7-6700K CPU @ 4.00GHz 4 cores, the time to solve SimpleGridWorld w/ n_iterations=100_000, depth=20, exploration_constant=5.0 was roughly: t1: 3.137 s, t2: 2.015 s, t3: 1.708 s, t4: 1.966 s, t5: 2.698 s, t6: 5.545 s, t7: 6.670 s. In other words, using 3-4 threads finished quickest, and using too many threads took longer than using a single thread presumably because of thread contentions.

kykim0 avatar Dec 13 '21 07:12 kykim0

Thank you for the contribution, @kykim0 ! I won't be able to review this until later this week at the earliest. @WhiffleFish, @himanshugupta1009 since you have been working on similar parallelization for POMDP MCTS, would you be able to give a quick review to this in the meantime? (I know you guys haven't done many reviews before, so we can talk about what a review for this entails)

@kykim0 I assume that the new code is always active, even when there is only one thread - is that correct? Can you report the run time for the current master branch to see if there is a significant slow down for the single-threaded case? Thanks!

zsunberg avatar Dec 13 '21 17:12 zsunberg

Thanks for the feedback! I actually discovered sth quite interesting while trying to compare the run time for the single thread case, which is that when using a single thread, the Channel approach to implementing a type of producer-consumer interaction can be quite slow (though this is likely dependent on the MDP). This is presumably due to that for the single thread case, it can spend a lot of time context switching between the producer and consumer, whereas in case of multiple threads they can run in parallel. On the other hand, perhaps not surprisingly, the Channel approach was quite crucial for performance in case of multiple threads. I ended up adding the old iteration code back so that this is run in case the no. of threads is 1. I didn't see a good way to combine the two cases, so please let me know if you have a suggestion. It doesn't seem so bad the way I currently have them though.

In terms of run time for the single thread case, the test case I mentioned above runs for about 1.09 s, whereas the old code 0.81 s, so the new code is a bit slower. I think the difference mainly comes from the logic to support lock-based parallel MCTS where we do sth like lock(fn, lk) instead of running fn directly. We can check if Threads.nthreads() is 1 and preserve the old logic as much as possible, but that does make the code a bit convoluted. Let me know if this performance difference is a concern for you.

kykim0 avatar Dec 14 '21 10:12 kykim0

It's also worth mentioning that what is implemented in this PR is in one sense node-level locking e.g., backpropagating is done independently for different nodes, but in another sense is more akin to tree-level locking (very bad) e.g., having to lock the whole vectors when adding a new node. The latter is somewhat unavoidable given the current design of using vectors to represent tree nodes. This is in fact likely what's slowing things down when many threads are used. I think to really improve performance we might need redesign the tree so that each node is represented as a distinct instance. This way we can access different parts of the tree in a more asynchronous manner.

kykim0 avatar Dec 14 '21 14:12 kykim0

I decided to bite the bullet and try the alternative design of representing each node in the tree as a separate object. This made the code (much) simpler especially the new logic to support multithreading. It also gives a lot better multithreading performance as e.g., we can lock a specific state/action node when backpropagating w/o locking the whole vectors. In fact, I saw about 2x improvement when multithreading is enabled with the new code for one of the test cases I tried. This was a fairly extensive change, but I'm quite happy with the new code which is simpler and easier to reason about. Please let me know if you have any feedback!

kykim0 avatar Dec 16 '21 07:12 kykim0

@kykim0 Awesome! I am excited to look at it.

zsunberg avatar Dec 16 '21 14:12 zsunberg

Hey @kykim0, have you looked into using SpinLock as opposed to ReentrantLock? I've found that if the time spent waiting for locks to open is a major chokepoint, this change offers a rather sizeable performance improvement.

WhiffleFish avatar Dec 19 '21 20:12 WhiffleFish

Thanks @WhiffleFish for review! I do remember reading about SpinLock, but the documentation seems to suggest using ReentrantLock in general. Do you have a better sense of the difference between the two? On a related note, one thing I was planning on trying (w/ Julia 1.7) is concurrent data structures (https://github.com/JuliaConcurrent/ConcurrentCollections.jl) which allows to be lock-free.

kykim0 avatar Dec 21 '21 02:12 kykim0

@kykim0 In my understanding (which is admittedly quite limited so correct me if I'm wrong), reentrant locks are safe if there’s a chance that a single thread will try to access a lock more than once without unlocking i.e. trying to reenter the already locked lock (hence “reentrant”). This seems to be more common in parallel cyclic graph traversals, but seeing as we’re traversing an acyclic tree for which our locking/unlocking operations do not occur over a span of multiple visited nodes reentrant locks might be unnecessarily safe if the code is correct.

In contrast, a spin lock seems to loop in place checking for the lock to open effectively decreasing the time between a lock opening and a waiting thread engaging the lock. This would be inefficient if it’s left spinning doing nothing for a long time, but seeing as the locks here are only engaged for a very short period of time (about the duration of a few pushing operations) I would imagine it would be worth it to make the switch from reentrant locks.

I originally tried using spin locks for parallel POMCPOW because they were mentioned in this parallel MCTS paper: "...this solution requires to frequently lock and unlock parts of the tree. Hence, fast-access mutexes such as spinlocks have to be used to increase the maximum speedup."

In regards to using concurrent collections, I've never used the package so it could very well be that all this talk about locks is entirely unnecessary, so let me know if that works out!

WhiffleFish avatar Dec 21 '21 13:12 WhiffleFish

@WhiffleFish Thanks for the explanation! The difference makes a good sense. I'm wondering though whether we can generally assume that a given MDP is acyclic. In the paper you linked, they seem to have only used the game of Go for testing which indeed is acyclic, but in sth like grid world it is possible to revisit the same state during a rollout. It is still interesting to know that you saw a noticeable difference in performance using SpinLock. Perhaps we can let the user decide which lock to use based on what the user knows about the MDP. More recent papers (e.g., AlphaGo) seem to suggest though that a lock-free approach gives better performance, so another approach to look at. My plan is to do more extensive experiments to compare these approaches once this initial change is submitted. I'll keep you posted on it, also please keep me posted on what you find :-)

kykim0 avatar Dec 22 '21 03:12 kykim0

@kykim0 @WhiffleFish, regarding SpinLock, I think that MCTS trees will often have cycles, but I don't think that actually matters in this case. You need to use ReentrantLocks if there is any possibility that the same thread might re-acquire the same lock. But, since you are just quickly modifying one node and then releasing it before doing anything else, I do not think the same thread will ever try to re-acquire the same lock, even if the graph has cycles. Thus, my feeling is that SpinLocks are much more appropriate for locking individual nodes.

zsunberg avatar Dec 22 '21 18:12 zsunberg

Thanks @zsunberg for the review! I agree with your concern, and it does make sense to be more careful with a stable package that people already use. I don't think the new design will be faster than the vector approach for the single thread case, because of, as you also mentioned, the additional heap allocations and the extra logic to support multithreading. But, the new design is quite crucial for multithreading performance as it allows more fine-grained use of locks, and there is also a clear benefit of having one API to support both cases. So, one question to ask ourselves is how much performance loss are we willing to live with for the single thread case for the benefit of having one API. For instance, would 0.81s vs.1.09s for 100,000 iterations with max depth of 20 for a SimpleGridWorld MDP be tolerable? This is rather a tricky question.

Having said all that, I'm starting to think if your suggestion #3 is a better way forward. I don't necessarily need to submit this change to MCTS.jl for my specific use case (AST) but wanted to make the feature more broadly available. But, people who are already using MCTS.jl are likely content with the current set of features (e.g., single threaded execution), so it might make more sense to keep the code optimized for that use case. I took a brief look at CommonRLInterface.jl, and it seems to support AlphaZero.jl, so there is also the benefit of being able to test the parallel MCTS code for a broader set of environments. So, unless you feel strongly about other options, I'll hold off on this PR and explore #3. Is there a good place for the new package?

Regarding SpinLock, I actually observed my test code using SimpleGridWorld deadlocking which is sth that the API doc warns about when cycles exist. So, something to be cautious about.

kykim0 avatar Dec 23 '21 09:12 kykim0

I also realize suggestion #2 and #3 aren't necessarily mutually exclusive. If you think it's worth having a version of vanilla MCTS that's more optimized for multithreading, we can also do #2 in addition to #3.

kykim0 avatar Dec 23 '21 13:12 kykim0

As far as (2) vs (3), I don't really have a strong preference. I'd say go with your gut. The main advice I have is to take baby steps and be prepared to scrap large portions of what you do and start those portions over :) We are still learning a lot as we go. If you want to, it would certainly be nice to have (2) as an intermediate step, but we may delete it if (3) is successful.

Is there a good place for the new package?

Do you have permission to create JuliaPOMDP/MonteCarloTreeSearch.jl ? I'd recommend using PkgTemplates.jl to create all the structure.

Regarding SpinLock, I actually observed my test code using SimpleGridWorld deadlocking which is sth that the API doc warns about when cycles exist. So, something to be cautious about.

I think we should investigate this more. I think we can avoid it. I think the issue is that right now you are just spawning n_iterations tasks all at the same time, so the same thread could be trying to do multiple tasks and hence access the same lock twice. I'll send you a gist of what I think you should do instead.

Excited to make this work! :)

zsunberg avatar Dec 23 '21 20:12 zsunberg

I'll send you a gist of what I think you should do instead.

Actually the pattern I was thinking of won't work because you don't have any control over which threads do which tasks. So as of this moment, I think we need to use reentrant locks. I am still curious though whether only having nthreads() tasks trying to run at the same time would be better than launching n_iterations tasks at the same time. Intuitively, it seems like that many tasks could really clog up all the tubes. @WhiffleFish , when you saw the performance increase with SpinLock, how many tasks did you launch concurrently?

zsunberg avatar Dec 23 '21 21:12 zsunberg

I am still curious though whether only having nthreads() tasks trying to run at the same time would be better than launching n_iterations tasks at the same time.

My intuition about several things has been wrong: it is slightly better to launch a large number of individual tasks, and it doesn't seem like SpinLock offers any performance advantage in this test: https://gist.github.com/zsunberg/9743519bfd1d9bb58e5c5c55a6e759b7

zsunberg avatar Dec 24 '21 01:12 zsunberg

Thanks for the feedback (also for the nice comparison test)! I'll go ahead and start a new package JuliaPOMDP/MonteCarloTreeSearch.jl, and also see if we can do sth quick for #2. I'll keep this PR open though so I can come back to the code while porting things over :-)

kykim0 avatar Dec 24 '21 05:12 kykim0

Looks like I do have permission to create a package under JuliaPOMDP but wasn't quite sure what the canonical settings are for JuliaPOMDP packages. Could you share with me sample code using PkgTemplates.jl for creating one? I found https://github.com/JuliaPOMDP/POMDPTools.jl/commit/a0fb5d62f39ce8bb92af656cd857955f79eac9dc but that seems to only contain the generated code and not the code that generated the commit. Thanks!

kykim0 avatar Dec 24 '21 05:12 kykim0

I'll send you a gist of what I think you should do instead.

Actually the pattern I was thinking of won't work because you don't have any control over which threads do which tasks. So as of this moment, I think we need to use reentrant locks. I am still curious though whether only having nthreads() tasks trying to run at the same time would be better than launching n_iterations tasks at the same time. Intuitively, it seems like that many tasks could really clog up all the tubes. @WhiffleFish , when you saw the performance increase with SpinLock, how many tasks did you launch concurrently?

I was using the same tasking method that @kykim0 used when I saw the performance improvement with SpinLocks. In fact, I fixed the deadlocks on my own MCTS fork and SpinLocks again offered a ~2x speed improvement over ReentrantLocks.

mdp = SimpleGridWorld()
sol = MCTSSolver(n_iterations=100_000, depth=20, exploration_constant=5.0)
planner = solve(sol, mdp)
@benchmark action(planner, SA[1,2])

With 4 threads... Using ReentrantLocks: 532.695 ± 33.711 ms Using SpinLocks: 260.996 ± 20.003 ms

That being said, seeing as this was quickly cobbled together with minimal testing I can't attest to how safe my code is.

WhiffleFish avatar Dec 24 '21 19:12 WhiffleFish

Could you share with me sample code using PkgTemplates.jl for creating one? I found JuliaPOMDP/POMDPTools.jl@a0fb5d6 but that seems to only contain the generated code and not the code that generated the commit. Thanks!

@kykim0 , sorry I do not have example code. In general, we want the following things

  • CompatHelper
  • TagBot
  • CI (running the tests) via github actions
  • github default branch should be main
  • Documentation via Documenter.jl
  • MIT license

Most things can be added later if you don't get it exactly right the first time.

zsunberg avatar Dec 28 '21 19:12 zsunberg

In fact, I fixed the deadlocks on my own MCTS fork and SpinLocks again offered a ~2x speed improvement over ReentrantLocks.

That being said, seeing as this was quickly cobbled together with minimal testing I can't attest to how safe my code is.

Hmmm... I guess we should figure out exactly when deadlocks will happen. i.e. what does "Recursive use" mean? This seems to always cause a deadlock if run on a single task:

l = Threads.SpinLock()
lock(l)
lock(l)

But that seems pretty easy to avoid.

It also seems that trying to lock from different tasks on the same thread will cause a deadlock:

l = Threads.SpinLock()

function f(l)
    lock(l)
    sleep(5)
    unlock(l)
end

Threads.@spawn f(l)
sleep(1)
lock(l)
println("done") # never prints this if run with one thread on my machine

Then the question is how could you ever safely use a SpinLock? Maybe if you could ensure that the thread can never switch tasks while the lock is held? I do not know enough about the thread scheduling model to know how to ensure this.

zsunberg avatar Dec 28 '21 19:12 zsunberg

I went ahead and asked about it here: https://discourse.julialang.org/t/when-can-tasks-yield-or-how-to-use-spinlocks-safely/73740

zsunberg avatar Dec 28 '21 23:12 zsunberg

I spent some time reading up on this today, and another architecture we should consider for tree parallelism is:

  1. Get rid of all locks on nodes and/or the tree
  2. Only allow one task to modify the tree
  3. All other tasks besides the modifier can only read the tree and submit changes to the modifier through a Channel

zsunberg avatar Dec 28 '21 23:12 zsunberg

I tried the fix by @WhiffleFish, which was to not unnecessarily grab the node lock in the best_sanode_UCB method and also saw a bit of speed up though not as dramatic as 2x. With 4 threads I saw

  • (a) 1.381 s for ReentrantLock w/ unnecessarily re-grabbing the node lock
  • (b) 1.120 s for ReentrantLock w/ the fix
  • (c) 920.853 ms for SpinLock w/ the fix (my machine is clearly slower :-()

The difference between (b) and (c) i.e., ReentrantLock vs. SpinLock is interesting, but I find that between (a) and (b) equally interesting, which suggests that lock contention can have quite an impact on performance (sounds obvious but interesting to see with actual numbers). Perhaps what all this means is that we should go with a lock-free approach if we want to optimize for performance. Ideally, I think it'd be nice to support both an approach that maintains 'data integrity' at the expense of (hopefully) a bit of performance loss and an approach that optimizes for performance. Based on our discussion so far, I think we want to try

  1. An approach w/ data integrity guarantees
  • Using ReentrantLock
  • Using SpinLock (somehow ensuring deadlocks don't occur)
  • Using JuliaConcurrent/ConcurrentCollections.jl
  • Using sth like https://github.com/JuliaPOMDP/MCTS.jl/pull/86#issuecomment-1002325069
  1. An approach w/o data integrity guarantees
  • Using Atomics

On a related note, I created a new package https://github.com/JuliaPOMDP/MonteCarloTreeSearch.jl. I'm looking into implementing sth like what I've done in this PR on CommonRLInterface. It'll be exciting to try our code for a broader set of environments like AlphaZero.jl!

kykim0 avatar Jan 02 '22 02:01 kykim0

Just wanted to share this very interesting thread on lock performance: https://discourse.julialang.org/t/poor-performance-of-lock-l-do-closures-caused-by-poor-performance-of-capturing-closures/42067/13. It might be worth trying the Base.@lock approach and see if that makes much difference in our use case.

On a related note, I've been a bit occupied with other things so that I wasn't able to make more progress on this i.e., porting code over to MonteCarloTreeSearch.jl, but hope to revisit in the next couple weeks :-)

kykim0 avatar Jan 31 '22 00:01 kykim0

@kykim0 , thanks for sharing the thread. That is interesting. Excited for when you are able to get back to the MCTS!

zsunberg avatar Feb 02 '22 16:02 zsunberg