boltz icon indicating copy to clipboard operation
boltz copied to clipboard

low GPU utilization on multi-GPU inference due to synchronization

Open timodonnell opened this issue 7 months ago • 2 comments

I mentioned this to @eric-czech in my org and he offered to fix this, putting details for Eric here:

Boltz1 does multi-GPU inference using a pytorch lightning Trainer. Each prediction is run on one GPU, and multiple GPUs are used to run multiple predictions at once (batch size = 1).

However, the Trainer forces a synchronization point after each set of N predictions (where N Is number of GPUs). If some predictions take longer to generate than others, all GPUs wait until the slowest prediction is finished before they all go on to the next set.

Seems like a good speed improvement could be had by doing something like spawning a process per device and reading work off a shared multiprocessing.Queue.

To start to put together a test case for this, I made copies of two different input yamls, one that does a large prediction (slow) and another that should be fast (because it disables msa):

$ ls yamls
001-small.yaml	004-large.yaml	005-small.yaml	007-large.yaml	008-small.yaml	010-large.yaml	011-small.yaml	013-large.yaml	014-small.yaml	016-large.yaml	017-small.yaml	019-large.yaml	020-small.yaml
002-large.yaml	004-small.yaml	006-large.yaml	007-small.yaml	009-large.yaml	010-small.yaml	012-large.yaml	013-small.yaml	015-large.yaml	016-small.yaml	018-large.yaml	019-small.yaml	021-large.yaml
003-small.yaml	005-large.yaml	006-small.yaml	008-large.yaml	009-small.yaml	011-large.yaml	012-small.yaml	014-large.yaml	015-small.yaml	017-large.yaml	018-small.yaml	020-large.yaml

$ cat yamls/001-small.yaml
version: 1
sequences:
  - protein:
      id: A
      sequence: AQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAA
      msa: empty

$ cat yamls/002-large.yaml
version: 1
sequences:
  - protein:
      id: A
      sequence: YFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQADRVSLRNLRGYYNQ
  - protein:
      id: B
      sequence: MARSVTLVFLVLVSLTGLYAIQKTPQIQVYSRHPPENGKPNILNCYVTQFHPPHIEIQMLKNGKKIPKVEMSDMSFSKDWSFYILAHTEFTPTETDTYACRVKHASMAEPKTVYWDRDM

If you then run this on a machine with multiple GPUs:

time boltz predict yamls --use_msa_server --override

You can see frequently only a small number of GPUs being used at once, which seems likely (but I haven't confirmed) due to the synchronization issue described here.

The first time you run the above command it will do a slow initial step to query a server for MSAs (and also download the model weights), so it needs to run once first to do that before any timings can be trusted.

timodonnell avatar May 30 '25 15:05 timodonnell

Are you sure about this? I'm pretty sure lightning doesn't put a sync barrier in predict_step

jwohlwend avatar May 30 '25 15:05 jwohlwend

No, not sure. @gcorso mentioned that he thought this was how it worked but I haven't confirmed it (also possible I misunderstood what he meant)

timodonnell avatar May 30 '25 16:05 timodonnell