rl
rl copied to clipboard
[WIP] add multiagentRNN
Description
Per #2003 adds multi-agent GRU and LSTMs to torchRL's multiagent modules.
Modifies the MultiAgentNetBase class to take in multiple input tensors and output tensors, which allows these recurrent multi-agent nets to input/output hidden states (e.g. (input, h_x)
).
Test gist: https://gist.github.com/kfu02/87ae6c6d99e681d474f4977a9653b329
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax close #15213
if this solves the issue #15213
close #2003
- [x] I have raised an issue to propose this change (required for new features and bug fixes)
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds core functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Documentation (update in the documentation)
- [ ] Example (update in the folder of examples)
Checklist
Go over all the following points, and put an x
in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!
- [x] I have read the CONTRIBUTION guide (required)
- [x] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (required for a bug fix or a new feature).
- [x] I have updated the documentation accordingly.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1948
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hi @vmoens , this draft PR is to add multi-agent RNN as you proposed in https://github.com/pytorch/rl/pull/1921#issuecomment-1955117158
I have one blocking issue: the LSTMNet in torchRL returns a tuple of (output, hidden_states...), but the forward()
call in MultiAgentNetBase expects the network from build_single_net()
to output a single tensor (as the MLP and ConvNet both do).
I am not sure what the best solution is here. Override the forward call? Add an MLP to the end of MultiAgentRNN (as that's the most typical use case)? To be honest, I'm also not fully understanding how the LSTMNet gets the previous hidden state as input either.
I would appreciate any clarity you could provide!
I did a bunch of fixes for this PR you can test it with this https://gist.github.com/vmoens/c24c36b1efcbb159638dc0bf4cb12f15
We have an arg that indicates that the agent dim is for the input but we should add one for the output too. The only reason we consider it's -2 is because CNN and MLP both end up with a linear layer which has just one dim of features.
Like you were mentioning: we must also allow the MARL modules to accept more than one tensor as input (eg in this case the input will be either a Tensor or a Tensor, Tuple[Tensor, Tensor]
). I will make a follow up edit with that
I think I got it working but I cannot guarantee that it's what it's supposed to do. Here's what signatures you should expect:
centralized = True
|- share_params = True
|---- input
| |- x [Batch, Time, Agents, Features]
| |- hidden[0] [Layers, Batch, Hidden] <====== No agents
| |- hidden[1] [Layers, Batch, Hidden]
|---- output
| |- y [Batch, Time, Agents, Hidden]
| |- hidden[0] [Layers, Batch, Hidden]
| |- hidden[1] [Layers, Batch, Hidden]
|- share_params = False
|---- input
| |- x [Batch, Time, Agents, Features]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
|---- output
| |- y [Batch, Time, Agents, Hidden]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
centralized = False
|- share_params = True
|---- input
| |- x [Batch, Time, Agents, Features]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
|---- output
| |- y [Batch, Time, Agents, Hidden]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
|- share_params = False
|---- input
| |- x [Batch, Time, Agents, Features]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
|---- output
| |- y [Batch, Time, Agents, Hidden]
| |- hidden[0] [Layers, Batch, Agents, Hidden]
| |- hidden[1] [Layers, Batch, Agents, Hidden]
So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases
@matteobettini looking for feedback here! :)
I updated the gist above
So pretty much all the same except that there's no Agent dim in the hidden state when it's shared and centralized since we flatten / expand the input / output and pass it through the net as in non-MARL cases
You have a typo where sometimes it hidden
is the last dim of y and sometimes it is Features
. also, is hidden
== Hidden
?
Maybe would be useful to declare the dimensions at the beginning of that tree
For the only case that is different (centralized=True, share_params=True), our convention was to expand the output to resemble the existance of the multiagent dimension. Why did you decide not to do this for all outputs? It seems strange to me that we do it for y but not for the others. I would do all or nothing.
Another issue relates to how to store the hidden states in the grouping API.
Imagine I have 2 agent groups with cardinality N_A
and N_B
, then I might have a td that looks like
Tensordict(
group_a: TensorDict(
batch_size=(B,T,N_A)
),
group_b: TensorDict(
batch_size=(B,T,N_B)
)
batch_size(B,T)
)
Now, if both groups use LSTMs, x and y can go in their tds, but the hidden states are more problematic, as their shape makes it so that we cannot use that structuring
Why did you decide not to do this for all outputs?
I did not "decide", it's just not possible. If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big). You can only expand the output you're not feeding back recursively.
I fixed the diagram.
Another issue relates to how to store the hidden states in the grouping API.
I agree but that's the LSTM format so I would not change that. The goal of these classes is to be used without tensordict, so tensordict formatting should not impact the data format. This class is supposed to be the MA version of nn.LSTM.
If we want to use it with a tensordict we can build a class similar to the LSTM net we have for single agents.
I did not "decide", it's just not possible. If you expand it, you will feed back hidden states that are expanded, flatten them, and get even bigger hidden states (if that doesn't break - actually it will break because they're too big). You can only expand the output you're not feeding back recursively.
Got it. So it is not possible to just say that when you have centralised=True, share=True, the hidden state should be just indexed alond the first elem of the agent dim instead of flattened?
Also, how difficult do you think would be to add GRU to this PR as well? It has been found to work better than LSTM in rl and since we already added the infrastructure maybe it requires minimal additions
@kfu02 went for LSTM but GRU would be easier. I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective
We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain
@kfu02 went for LSTM but GRU would be easier. I'm not super in favour of expanding a tensor and then taking the first index. It's confusing (let people think the content is different) and dangerous both from a usage and memory perspective
We implemented something like that for indices in replay buffers and I wish we didn't! It's extremely hard to maintain
I can switch the current code to use a GRU. I was not aware that LSTMs were outperformed by GRUs, thanks @matteobettini !
Now that we have LSTM working I think we can do both!
I was not aware that LSTMs were outperformed by GRUs
At least that is what we observed https://arxiv.org/abs/2303.01859
@vmoens @kfu02 what is the satus of this? ready for review?
@vmoens @kfu02 what is the satus of this? ready for review?
Apologies, I still have to add docstrings and unit tests, and it has fallen behind in my priorities. I will complete this over the weekend most likely.
@vmoens do you think we could reopen this on a new branch?
Sure we should get this feature, I just didn't have the time to work on it yet. Should I try to solve it on my own? I had the impression you guys were on it
I am keen on the feature, but I don't have bandwidth as this moment due to deadlines and it seems @kfu02 won't work on this anymore.
I think it just misses testing and docs and it should be good to go.
My personal concern is also a loss in readability in the mutli-agent modules after this change.
We can reopent it in another PR and then first one of us that picks this up can work to ship it.
We can work on readability. This is an unmerged PR, the proposed changes have not been properly tested or documented yet so I would not jump straight ahead to discarding the changes because of readability issues before we get this piece to a mature stage.
Having one parent class that account for all nets makes it easier to check that we have all the features covered across networks. I noticed multiple inconsistencies across MLP and CNN so if we scale things up to other networks we need to build an API that is consistent and testable for all. Also there are some non-trivial tricks I want to apply in the future to make things more readable and faster to execute but that will be considerably harder if I have to patch 4 or more classes independently.