[BUG] cellpose causing numerical issues in unrelated code in same process.
Describe the bug
the file vit_sam.py contains the line
torch.backends.cuda.matmul.allow_tf32 = True
right within the imports at the top. This is globally changing the pytorch configuration, which can cause numerical accuracy issues in completely unrelated pytorch code which somehow directly or indirectly imports this file.
See https://docs.pytorch.org/docs/stable/notes/numerical_accuracy.html#tensorfloat-32-tf32-on-nvidia-ampere-and-later-devices.
I would recommend to make this configurable or an explicit choice in some way. According to https://github.com/pytorch/pytorch/issues/69921 a context manager might not be the way to do it unless you can guarantee single-threadedness.
@klondenberg-bioptimus Have you experimented with this and found this to be an issue?
We recently changed the default model weight prevision to bfloat16 so I doubt this would have a significant impact for most people
Hi @mrariden, this definitely caused problems, yes. We spent quite some time hunting down completely unrelated test failures in a large codebase of ours, which were caused by this line.
The problem is not in how it affects cellpose, the problem is that this is a process-wide change that affects every pytorch-based code in the same process.
In our case, we were running a few hundred tests in CI using the pytest framework. One of the tests was using cellpose ( and that test passed ) but it indirectly seems to have imported this file. So all subsequent tests in the same process now had the allow_tf32 setting set to True, which basically means that matrix multiplications in Pytorch are done with reduced precision on modern GPUs ( but not on older ones that don't support TF32). This led to numerous test failures in this unrelated code where the root cause was hard to find. Whether or not cellpose uses 32 bit floats at all does not matter for these failures.
If you use bfloat16 now, the setting is unneccessary in any case, and the fix would be to simply remove it without replacement.
As it seems to be in a similar vein, I have also experienced instabilities and some hard-to debug race conditions when trying to run the multiple cellpose models on different threads and/or devices. This did not happen with the cellpose 3 nor with other neural network architectures, but so far it is hard to replicate, though it always ends up happening, it happens faster the more threads are being used.