low GPU utilization on multi-GPU inference due to synchronization
I mentioned this to @eric-czech in my org and he offered to fix this, putting details for Eric here:
Boltz1 does multi-GPU inference using a pytorch lightning Trainer. Each prediction is run on one GPU, and multiple GPUs are used to run multiple predictions at once (batch size = 1).
However, the Trainer forces a synchronization point after each set of N predictions (where N Is number of GPUs). If some predictions take longer to generate than others, all GPUs wait until the slowest prediction is finished before they all go on to the next set.
Seems like a good speed improvement could be had by doing something like spawning a process per device and reading work off a shared multiprocessing.Queue.
To start to put together a test case for this, I made copies of two different input yamls, one that does a large prediction (slow) and another that should be fast (because it disables msa):
$ ls yamls
001-small.yaml 004-large.yaml 005-small.yaml 007-large.yaml 008-small.yaml 010-large.yaml 011-small.yaml 013-large.yaml 014-small.yaml 016-large.yaml 017-small.yaml 019-large.yaml 020-small.yaml
002-large.yaml 004-small.yaml 006-large.yaml 007-small.yaml 009-large.yaml 010-small.yaml 012-large.yaml 013-small.yaml 015-large.yaml 016-small.yaml 018-large.yaml 019-small.yaml 021-large.yaml
003-small.yaml 005-large.yaml 006-small.yaml 008-large.yaml 009-small.yaml 011-large.yaml 012-small.yaml 014-large.yaml 015-small.yaml 017-large.yaml 018-small.yaml 020-large.yaml
$ cat yamls/001-small.yaml
version: 1
sequences:
- protein:
id: A
sequence: AQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAA
msa: empty
$ cat yamls/002-large.yaml
version: 1
sequences:
- protein:
id: A
sequence: YFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQADRVSLRNLRGYYNQ
- protein:
id: B
sequence: MARSVTLVFLVLVSLTGLYAIQKTPQIQVYSRHPPENGKPNILNCYVTQFHPPHIEIQMLKNGKKIPKVEMSDMSFSKDWSFYILAHTEFTPTETDTYACRVKHASMAEPKTVYWDRDM
If you then run this on a machine with multiple GPUs:
time boltz predict yamls --use_msa_server --override
You can see frequently only a small number of GPUs being used at once, which seems likely (but I haven't confirmed) due to the synchronization issue described here.
The first time you run the above command it will do a slow initial step to query a server for MSAs (and also download the model weights), so it needs to run once first to do that before any timings can be trusted.
Are you sure about this? I'm pretty sure lightning doesn't put a sync barrier in predict_step
No, not sure. @gcorso mentioned that he thought this was how it worked but I haven't confirmed it (also possible I misunderstood what he meant)