openpi icon indicating copy to clipboard operation
openpi copied to clipboard

Support for weighted dataset mixtures

Open akuramshin opened this issue 6 months ago • 0 comments

Created a new MixtureDataset Dataset type that supports weighted sampling from a list of datasets. Inspired by the octo codebase that uses TensorFlow's sample_from_datasets to create oxe mixtures.

Since torch does not have a sample_from_datasets equivalent, I recreate a minimal implementation using numpy's RandomState to ensure a reproducible sampling order based on the seed.

There is a design decision when saving norm stats for the checkpoints. For now, I have decided to save the stats from one of the datasets in the mixture. If the user wants to use different norm stats for inference, they can pass that into _policy_config.create_trained_policy().

I added tests in data_loader_test.py, but let me know if these are insufficient and more should be added.

I also added an accumulation_steps parameter to the AdamW optimizer factory to support gradient accumulation. I saw here that you guys had plans to add grad accumulation, but I was not able to find it. Let me know if this should not be part of this PR.

akuramshin avatar May 10 '25 19:05 akuramshin