dianna
dianna copied to clipboard
Generate masks on-the-fly in batches
When using many masks, pre-generating them all could use too much memory. In the dianna-deeprank branch (see https://github.com/dianna-ai/dianna/blob/c1823d7278be7602cba19ac56e69c5f13ec5a980/dianna/methods/rise.py#L190), we implemented a way to generate masks on the fly for RISE. This may be useful to add to the general DIANNA code.
It works as follows: Masks are generated per batch. A running average is used to update the RISE output after running it with a new batch of masks. Additionally we implemented an early stopping feature. The number of masks is no longer a set number that is used, but rather a maximum. When the output of RISE doesn't change anymore (to some tolerance), it stops and reports how many masks were actually used.