imitation
imitation copied to clipboard
Add support for stateful policies
rollout.PolicyCallable takes an observation and outputs an action. This only supports stateless policies. By contrast, BasePolicy.predict takes an observation, mask (is it terminal observation) and state (which is reset when mask is true), supporting stateful policies like LSTMs.
We should make PolicyCallable also task mask and state to be more flexible. This would be a fairly easy change to make, as its only used at a handful of call sites.
In particular, taking rolllouts of TabularPolicy always uses the timestep-0 policy because of this limitation.