rl icon indicating copy to clipboard operation
rl copied to clipboard

[WIP] add multiagentRNN

Open kfu02 opened this issue 1 year ago • 15 comments

Description

Per #2003 adds multi-agent GRU and LSTMs to torchRL's multiagent modules.

Modifies the MultiAgentNetBase class to take in multiple input tensors and output tensors, which allows these recurrent multi-agent nets to input/output hidden states (e.g. (input, h_x)).

Test gist: https://gist.github.com/kfu02/87ae6c6d99e681d474f4977a9653b329

Motivation and Context

Why is this change required? What problem does it solve? If it fixes an open issue, please link to the issue here. You can use the syntax close #15213 if this solves the issue #15213

close #2003

  • [x] I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds core functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Documentation (update in the documentation)
  • [ ] Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • [x] I have read the CONTRIBUTION guide (required)
  • [x] My change requires a change to the documentation.
  • [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
  • [x] I have updated the documentation accordingly.

kfu02 avatar Feb 22 '24 05:02 kfu02

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1948

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Feb 22 '24 05:02 pytorch-bot[bot]

Hi @vmoens , this draft PR is to add multi-agent RNN as you proposed in https://github.com/pytorch/rl/pull/1921#issuecomment-1955117158

I have one blocking issue: the LSTMNet in torchRL returns a tuple of (output, hidden_states...), but the forward() call in MultiAgentNetBase expects the network from build_single_net() to output a single tensor (as the MLP and ConvNet both do).

I am not sure what the best solution is here. Override the forward call? Add an MLP to the end of MultiAgentRNN (as that's the most typical use case)? To be honest, I'm also not fully understanding how the LSTMNet gets the previous hidden state as input either.

I would appreciate any clarity you could provide!

kfu02 avatar Feb 22 '24 05:02 kfu02

I did a bunch of fixes for this PR you can test it with this https://gist.github.com/vmoens/c24c36b1efcbb159638dc0bf4cb12f15

We have an arg that indicates that the agent dim is for the input but we should add one for the output too. The only reason we consider it's -2 is because CNN and MLP both end up with a linear layer which has just one dim of features.

Like you were mentioning: we must also allow the MARL modules to accept more than one tensor as input (eg in this case the input will be either a Tensor or a Tensor, Tuple[Tensor, Tensor]). I will make a follow up edit with that

vmoens avatar Feb 26 '24 21:02 vmoens

I think I got it working but I cannot guarantee that it's what it's supposed to do. Here's what signatures you should expect:

centralized = True
  |- share_params = True
  |---- input
  |           |- x [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Hidden]  <====== No agents
  |           |- hidden[1] [Layers, Batch, Hidden]
  |---- output
  |           |- y [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Hidden]
  |           |- hidden[1]  [Layers, Batch, Hidden]
  |- share_params = False
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
centralized = False
  |- share_params = True
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |- share_params = False
  |---- input
  |           |- x  [Batch, Time, Agents, Features]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  |---- output
  |           |- y  [Batch, Time, Agents, Hidden]
  |           |- hidden[0]  [Layers, Batch, Agents, Hidden]
  |           |- hidden[1]  [Layers, Batch, Agents, Hidden]
  

So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases

@matteobettini looking for feedback here! :)

I updated the gist above

vmoens avatar Feb 26 '24 22:02 vmoens

So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases

You have a typo where sometimes it hidden is the last dim of y and sometimes it is Features. also, is hidden == Hidden?

Maybe would be useful to declare the dimensions at the beginning of that tree

For the only case that is different (centralized=True, share_params=True), our convention was to expand the output to resemble the existance of the multiagent dimension. Why did you decide not to do this for all outputs? It seems strange to me that we do it for y but not for the others. I would do all or nothing.

matteobettini avatar Feb 27 '24 10:02 matteobettini

Another issue relates to how to store the hidden states in the grouping API.

Imagine I have 2 agent groups with cardinality N_A and N_B, then I might have a td that looks like

Tensordict(
group_a: TensorDict(
    batch_size=(B,T,N_A)
    ),
group_b: TensorDict(
    batch_size=(B,T,N_B)
    )
batch_size(B,T)
)

Now, if both groups use LSTMs, x and y can go in their tds, but the hidden states are more problematic, as their shape makes it so that we cannot use that structuring

matteobettini avatar Feb 27 '24 10:02 matteobettini

Why did you decide not to do this for all outputs?

I did not "decide", it's just not possible. If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big). You can only expand the output you're not feeding back recursively.

I fixed the diagram.

vmoens avatar Feb 28 '24 21:02 vmoens

Another issue relates to how to store the hidden states in the grouping API.

I agree but that's the LSTM format so I would not change that. The goal of these classes is to be used without tensordict, so tensordict formatting should not impact the data format. This class is supposed to be the MA version of nn.LSTM.

If we want to use it with a tensordict we can build a class similar to the LSTM net we have for single agents.

vmoens avatar Feb 28 '24 21:02 vmoens

I did not "decide", it's just not possible. If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big). You can only expand the output you're not feeding back recursively.

Got it. So it is not possible to just say that when you have centralised=True, share=True, the hidden state should be just indexed alond the first elem of the agent dim instead of flattened?

Also, how difficult do you think would be to add GRU to this PR as well? It has been found to work better than LSTM in rl and since we already added the infrastructure maybe it requires minimal additions

matteobettini avatar Feb 29 '24 09:02 matteobettini

@kfu02 went for LSTM but GRU would be easier. I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective

We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain

vmoens avatar Feb 29 '24 15:02 vmoens

@kfu02 went for LSTM but GRU would be easier. I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective

We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain

I can switch the current code to use a GRU. I was not aware that LSTMs were outperformed by GRUs, thanks @matteobettini !

kfu02 avatar Feb 29 '24 20:02 kfu02

Now that we have LSTM working I think we can do both!

vmoens avatar Feb 29 '24 21:02 vmoens

I was not aware that LSTMs were outperformed by GRUs

At least that is what we observed https://arxiv.org/abs/2303.01859

matteobettini avatar Mar 01 '24 09:03 matteobettini

@vmoens @kfu02 what is the satus of this? ready for review?

matteobettini avatar Mar 27 '24 15:03 matteobettini

@vmoens @kfu02 what is the satus of this? ready for review?

Apologies, I still have to add docstrings and unit tests, and it has fallen behind in my priorities. I will complete this over the weekend most likely.

kfu02 avatar Mar 28 '24 15:03 kfu02

@vmoens do you think we could reopen this on a new branch?

matteobettini avatar May 20 '24 09:05 matteobettini

Sure we should get this feature, I just didn't have the time to work on it yet. Should I try to solve it on my own? I had the impression you guys were on it

vmoens avatar May 20 '24 09:05 vmoens

I am keen on the feature, but I don't have bandwidth as this moment due to deadlines and it seems @kfu02 won't work on this anymore.

I think it just misses testing and docs and it should be good to go.

My personal concern is also a loss in readability in the mutli-agent modules after this change.

We can reopent it in another PR and then first one of us that picks this up can work to ship it.

matteobettini avatar May 20 '24 09:05 matteobettini

We can work on readability. This is an unmerged PR, the proposed changes have not been properly tested or documented yet so I would not jump straight ahead to discarding the changes because of readability issues before we get this piece to a mature stage.

Having one parent class that account for all nets makes it easier to check that we have all the features covered across networks. I noticed multiple inconsistencies across MLP and CNN so if we scale things up to other networks we need to build an API that is consistent and testable for all. Also there are some non-trivial tricks I want to apply in the future to make things more readable and faster to execute but that will be considerably harder if I have to patch 4 or more classes independently.

vmoens avatar May 20 '24 11:05 vmoens