tianshou icon indicating copy to clipboard operation
tianshou copied to clipboard

RNN for continuous CQL algorithm

Open BFAnas opened this issue 3 years ago • 15 comments

  • [X] I have marked all applicable categories:
    • [ ] exception-raising bug
    • [ ] RL algorithm bug
    • [ ] documentation request (i.e. "X is missing from the documentation.")
    • [X] new feature request
  • [X] I have visited the source website
  • [X] I have searched through the issue tracker for duplicates
  • [X] I have mentioned version numbers, operating system and environment, where applicable:
    import tianshou, torch, numpy, sys
    print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
    0.4.5 1.10.1+cu102 1.20.3 3.9.7 (default, Sep 16 2021, 13:09:58)  [GCC 7.5.0] linux
    

This is a request for RNN support in continuous CQL algorithm. Thanks for this awesome lib!

BFAnas avatar Jan 21 '22 09:01 BFAnas

@thkkk

Trinkle23897 avatar Jan 21 '22 14:01 Trinkle23897

I think the problem is in RecurrentActorProb class. From this part of the code it seems that it expects an input of shape [bsz, len*dim]:

self.nn = nn.LSTM(
              input_size=int(np.prod(state_shape)),
              hidden_size=hidden_layer_size,
              num_layers=layer_num,
              batch_first=True,
              )

But this part of the code suggests that the obs passed to self.nn is of shape [bsz, len, dim]:

obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
    obs, (hidden, cell) = self.nn(obs)

Could you investigate this?

BFAnas avatar Feb 22 '22 12:02 BFAnas

I think we can still use this setting by setting buffer's stack_num to >1. In short, when training RNN+CQL, we use [bsz, len, dim] to train a Recurrent network with trajectory length len, and when in the test phase we use [bsz, dim] because at this time the neural network can maintain a state to inference action.

Trinkle23897 avatar Feb 25 '22 02:02 Trinkle23897

@Trinkle23897 Thank you for your answer. I understand that, and that's not the problem. To explain better: I see a part of the code where the input expected is of shape [bsz, len*dim], whereas the input passed is of the shape [bsz, len, dim]. In this part of the code self.nn expects an input of shape [bsz, len*dim] note that int(np.prod(state_shape)) = len*dim :

self.nn = nn.LSTM(
              input_size=int(np.prod(state_shape)),
              hidden_size=hidden_layer_size,
              num_layers=layer_num,
              batch_first=True,
              )

And later, obs that is passed to self.nn is of shape [bsz, len, dim], therefore different of the shape that self.nn expects. Do you agree?

BFAnas avatar Feb 28 '22 12:02 BFAnas

note that int(np.prod(state_shape)) = len*dim

I don't think so. state_shape should always be a single frame, i.e., int(np.prod(state_shape)) = dim. If it's not the case, you should modify it outside correspondingly.

Trinkle23897 avatar Feb 28 '22 16:02 Trinkle23897

You mean I should have dim instead of len*dim? Even when I'm working with stack_num!=1? But anyway self.nn is getting obs of shape [bsz, len, dim] when it is expecting [bsz, int(np.prod(state_shape))] whatever that is.

BFAnas avatar Feb 28 '22 16:02 BFAnas

In [16]: m = nn.LSTM(input_size=3, hidden_size=10, num_layers=1, batch_first=True)

In [17]: s = torch.zeros([64, 1, 3])

In [18]: ns, (h, c) = m(s)

In [19]: ns.shape, h.shape, c.shape
Out[19]: (torch.Size([64, 1, 10]), torch.Size([1, 64, 10]), torch.Size([1, 64, 10]))

In [20]: s = torch.zeros([64, 16, 3])

In [21]: ns, (h, c) = m(s)

In [22]: ns.shape, h.shape, c.shape
Out[22]: (torch.Size([64, 16, 10]), torch.Size([1, 64, 10]), torch.Size([1, 64, 10]))

The input of self.nn.forward is always 3-dim tensor, not 2-dim.

Trinkle23897 avatar Feb 28 '22 16:02 Trinkle23897

If I have an observation of shape [bsz, len, dim] what is the state_shape argument that I should pass to RecurrentActorProb?

BFAnas avatar Feb 28 '22 17:02 BFAnas

Should be dim. Let's take atari example: the observation space is (4, 84, 84) where 4 is len. However, when defining recurrent network, the state_shape should be 84*84 instead of 4*84*84, and the length of trajectory is defined in replay buffer's sampling method.

Trinkle23897 avatar Feb 28 '22 18:02 Trinkle23897

Okay, thanks for the support. It is a little bit confusing since state_shape for the normal ActorProb is equal to obs.shape, maybe you can consider making them (ActorProb and RecurrentActorProb) coherent in this regard. Also, more ambitiously, maybe you can make the way of constructing RecurrentActorProb and RecurrentCritic with RecurentNet like ActorProb and Critic are constructed with Net.

BFAnas avatar Feb 28 '22 18:02 BFAnas

But here comes the problem: there are two ways to perform this kind of stack-obs:

  1. gym.Env outputs single frame -- stack by buffer.sample();
  2. gym.Env outputs stacked frame by FrameStack env wrapper -- no stack at all, or de-stack -> save to buffer -> stack by buffer.sample();

I cannot make any assumption here so that's why the current code looks like.

Trinkle23897 avatar Feb 28 '22 18:02 Trinkle23897

For making CQL work with RNN, I changed tmp_obs and tmp_obs_next in cql.py>CQLPolicy>learn as follows:

tmp_obs = obs.unsqueeze(1) \
    .repeat(1, self.num_repeat_actions, 1, 1) \
    .view(batch_size * self.num_repeat_actions, obs.shape[-2], obs.shape[-1])
tmp_obs_next = obs_next.unsqueeze(1) \
    .repeat(1, self.num_repeat_actions, 1, 1) \
    .view(batch_size * self.num_repeat_actions, obs.shape[-2], obs.shape[-1])

Now the code executes without errors, but maybe I'm missing something else necessary for RNN to work correctly.

BFAnas avatar Feb 28 '22 18:02 BFAnas

Glad to hear that!

Trinkle23897 avatar Feb 28 '22 18:02 Trinkle23897

Which task would you recommend for testing this solution? Ideally it should be a task in d4rl datasets and where SAC has been tried with RNNs and worked correctly, since CQL inherits from SAC.

BFAnas avatar Mar 04 '22 10:03 BFAnas

Which task would you recommend for testing this solution? Ideally it should be a task in d4rl datasets and where SAC has been tried with RNNs and worked correctly, since CQL inherits from SAC.

I think that the task for testing CQL is the same as the task testing for SAC, e.g., Pendulum for unit test or halfcheetah-medium in d4rl. I don't know if the existence of RNN will affect the choice of tasks.

thkkk avatar Mar 05 '22 01:03 thkkk