rl icon indicating copy to clipboard operation
rl copied to clipboard

[WIP, Feature] MCTS

Open vmoens opened this issue 3 years ago • 7 comments

Implements MCTS planners.

We design an _MCTSNode class that represents a node in the graph. It is stateless (or so) in the sense that it does not store information internally but in a tensordict that is used to instantiate it. This makes it possible to retrieve the information via this tensordict.

We also provide a generic policy class that reads a tensordict, executes a tree search and picks an action. It returns the full tensordict containing all the leaves of the tree search (they can easily be pruned if needed).

Finally we provide a generic MCTS class for deterministic envs that executes rollout in an appropriate way (the mdp step with MCTS differs significantly in TorchRL with respect to more classical RL problems, say gym-like).

cc @pmcvay

vmoens avatar Oct 31 '22 13:10 vmoens

Codecov Report

Merging #629 (05c6877) into main (4e1b878) will increase coverage by 0.06%. The diff coverage is 95.55%.

:exclamation: Current head 05c6877 differs from pull request most recent head 5399dbf. Consider uploading reports for the commit 5399dbf to get more accurate results

@@            Coverage Diff             @@
##             main     #629      +/-   ##
==========================================
+ Coverage   87.92%   87.98%   +0.06%     
==========================================
  Files         125      126       +1     
  Lines       24218    24733     +515     
==========================================
+ Hits        21293    21761     +468     
- Misses       2925     2972      +47     
Flag Coverage Δ
habitat-gpu 23.93% <6.66%> (-5.24%) :arrow_down:
linux-cpu 85.32% <95.55%> (+0.22%) :arrow_up:
linux-gpu 86.62% <95.55%> (+0.19%) :arrow_up:
linux-outdeps-gpu 76.09% <96.47%> (+0.43%) :arrow_up:
linux-stable-cpu 85.21% <95.55%> (+0.22%) :arrow_up:
linux-stable-gpu 86.50% <95.55%> (+0.19%) :arrow_up:
macos-cpu 85.10% <95.55%> (+0.22%) :arrow_up:
olddeps-gpu 76.91% <95.54%> (+0.40%) :arrow_up:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
torchrl/envs/libs/utils.py 16.17% <16.66%> (+0.30%) :arrow_up:
torchrl/modules/planners/common.py 86.20% <66.66%> (-2.69%) :arrow_down:
torchrl/modules/planners/mcts.py 94.67% <94.67%> (ø)
test/test_modules.py 99.61% <100.00%> (+0.21%) :arrow_up:
torchrl/data/tensordict/tensordict.py 82.91% <100.00%> (+0.04%) :arrow_up:
torchrl/envs/libs/habitat.py 77.77% <0.00%> (-22.23%) :arrow_down:
test/_utils_internal.py 67.27% <0.00%> (-20.00%) :arrow_down:
torchrl/envs/libs/gym.py 80.71% <0.00%> (-3.05%) :arrow_down:
test/test_libs.py 95.67% <0.00%> (-2.71%) :arrow_down:
torchrl/envs/utils.py 95.00% <0.00%> (-1.67%) :arrow_down:

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

codecov[bot] avatar Nov 02 '22 22:11 codecov[bot]

@vmoens thanks for providing an implementation of MCTS.

I am wondering if you would plan to merge this PR?

mjlaali avatar Aug 31 '23 03:08 mjlaali

@vmoens thanks for providing an implementation of MCTS.

I am wondering if you would plan to merge this PR?

Thanks for the interest. I would love to keep working on this, but we currently have very little bandwidth to take care of that. If anyone wants to help us get this on its feet, we'd love to support this effort as much as we can!

Until then this will remain a dormant PR for the foreseeable future, I'm afraid. Hope that makes sense!

vmoens avatar Aug 31 '23 16:08 vmoens

That makes sense, let me see if I can grab from your PR and add MCTS implementation. Do you have any comment / suggestion for me on completing this PR?

mjlaali avatar Sep 01 '23 00:09 mjlaali

That makes sense, let me see if I can grab from your PR and add MCTS implementation. Do you have any comment / suggestion for me on completing this PR?

Sorry for the late reply @mjlaali! Thanks for proposing your help.

This is the current state of this PR:

  • We worked a bit on an env interface with MCTS but this should be completely reconsidered IMO. The idea was to have an object that behaves like an env, but the step function does not write the result in td["next"] but in td["children", <action_taken>] where <action_taken> is a string indicating the action just taken. In other words, it's an env that supports taking multiple actions at the same stage and writing all the results in the td. Let me give an example of what we wanted to do:
env = make_env()
data = env.reset()
data["action"] = 1
data1 = env.step(data)
print(env["children", "1")) # result of action "1" -- incidentally it is the same thing as `data1`
data["action"] = 2
data2 = env.step(data)
print(env["children", "2")) # result of action "2"  -- incidentally it is the same thing as `data2`

