D4RL
D4RL copied to clipboard
Ways to recover trajectories in dataset
First, thank you for sharing the repo!
The dataset seems to consists of state-action pairs, is there a way to recover entire rollout of a policy?
I am not sure if there is a better idea to do this, but the way am extracting independent rollouts is by modifying the 'terminal' flags such that it detects when the current state doesn't correspond to the previous future state in the d4rl.qlearning_dataset
# load d4rl qlearning dataset
dataset = d4rl.qlearning_dataset(env)
# detect terminal states
terminal_state = np.where(np.all(dataset['next_observations'][:-1] != dataset['observations'][1:], axis = 1))[0]
dataset['terminals'][terminal_state] = 1
dataset['terminals'][-1] = 1
Hi,
You can try the following function, which will return an iterator through trajectories in the dataset: https://github.com/rail-berkeley/d4rl/blob/master/d4rl/init.py#L137