rl icon indicating copy to clipboard operation
rl copied to clipboard

[BUG] Planning (`MPCPlannerBase`) should consider `done`

Open FrankTianTT opened this issue 2 years ago • 9 comments

Describe the bug

In the current implementation, all subclasses of MPCPlannerBase do not consider done thrown by env during the planning process, which means that MPC is invalid in a large class of environments. For example, in CEM:

            sum_rewards = optim_tensordict.get(self.reward_key).sum(
                dim=TIME_DIM, keepdim=True
            )

Specifically, one type of environment indicates that the agent has entered a dangerous state by throwing done (usually the reward is positive in non-dangerous states), including many environments of gym-mujoco, such as InvertedPendulum and Hopper. The MPC algorithm needs to identify done thrown by the environment and find the action sequence that maximizes the cumulative reward before done.

To Reproduce

Just try CEM on InvertedPendulum.

Reason and Possible fixes

For CEM, a simple fix chould be:

            dones = optim_tensordict.get(("next", "done"))
            rewards = optim_tensordict.get(self.reward_key)
            assert rewards.shape == dones.shape
            for candidate in range(self.num_candidates):
                if (~dones[candidate]).all():  # no done
                    continue

                if len(dones[candidate].shape) == 1:  # both done and reward are float
                    idx = torch.nonzero(dones[candidate]).min()
                    rewards[candidate, idx:] = 0
                elif len(dones[candidate].shape) == 2:  # both done and reward are 1-dim tensor
                    idx = torch.nonzero(dones[candidate, :, 0]).min()
                    rewards[candidate, idx:, 0] = 0
                else:
                    raise ValueError("Unsupported shape for done and reward")

            sum_rewards = rewards.sum(
                dim=TIME_DIM, keepdim=True
            )

I'm more than happy to submit my changes, but they may require further style uniformity and standardization. At the same time, it is likely that there is a more efficient way.

Checklist

  • [x] I have checked that there is no similar issue in the repo (required)
  • [x] I have read the documentation (required)
  • [x] I have provided a minimal working example to reproduce the bug (required)

FrankTianTT avatar Oct 29 '23 13:10 FrankTianTT

Thanks for raising this, let me think about it, I guess we can do that in a vectorized way!

vmoens avatar Oct 30 '23 21:10 vmoens

Thanks for raising this, let me think about it, I guess we can do that in a vectorized way!

@vmoens Yes, I also think the efficiency of these loops maybe low. However, I found a new issue again, which ask we to change the rollout of env. Maybe the manipulation of reward can be done in this new rollout function.

The new problem is that the default rollout function will stop when any one env throws done:

            optim_tensordict = self.env.rollout(
                max_steps=self.planning_horizon,
                policy=policy,
                auto_reset=False,
                tensordict=optim_tensordict,
            )