The env is stateless: that way you can get any input data and unroll a rollout from there without needing to modify the env internally. Not 100% sure this is the right interface but it seemed like a good way of navigating through a tree.

  • We had _MCTSNode and similar classes that represent nodes in a tree. They are linked to a TensorDict representing the state. The idea was to have an object that has specific behaviours but the underlying data structure is a TensorDict since that is well suited to build nested structures. In the meantime, we introduced @tensorclass which allows us to do just that so this could be the way to go:
@tensorclass
class _MCTSNode:
    state: TensorDictBase
    prev_action: torch.Tensor
    prev_action: torch.Tensor
    parent: _MCTSNode
    .... # some other data
    n_actions: int
    .... # some other meta-data

    # we can define methods
    @property
    def visit_count(self) -> int:
        """Number of times this particular node has been accessed."""
        return self.parent._child_visit_count[self.prev_action]

  • Finally, we had a policy that was defining the behaviour in the env, basically interfacing the env (simulator) with the tree (made of nodes as explained above). Have a look at the forward for more context.

I wrote some tests already, but everything is very rough.

  • Some more thoughts:
    • at the time we were talking with users that were looking for efficient solutions for all this. At the time, properly batching operations seemed like an easier win than writing everything in C++ and bring a lot of performance increase already. At a later stage, we were going to look at torch.compile to reduce the overhead. I still think this plan is sound: I believe that tensorclass/tensordict are appropriate data structures to navigate a tree and writing the code in C++ is certainly a good choice if you want something super optimized but not if you're looking for something that makes it easy for your users to hack around and modify the code.
    • One thing we worked on in tensordict was a compact representation of trees. @tcbegley did an amazing work there. What we wanted is that if you have a data structure like this:
root = TensorDict{
 "a": torch.zeros(2),
}, [])
leaf = TensorDict{
 "a": torch.ones(2),
}, [])
root["leaf"] = leaf

Then you could actually store the tensors "a" compactly as a single tensor, where root["a"] would be located in a[0] and root["leaf", "a"] would be located at a[1]. That way you can batch operations across the tree. For instance, if you want to find which leaf is the closest to 1 you'd do

idx = (a - 1).norm(dim=-1).argmin()

which should give you idx == 1, that you can then convert back to the index ("leaf",) in the tree.

This is basically as much as I remember right now. If anything isn't clear I'd be happy to dig more into the literature and the code.

If things don't make sense feel free to reach out or write your thoughts here! This is very very much WIP so any improvement is welcome!

vmoens avatar Sep 05 '23 17:09 vmoens

@vmoens thanks for explanation, I was a bit slow in implementation as I am bit busy with a few other things.

I like the idea of using env to encapsulate tree information, and re-use policy concept to define the tree traverse strategy.

During the implementation, I found a twist to this design that may simplify it but at the cost of making the environment stateful. I am not sure about this trade off and I thought that to share it here and get your feedback.

This is what I am thinking at this moment:

Either we define a new environment or use TransformedEnv (my preference) to augment env states with statistics of tree nodes. The statistics can be added to td["observation"] or td["mcts"] so that tree traverse policies can be called directly on td:

tree = TensorDictTree()  # this class stores tree information

env = TransformedEnv(GymEnv("Pendulum-v1"), TreeSearch(nnet, tree))
state = env.reset()
print(state[("observation", "n_sa")])  # print number of times this state has been observed (here is zero)

tree_path = env.rollout(max_step=10, policy=AlphaZeroPolicy())

This is different from proposed approach in that it does not store tree information within TensorDict inputs, but in another class (here TensorDictTree())

This has a few advantages:

  • We don't need to define a new type of Env (DeterministicGraphEnv).
  • We can reuse rollout method to do tree search traverse.

But at the cost of having stateful env (does it create issue this env within ParallelEnv? What if we can make TensorDictTree thread safe?)

A note on TensorDictTree: This class will not store an explicit representation of a tree, but just a data storage to keep track of extra information within tree nodes (in this example n_sa). This is useful to avoid deterministic assumption in env (i.e. next state can be probabilistic given the action). It is worth to note that to have this flexibility, I have to define hash_value for a tensordict.

Finally, to do multiple rollout from an env and to best action after MCTS, we can use rollout api:

for _ in range(num_iterations):
  tree_path = env.rollout(max_step=10, policy=AlphaZeroPolicy())
  tree.update(tree_path)   # here we update n_sa for each state in the path.

print(tree.get_root()['n_sa'])   # this is frequency of each action in the root. The one with the max frequency is the high value  

mjlaali avatar Oct 09 '23 02:10 mjlaali

Oh that looks pretty amazing, eager to see what it looks like in practice!

vmoens avatar Oct 10 '23 06:10 vmoens