[WIP, Feature] MCTS
Implements MCTS planners.
We design an _MCTSNode class that represents a node in the graph. It is stateless (or so) in the sense that it does not store information internally but in a tensordict that is used to instantiate it. This makes it possible to retrieve the information via this tensordict.
We also provide a generic policy class that reads a tensordict, executes a tree search and picks an action. It returns the full tensordict containing all the leaves of the tree search (they can easily be pruned if needed).
Finally we provide a generic MCTS class for deterministic envs that executes rollout in an appropriate way (the mdp step with MCTS differs significantly in TorchRL with respect to more classical RL problems, say gym-like).
cc @pmcvay
Codecov Report
Merging #629 (05c6877) into main (4e1b878) will increase coverage by
0.06%. The diff coverage is95.55%.
:exclamation: Current head 05c6877 differs from pull request most recent head 5399dbf. Consider uploading reports for the commit 5399dbf to get more accurate results
@@ Coverage Diff @@
## main #629 +/- ##
==========================================
+ Coverage 87.92% 87.98% +0.06%
==========================================
Files 125 126 +1
Lines 24218 24733 +515
==========================================
+ Hits 21293 21761 +468
- Misses 2925 2972 +47
| Flag | Coverage Δ | |
|---|---|---|
| habitat-gpu | 23.93% <6.66%> (-5.24%) |
:arrow_down: |
| linux-cpu | 85.32% <95.55%> (+0.22%) |
:arrow_up: |
| linux-gpu | 86.62% <95.55%> (+0.19%) |
:arrow_up: |
| linux-outdeps-gpu | 76.09% <96.47%> (+0.43%) |
:arrow_up: |
| linux-stable-cpu | 85.21% <95.55%> (+0.22%) |
:arrow_up: |
| linux-stable-gpu | 86.50% <95.55%> (+0.19%) |
:arrow_up: |
| macos-cpu | 85.10% <95.55%> (+0.22%) |
:arrow_up: |
| olddeps-gpu | 76.91% <95.54%> (+0.40%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Impacted Files | Coverage Δ | |
|---|---|---|
| torchrl/envs/libs/utils.py | 16.17% <16.66%> (+0.30%) |
:arrow_up: |
| torchrl/modules/planners/common.py | 86.20% <66.66%> (-2.69%) |
:arrow_down: |
| torchrl/modules/planners/mcts.py | 94.67% <94.67%> (ø) |
|
| test/test_modules.py | 99.61% <100.00%> (+0.21%) |
:arrow_up: |
| torchrl/data/tensordict/tensordict.py | 82.91% <100.00%> (+0.04%) |
:arrow_up: |
| torchrl/envs/libs/habitat.py | 77.77% <0.00%> (-22.23%) |
:arrow_down: |
| test/_utils_internal.py | 67.27% <0.00%> (-20.00%) |
:arrow_down: |
| torchrl/envs/libs/gym.py | 80.71% <0.00%> (-3.05%) |
:arrow_down: |
| test/test_libs.py | 95.67% <0.00%> (-2.71%) |
:arrow_down: |
| torchrl/envs/utils.py | 95.00% <0.00%> (-1.67%) |
:arrow_down: |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
@vmoens thanks for providing an implementation of MCTS.
I am wondering if you would plan to merge this PR?
@vmoens thanks for providing an implementation of MCTS.
I am wondering if you would plan to merge this PR?
Thanks for the interest. I would love to keep working on this, but we currently have very little bandwidth to take care of that. If anyone wants to help us get this on its feet, we'd love to support this effort as much as we can!
Until then this will remain a dormant PR for the foreseeable future, I'm afraid. Hope that makes sense!
That makes sense, let me see if I can grab from your PR and add MCTS implementation. Do you have any comment / suggestion for me on completing this PR?
That makes sense, let me see if I can grab from your PR and add MCTS implementation. Do you have any comment / suggestion for me on completing this PR?
Sorry for the late reply @mjlaali! Thanks for proposing your help.
This is the current state of this PR:
- We worked a bit on an env interface with MCTS but this should be completely reconsidered IMO. The idea was to have an object that behaves like an env, but the
stepfunction does not write the result intd["next"]but intd["children", <action_taken>]where<action_taken>is a string indicating the action just taken. In other words, it's an env that supports taking multiple actions at the same stage and writing all the results in the td. Let me give an example of what we wanted to do:
env = make_env()
data = env.reset()
data["action"] = 1
data1 = env.step(data)
print(env["children", "1")) # result of action "1" -- incidentally it is the same thing as `data1`
data["action"] = 2
data2 = env.step(data)
print(env["children", "2")) # result of action "2" -- incidentally it is the same thing as `data2`
The env is stateless: that way you can get any input data and unroll a rollout from there without needing to modify the env internally. Not 100% sure this is the right interface but it seemed like a good way of navigating through a tree.
- We had
_MCTSNodeand similar classes that represent nodes in a tree. They are linked to aTensorDictrepresenting the state. The idea was to have an object that has specific behaviours but the underlying data structure is aTensorDictsince that is well suited to build nested structures. In the meantime, we introduced@tensorclasswhich allows us to do just that so this could be the way to go:
@tensorclass
class _MCTSNode:
state: TensorDictBase
prev_action: torch.Tensor
prev_action: torch.Tensor
parent: _MCTSNode
.... # some other data
n_actions: int
.... # some other meta-data
# we can define methods
@property
def visit_count(self) -> int:
"""Number of times this particular node has been accessed."""
return self.parent._child_visit_count[self.prev_action]
- Finally, we had a
policythat was defining the behaviour in the env, basically interfacing the env (simulator) with the tree (made of nodes as explained above). Have a look at theforwardfor more context.
I wrote some tests already, but everything is very rough.
- Some more thoughts:
- at the time we were talking with users that were looking for efficient solutions for all this. At the time, properly batching operations seemed like an easier win than writing everything in C++ and bring a lot of performance increase already. At a later stage, we were going to look at
torch.compileto reduce the overhead. I still think this plan is sound: I believe that tensorclass/tensordict are appropriate data structures to navigate a tree and writing the code in C++ is certainly a good choice if you want something super optimized but not if you're looking for something that makes it easy for your users to hack around and modify the code. - One thing we worked on in tensordict was a compact representation of trees. @tcbegley did an amazing work there. What we wanted is that if you have a data structure like this:
- at the time we were talking with users that were looking for efficient solutions for all this. At the time, properly batching operations seemed like an easier win than writing everything in C++ and bring a lot of performance increase already. At a later stage, we were going to look at
root = TensorDict{
"a": torch.zeros(2),
}, [])
leaf = TensorDict{
"a": torch.ones(2),
}, [])
root["leaf"] = leaf
Then you could actually store the tensors "a" compactly as a single tensor, where root["a"] would be located in a[0] and root["leaf", "a"] would be located at a[1]. That way you can batch operations across the tree. For instance, if you want to find which leaf is the closest to 1 you'd do
idx = (a - 1).norm(dim=-1).argmin()
which should give you idx == 1, that you can then convert back to the index ("leaf",) in the tree.
This is basically as much as I remember right now. If anything isn't clear I'd be happy to dig more into the literature and the code.
If things don't make sense feel free to reach out or write your thoughts here! This is very very much WIP so any improvement is welcome!
@vmoens thanks for explanation, I was a bit slow in implementation as I am bit busy with a few other things.
I like the idea of using env to encapsulate tree information, and re-use policy concept to define the tree traverse strategy.
During the implementation, I found a twist to this design that may simplify it but at the cost of making the environment stateful. I am not sure about this trade off and I thought that to share it here and get your feedback.
This is what I am thinking at this moment:
Either we define a new environment or use TransformedEnv (my preference) to augment env states with statistics of tree nodes. The statistics can be added to td["observation"] or td["mcts"] so that tree traverse policies can be called directly on td:
tree = TensorDictTree() # this class stores tree information
env = TransformedEnv(GymEnv("Pendulum-v1"), TreeSearch(nnet, tree))
state = env.reset()
print(state[("observation", "n_sa")]) # print number of times this state has been observed (here is zero)
tree_path = env.rollout(max_step=10, policy=AlphaZeroPolicy())
This is different from proposed approach in that it does not store tree information within TensorDict inputs, but in another class (here TensorDictTree())
This has a few advantages:
- We don't need to define a new type of Env (
DeterministicGraphEnv). - We can reuse rollout method to do tree search traverse.
But at the cost of having stateful env (does it create issue this env within ParallelEnv? What if we can make TensorDictTree thread safe?)
A note on TensorDictTree: This class will not store an explicit representation of a tree, but just a data storage to keep track of extra information within tree nodes (in this example n_sa). This is useful to avoid deterministic assumption in env (i.e. next state can be probabilistic given the action). It is worth to note that to have this flexibility, I have to define hash_value for a tensordict.
Finally, to do multiple rollout from an env and to best action after MCTS, we can use rollout api:
for _ in range(num_iterations):
tree_path = env.rollout(max_step=10, policy=AlphaZeroPolicy())
tree.update(tree_path) # here we update n_sa for each state in the path.
print(tree.get_root()['n_sa']) # this is frequency of each action in the root. The one with the max frequency is the high value
Oh that looks pretty amazing, eager to see what it looks like in practice!