but we need a rollout function that stops when all env are done. A new rollout function can be (it's a method of MPCPlannerBase, rather that EnvBase):

    def reward_truncated_rollout(self, policy, tensordict):
        tensordicts = []
        ever_done = torch.zeros(*tensordict.batch_size, 1, dtype=bool).to(self.device)
        for i in range(self.planning_horizon):
            tensordict = policy(tensordict)
            tensordict = self.env.step(tensordict)

            tensordict.get(("next", "reward"))[ever_done] = 0
            tensordicts.append(tensordict)

            ever_done |= tensordict.get(("next", "done"))
            if ever_done.all():
                break
        batch_size = self.batch_size if tensordict is None else tensordict.batch_size
        out_td = torch.stack(tensordicts, len(batch_size)).contiguous()
        out_td.refine_names(..., "time")

        return out_td

I conduct the reward-truncation in this new rollout in a vectorized way (according to ever_done). With this new rollout, the only change in planning is changing optim_tensordict = self.env.rollout(...) to optim_tensordict = self.reward_truncated_rollout(...), which means no more manipulation on optim_tensordict.get(self.reward_key).

FrankTianTT avatar Nov 01 '23 08:11 FrankTianTT

The new problem is that the default rollout function will stop when any one env throws done:

There is a break_when_any_done argument that can be used to manipulate this. Would that solve it?

vmoens avatar Nov 01 '23 12:11 vmoens

There is a break_when_any_done argument that can be used to manipulate this. Would that solve it?

@vmoens Unfortunately not, when break_when_any_done=False, all envs will be reseted when any one env throws done. However, we do not want reset any of them (In fact, if rollout just reset the env that throws done, it's acceptable. But it reset all.)

FrankTianTT avatar Nov 01 '23 14:11 FrankTianTT

@vmoens Unfortunately not, when break_when_any_done=False, all envs will be reseted when any one env throws done.

In theory, rollout will only reset the envs that are done. You can check in the doc how this is done: we assign a "_reset" key corresponding to each done that indicates what should be reset. Which env are you referring to, exactly?

However, we do not want reset any of them (In fact, if rollout just reset the env that throws done, it's acceptable. But it reset all.)

It is a problem if the envs are reset (assuming this is done properly)? Do you mean that the env.step() should just stop being called if the sub-env is done? We do not cover this as of now, but it could be a feature for a distant release. It won't be easy to come by though (at least efficiently): "_reset" brings a lot of overhead that we hope to mitigate by calling "_reset" only rarely.

vmoens avatar Nov 01 '23 18:11 vmoens

In theory, rollout will only reset the envs that are done. You can check in the doc how this is done: we assign a "_reset" key corresponding to each done that indicates what should be reset.

Yep, you are right, my fault.

Which env are you referring to, exactly?

ModelBasedEnvBase, to be more specific, example code in https://github.com/pytorch/rl/pull/1657, where we write a new _reset that does not support individual reset:

    def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        tensordict = TensorDict(
            {},
            batch_size=self.batch_size,
            device=self.device,
        )
        tensordict = tensordict.update(self.state_spec.rand())
        tensordict = tensordict.update(self.observation_spec.rand())
        return tensordict

It seem to be available to make _reset function sensitive to "_reset" key.

It is a problem if the envs are reset (assuming this is done properly)? Do you mean that the env.step() should just stop being called if the sub-env is done?

My expectation for the rollout function is that when multiple environments are running at the same time, even if one environment throws done, it will continue to run until 1) all environments have thrown done or 2) max_steps is reached.

Note that even if we implement the logic of handling "_reset" in the _reset function, the original rollout is still different from what I expected: it will only stop when max_steps is reached, and will not stop early (because it has reset the done environment). This will lead to a reduction in the efficiency of CEM (because each rollout must run all max_steps steps, but early stopping is very common in some environments, such as Hopper-v4)

In short, all the above problems are caused by trying to let MPCPlannerBase handle done. If the environment will never done, there is not much problem with the algorithm. To do it, there at least two issues:

  • reset individual env rather than all when one env throws done
  • truncate the reward since one env has thrown done

In my opinion, a special rollout function of MPCPlannerBase seem to be reasonable, because these special function maybe not necessary for EnvBase.

BTW, to make a more effective planning, we should torch.no_grad() the rollout, because gradient is not needed in CEM.

FrankTianTT avatar Nov 02 '23 03:11 FrankTianTT

  • I can definitely adapt ModelBaseEnvBase to match the partial reset, this should have been done a long time ago.
  • To make the rollout step early we can design a special transform that keeps track of the dones, wdyt?
    base_env = MyBatchedEnv()
    env = TransformedEnv(base_env, AggregateDone(rule="all")) # rule can be all, in which case `done` is `True` when all envs have been done at least once, or `any` in which case just one done suffies. 
    env.rollout(1_000) # will stop at 1000 or when all envs have encountered a done, whichever comes first
    

vmoens avatar Nov 02 '23 15:11 vmoens

  • To make the rollout step early we can design a special transform that keeps track of the dones, wdyt?

You are right, and we can even do the reward-truncated in this special transform, and make a just tiny change to planning function (add torch.no_grad()). I'm glad to submit a PR to conduct that.

BTW, there are some tiny different between CEM in torchrl and that in mbrl-lib: update mean and std of actions by momentum, rather than direct assignment:

change from

        for _ in range(self.optim_steps):
            ...
            container.set_(("stats", "_action_means"), best_actions.mean(dim=K_DIM, keepdim=True))
            container.set_(("stats", "_action_stds"), best_actions.std(dim=K_DIM, keepdim=True))
            ...

to

        for _ in range(self.optim_steps):
            ...
            self.update_stats(
                best_actions.mean(dim=K_DIM, keepdim=True),
                best_actions.std(dim=K_DIM, keepdim=True),
                container
            )
            ...

    def update_stats(self, means, stds, container):
        self.alpha = 0.1  # should in __init__

        new_means = self.alpha * container.get(("stats", "_action_means")) + (1 - self.alpha) * means
        new_stds = self.alpha * container.get(("stats", "_action_stds")) + (1 - self.alpha) * stds
        container.set_(("stats", "_action_means"), new_means)
        container.set_(("stats", "_action_stds"), new_stds)

to restore original behaviour, just set self.alpha=0. Do you think I should add these in the same PR, or create a new one?

FrankTianTT avatar Nov 03 '23 07:11 FrankTianTT

A new PR would defo make a lot of sense for this!

vmoens avatar Nov 03 '23 10:11 vmoens