D4RL icon indicating copy to clipboard operation
D4RL copied to clipboard

[Bug Report] Discrepancies between `get_dataset` and `qlearning_dataset`

Open vmoens opened this issue 1 year ago • 1 comments

I'm having a hard time figuring out how qlearning dataset is being built. As mentioned by @odelalleau in https://github.com/Farama-Foundation/D4RL/issues/182, the "terminals" key in some env is never True.

Moreover,

>>> dataset1 = env.get_dataset()
>>> dataset2 = d4rl.qlearning_dataset(env)
>>> dataset3 = d4rl.qlearning_dataset(env, terminate_on_end=True)
>>> # obs 91 matches
>>> dataset1["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset3["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> # obs 92 does not
>>> dataset1["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][92]
array([ 2.9757538 ,  2.9996927 ,  2.674898  , -0.15739861], dtype=float32)
>>> dataset3["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)

I'm a bit puzzled by what this means. It's like qlearning_dataset gives me a dataset where each consecutive step is different, but qlearning_dataset(..., terminate_on_end=True) gives me smth similar to get_dataset where some consecutive steps are identical. What should we do with this?

When does one trajectory stops, when does the other starts?

vmoens avatar Mar 10 '23 10:03 vmoens

For what it's worth, here's my takeaway of how this works:

  • qlearning_dataset(), by default, gets rid of the timeouts by ignoring the corresponding transitions. The reason for this is that typically you would have a transition of the form (s, a, r, s', timeout=True) where s' is actually the first step of the next episode. Such a transition is thus "invalid" and it is thrown away. This is fine when you're doing 1-step Q-learning based on the resulting dataset, but be careful if you intend to do multi-step Q-learning or anything else looking further down the trajectory, because it means you will be switching between episodes with no way to know (the done flag won't be set here).
  • qlearning_dataset(terminate_on_end=True) will keep this invalid transition, but be aware that it will not set the done flag (contrary to what the docstring claims). So in general it's a pretty bad idea, except for datasets with fake timeouts like maze2d where the episode doesn't actually end on timeout.
  • One consequence of the above is that the done flag is set only when terminal == True. As you noticed, some datasets don't have any terminals (ex: maze2d, which is actually a single trajectory). One thing to be aware of is that next observation s' will be invalid when done == True since it will be the first state of the next episode (which in general does not matter in Q-Learning since we don't bootstrap when done == True, but if you're doing something else this may matter). Some datasets provide a next_observations field that can be used to access the last observation (both on timeout and terminal) but the qlearning_dataset() function doesn't use it.
  • It is important to realize that neither timeout == True nor terminal == True indicates with certainty that an episode has ended! I already gave the example of maze2d for the former, and a typical example of the latter is antmaze, where all states with a reward are marked with terminal == True even though the episode continues until a timeout is reached! The interpretation of terminal is thus "if I were to reach this state during evaluation, the episode would end" (antmaze finishes as soon as you reach a reward at test time), rather than "the episode has ended in the offline dataset" (otherwise you will end up with tons of 1-step episodes).

My recommendation is that unless you're doing standard 1-step Q-Learning, you should write your own function to build the dataset you need instead of relying on qlearning_dataset(), so you can decide exactly how to handle timeouts, terminals and the last observation that may be missing, and all of this in a dataset-dependent manner.

odelalleau avatar Mar 10 '23 13:03 odelalleau