[IDEA] Easiest way to implement Hindsight Relabeling?
Hi all,
I was wondering if anyone had any thoughts about ways to implement Hindsight experience replay relabeling in to a replay buffer for Goal Conditioned implementations. In essence, it would look a lot like the trajectory buffer with a few key differences:
- We insert (in a batched format) trajectories from an agent's experience that are labeled with (s, a, s', r, g)
- (either at sample time or insertion time) we extract the goals achieved during the trajectory and populate the replay buffer with n copies of the trajectory, each with modified g values as if the agent "intended" to accomplish those goals
- The goal relabeling often requires the entire trajectory for context (e.g. you might need to look ahead to the future to know what to label the current state with).
It would be nice if you didn't have to relabel every time you sample, but instead re-insert relabeled trajectories into the replay buffer to be further sampled. Also, while we need the full trajectory for relabeling context, it might not be needed for actual learning algorithm (.e.g if we were doing a goal-conditioned DQN).
Does anyone have any thoughts about how to best go about implementing something like this using flashbax?
Thanks in advance!
I would start from the trajectory buffer class and build from there. @EdanToledo do you have any advice on this?
I haven't read the hindsight relabeling paper so there might be context i am missing but this sounds achievable just with the trajectory buffer and no extra functionality. Correct me if im wrong here:
We insert (in a batched format) trajectories from an agent's experience that are labeled with (s, a, s', r, g) (either at sample time or insertion time) we extract the goals achieved during the trajectory and populate the replay buffer with n copies of the trajectory, each with modified g values as if the agent "intended" to accomplish those goals The goal relabeling often requires the entire trajectory for context (e.g. you might need to look ahead to the future to know what to label the current state with).
the goal field would just be part of the pytree the buffer is initialised with. When inserting the data, you just do data manipulation before insertion. Make copies of the trajectory, alter the goals, and insert it as a batch then.
The goal relabeling often requires the entire trajectory for context (e.g. you might need to look ahead to the future to know what to label the current state with).
If doing the relabeling at insertion time then you have the entire trajectory. At sample time it becomes more difficult especially since right now we have no functionality to modify individual data items in the buffer once inserted.
@EdanToledo @SimonDuToit Thanks a lot for your comments here.
@EdanToledo, I am a bit confused by what you said about having the entire trajectory at insertion time. I was under the impression that you insert length $n$ trajectory segments, not full trajectories. Then:
(from notebook on trajectory buffer) > "The trajectory buffer receives batches of trajectories, saves them while maintaining their temporal ordering, which allows sampling to return trajectories also."
Which I took to mean: we stitch together the trajectory segments, keeping their order the same, so that you can sample the full trajectory. In other words, I'm asking whether you can sample a longer trajectory segment than you insert. If yes, then I don't see how you have access to the full (i.e. from start state to end state, a value larger than the insertion length), at insertion time. But you could probably get it at sample time, with a large enough sample length. Right?
My ultimate need is something like the following: You collect rollouts from the env for the full trajectory (or, due to jaxness, a fixed number of time steps that is larger than the full trajectory length, and pad with new env starts), add these to buffer w/ a variety of relabelings (maybe 10:1 or something) and can sample these. However, I'm using craftax which has a very large observation space, and so the full trajectory takes up too much GPU memory to process, so rolling out for a few (8) time steps, send to CPU replay buffer, then rolling out for 8 more, stitching together to recover the full trajectory, seems to be the best option.
Thanks a lot for any advice and for correcting me where needed!