pytorch-lightning
pytorch-lightning copied to clipboard
Stream outputs from Trainer.predict()
Description & Motivation
I would like to request a feature that allows streaming the outputs from Trainer.predict() so that they can be processed one by one. This would enable more efficient handling of predictions, especially for large datasets.
Pitch
It would be perfect if Trainer.predict() could just yield intermediate results.
Alternatives
Post-process the results in prediction_step(). However, it would be nice to have the flexibility to also do this outside of prediction_step(), e.g., if you have different types of aggregations.
Additional context
In my use case, the activations of a certain hidden layer are sparse, and I would like to collect the sparsified activations to reduce memory usage.
cc @borda