Implement "Curious Replay" into Dreamer?
Hi, I recently saw the paper Curious Replay for Model-based Adaptation, which proposes a fairly straight forward curiosity-based sampling from the Replay Buffer.
Are there any plans to potentially integrate this into the SheepRL Dreamer implementations?
Thanks
Hi @defrag-bambino, that was something that we have always had in mind to implement but never did. Yes, we should definitely try to add it to SheepRL. As far as I remember we need to save the world model loss and sample given both the insertion time and the world model loss, right? Do you want/have-any-time to work on a PR?
I have only limited time. Also, I am not that familiar with the SheepRL implementation. For example, I've looked through your ReplayBuffer code but am not entirely sure where I need to change things. But maybe you can guide me a little bit.
The required changes are actually pretty straight-forward. When adding data (i.e. state transitions) to the replay buffer, we need to additionally store its priority and visit count. Upon "visiting" (i.e. sampling) this datapoint its priority is then updated using the world model loss. Here is the pseudo-code from the Curious Replay paper:
Here is the priority calculation function from the official implementation:
def _calculate_priority_score(model_loss, visit_count, hyper):
return (hyper['c'] * np.power(hyper['beta'], visit_count)) \
+ np.power((model_loss + hyper['epsilon']), hyper['alpha'])
Now, the changes we need to do in SheepRL/dreamer_v3.py are (?):
- Upon adding to the replay buffer: store initial priorities p and visit counts v in the step_data, i.e.
step_data["replay_priority"] = p_max
step_data["replay_visit_count"] = 0
- Sample from the replay buffer according to the stored priorities. Also return the indices of these samples. I'm not sure where in buffers.py and how to do this
- Train the world model using the sampled data and use the resulting loss to compute the new probabilities (do we currently get a loss for each element of the batch or only a single number?). Then set the new probabilities using the samples' indices in the replay buffer.
I'm in the process of implementing this now. Right now, dreamer_v3 uses a SequentialReplayBuffer. This will have to change. Does the implementation rely on the sampled data being sequential? Or could I use the regular ReplayBuffer class?
Hi @defrag-bambino. Yes, Dreamer-V3 needs sequential data, so you should modify the SequentialReplayBuffer
Ok. I've implement something now and it seems to work. Its not very clean yet, e.g. just hardcoded changes rather than something configurable. Here is the fork. And here is some experiments, comparing the ones in the paper's Figure A19 that had the most difference between CR and Vanilla sampling. Note that I used a replay ratio of 0.1, which I think led to the quicker convergence? Here is a full exemplary command. Each experiment ran only once and only using vector observations!
sheeprl fabric.accelerator=cuda fabric.devices=1 fabric.precision=16-mixed exp=dreamer_v3 algo=dreamer_v3_S env=dmc env.wrapper.domain_name=quadruped env.wrapper.task_name=walk algo.total_steps=1000000 algo.cnn_keys.encoder=\[\] algo.mlp_keys.encoder=\["state"\] algo.cnn_keys.decoder=\[\] algo.mlp_keys.decoder=\["state"\] env.num_envs=16 num_threads=16 checkpoint.every=1000 metric.log_every=1000 algo.replay_ratio=0.1 env.screen_size=256
Vanilla DreamerV3 is blue, with Curious Replay orange.
On Quadruped-Walk and and Hopper-Hop it performs better than the baseline, exactly as in the paper. Interestingly, on Pendulum-Swingup it does not perform worse, as in the paper, but better. Even more so on Cartpole-Swingup-Sparse, where the baseline fails to achieve any progress. However, this may likely be due to only running one experiment each. Also, these swingup tasks are probably too easy and not what Dreamer is made for. We'll definetly have to test it on more complex envs, such as Crafter.
This is awesome! I'll try it out on lightning.ai with crafter in the next few days! Thank you @defrag-bambino
Hi @defrag-bambino, sorry for the laaate response. Have you managed to implement this in the buffers.py? I would create a new buffer class extending from the SequentialReplayBuffer by changing the sample() method or better the _get_samples() one
I think so, yeah. I don't quite remember. But its probably not a seperate class, rather just proof-of-concept edited into the SequentialReplayBuffer. You can check it out in my fork. I'll probably get back to this once I finish something else, but that'll be another 2-3 months.