flashbax
flashbax copied to clipboard
Index trajectories at particular index from buffer
Can we get trajectories at particular index from the TrajectoryBuffer or ItemBuffer. Something like buffer.get_trajectory(state, trajectory_index), since buffer.sample(state, rng_key) returns a random slice from the buffer.
You would have to specify the batch index as well as the starting time index but this is pretty easy functionality to add.