sbi icon indicating copy to clipboard operation
sbi copied to clipboard

refactor: code repetition in train methods

Open janfb opened this issue 1 year ago • 1 comments

Description:

The current implementation of the SBI library contains significant code duplication within the train(...) methods of SNPE, SNRE, and SNLE. These methods share many common functionalities, including:

  • Building the neural network
  • Resuming training
  • Managing the training and validation loops

This redundancy increases the complexity of the codebase, making it harder to maintain and more prone to inconsistencies and bugs, particularly during updates or enhancements.

To address this, we propose refactoring these methods by introducing a unified train function in the base class. This common train function would handle the shared aspects of the training process, while accepting specific losses and other relevant keyword arguments as parameters to handle the differences between SNPE, SNRE, and SNLE.

Example redundancies

  • SNPE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snpe/snpe_base.py#L340-L379
  • SNLE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snle/snle_base.py#L214-L244
  • SNRE: https://github.com/sbi-dev/sbi/blob/9e224dadd7af9a0b431880bad180e500e09c3200/sbi/inference/snre/snre_base.py#L228-L260

Proposed Steps

  • [ ] Identify and abstract the common code segments across the train methods of SNPE, SNRE, and SNLE.
  • [ ] Design a generic train function in the base class that accepts specific losses and other necessary arguments unique to each method. Parts shared by some, but not all methods, should be offloaded into separate class methods that can be overridden by children's classes if required.
  • [ ] Refactor the existing train methods to utilize the new generic function, passing their specific requirements as arguments.

We encourage contributors to discuss strategies for this refactoring and help with the implementation. This effort will improve the library’s maintainability and ensure consistency across its components.

If you identify other areas where significant code duplication can be reduced, please create a new issue (e.g., #921).

janfb avatar Jan 24 '24 16:01 janfb

This will become even more relevant when we have a common dataloader interface and agnostic loss functions for all SBI methods. But I am removing the hackathon label for now as it will not be done before the release.

janfb avatar Jul 22 '24 07:07 janfb