Remove unnecessary deep copies in SBI training/inference workflow
Problem
The SBI codebase currently creates deep copies of neural networks and posteriors in three places:
train()returnsdeepcopy(self._neural_net)build_posterior()returnsdeepcopy(self._posterior)build_posterior()storesself._model_bank.append(deepcopy(self._posterior))
These deep copies can consume significant memory (10-100MB+ per round for modern density estimators).
Details
- Model bank is completely unused:
_model_bankis write-only - no code reads from - Neural networks aren't reused between rounds: SNPE only needs the data (
_theta_roundwise,_x_roundwise) from previous rounds, not the networks - No modification risk: Users interact with posteriors through stable APIs (
sample(),log_prob())
The deep copies appear to be legacy from when multi-round inference was handled internally. Now that users manage rounds explicitly, they serve no purpose.
Proposed Solution
- Remove
_model_bankentirely - it's likely legacy code? - Return networks/posteriors without copying:
def train(self, ...):
return self._neural_net # No deepcopy
def build_posterior(self, ...):
# Remove: self._model_bank.append(deepcopy(self._posterior))
return self._posterior # No deepcopy
Impact
- Memory savings: ~30x reduction for multi-round inference (e.g., 300MB → 10MB for 10 rounds)
- No functionality loss: All tests should pass unchanged
- Backward compatibility: We could also add
return_copy=Falseparameter to ensure backward compatibility.
Questions
- Any hidden use cases for the model bank we're missing?
- Should we make copying opt-in via parameter or just remove it?
Amazing! Love this! And yes, I think most of this is either legacy or based on decisions that were not all that much thought through. The only thing I am not sure about is:
SNPE only needs the data (_theta_roundwise, _x_roundwise) from previous rounds, not the networks.
I think SNPE-B needs the network from previous rounds for proposal log-probs, no?
Yes, on a closer look, SNPE-C needs the snapshots of the neural net weights from previous rounds when evaluating the mog-based proposal in non-atomic SNPE-C, e.g., here:
https://github.com/sbi-dev/sbi/blob/ebcd68e0c9a6772626d625a8cbfe6fcffa3f2820/sbi/inference/trainers/npe/npe_c.py#L178
and then
https://github.com/sbi-dev/sbi/blob/ebcd68e0c9a6772626d625a8cbfe6fcffa3f2820/sbi/inference/trainers/npe/npe_c.py#L432-L435
And SNPE-B needs to evaluate the log prob of the proposals from previous rounds for calculating the importance weights:
https://github.com/sbi-dev/sbi/blob/ebcd68e0c9a6772626d625a8cbfe6fcffa3f2820/sbi/inference/trainers/npe/npe_b.py#L114-L120
Then question is where we create the deepcopy, in train() or in build_posterior()?
We can definitively remove the _model_bank. Then we could return a deepcopy in train() and return the plain posterior with the deep copied net attached in build_posterior?