lfads-torch
lfads-torch copied to clipboard
Out of GPU memory when running posterior sampling post-training
https://github.com/arsedler9/lfads-torch/blob/6f58d249e1334c272146938aa1b3b27686580588/lfads_torch/post_run/analysis.py#L55
Hi!
I'm analysing a very large dataset, and encountered a problem where during the generation of the posterior sample (post training) I kept running out of GPU memory. I narrowed it down to the fact that all of the results are accumulated in the GPU RAM, and only being transferred to the CPU RAM after analysing all the trials. Solved by changing the line above to:
return [(s / num_samples).cpu() for s in sums]