DecisionTransformerInterpretability
DecisionTransformerInterpretability copied to clipboard
Vectorize get trajectory minibatches method of memory class (useful for TrajPPO model)
I recently wrote a version of get_minibatches in the memory class of the ppo subpackage.
https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/c84edb381c53b3f9ef2fa9517e34914a52e15fbd/src/ppo/memory.py#L210-L393
TLDNR: This is important for sampling sections of trajectories which is necessary for online training of trajectory models as opposed to models which only respond to the latest observation. I have a few ideas for what to do here:
Keep the logic more or less the same, but vectorize it. It's way to serialized and it doesn't have to be. Obviously write lots of tests.