stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

how to correctly add state representation methods such as AE?

Open angel-ayala opened this issue 9 months ago • 3 comments
trafficstars

❓ Question

Hi, I would like to thank you for the effort to keep this repo updated. I already implemented and got a little bit familiar with the ecosystem, mainly with the NN models, policy, and algos. I'm intended to extend three main existing algos (SAC, TD3, and PPO) to use some representation learning techniques such as SPR but I would like to first try a vanilla autoencoder, specifically a VAE.

I saw that I must extend each algorithm class to include the VAE model and perform joint optimization of the critic and the encoder/decoder stage in the train method, however, I was wondering if this is enough, or should I need to consider other aspects.

The AE architecture considers three functions, one online encoder and decoder, and a target encoder.

My main concerns are:

  • Critic inference using the online encoder or target encoder to process current or next observation.
  • Action inference using a detached version of the encoder to prevent gradient propagation (or make it configurable if I want to).
  • Logging the loss value of the reconstruction.
  • AE model saving and loading.

I already started and I was able to create a custom model and policy following the documentation, I really appreciate any guidance on this aspects to not make a huge mess and get unexpected outcomes.

Checklist

angel-ayala avatar Feb 05 '25 21:02 angel-ayala