openfold icon indicating copy to clipboard operation
openfold copied to clipboard

Training duration & NaNs during training

Open lhatsk opened this issue 2 years ago • 27 comments

First of all, great work!

I'm wondering what training times I can expect for a single target. I'm currently at 1min/ it (sample) which seems too slow (v100s with fp16 and deepspeed activated, crop size 256). The official implementation takes around 20sec for a comparable sample (single GPU, about 16s with an A100). Haven't tested how much of an overhead is introduced by deepspeed. Gradient accumulation should help to reduce this.

Is it actually possible to train batch size > 1 on a single GPU? I'm assuming it would work with fixed_size=True. I just vaguely remember that they did some dimensionality juggling with the template/ recycling dimensions which might interfere.

Thanks!

lhatsk avatar Nov 22 '21 20:11 lhatsk

Are you using PyTorch Lightning's built-in timer? If so, it's a running average of total time elapsed between iterations, including time spent loading data. Usually for me, the time starts in the 40-second range (as the dataloader front-loads a couple of batches) and then dwindles to about ~17 seconds per iteration for crops of size 256 on our 11GB 2080 Ti's.

If iterations are still much slower for you and you have more video RAM than we do, you might want to disable a recent change we made to reduce memory fragmentation (the biggest downside of using PyTorch as opposed to compiled JAX, in our case): in openfold/config.py, in the section for the extra MSA stack, disable clear_cache_between_blocks. That will give you a (probably modest) speedup. Playing around with your DeepSpeed config (e.g. disabling CPU offloading) might also net you some performance. Finally, running the training script with the --benchmark flag might also help. If issues persist, I'd be interested in seeing a profile of module runtimes & your config, if you've changed it from the default.

Where are you getting the 20-second estimate for the official implementation, by the way? The official implementation doesn't support training.

The code fully supports training with larger batch sizes (with even more dimensionality juggling), so you can give it a try (again with a tweak in the config). On our cards, though, even one crop of size 256 per GPU at a time is kind of pushing it.

gahdritz avatar Nov 22 '21 20:11 gahdritz

Thanks for the quick response! I will give the tweaks a try. I didn't change the deepspeed config. I'm using SLURM to train on 4 V100s in parallel. The time measurement is based on the SLURM output. I don't have new estimates. It looks like training died over night (?). No updates in 7 hours: Epoch 0: 37%|█�WARNING:root:distogram loss is NaN. Skipping..., loss=3.85, v_num=2.06e+7]

I just tested it on a single GPU without deepspeed and I get down to 10.5s / it on a V100s. 9.5s / it without clear_cache_between_blocks.

Do you always train with fp16? I have had big problems making it stable, wondering if this is the cause of NaNs.

I have cards with 32 and 40GB VRAM, training with batch size = 2 works, but for some reason it's much slower. Looks like data loader workers are blocking. That also seems to be the main issue, data loading/ preprocessing is slowing everything down. Some alignments are multiple GBs in size, parsing/ clustering/ subsampling will take some time. I will truncate them.

Regarding the 20-second time estimate for the official implementation, I re-created the necessary features for training and built a bare bone training loop to play with optimizations. But I just ran experiments on a single target to see how far I will get in terms of training time. Looks like I was still way off. Getting FP16 to work was a pain (still have NaNs, bfloat16 works fine). But it was also my first experience with JAX, which I was happy to abandon in favor of openfold and PyTorch...

lhatsk avatar Nov 23 '21 09:11 lhatsk

I'm sorry to hear that you're getting NaNs.

Two thoughts:

  1. The model should be written in such a way that it just skips over examples that yield NaN loss, so those shouldn't interfere with training, let alone cause a crash.
  2. NaN examples should be quite rare---we haven't seen any for a while, despite testing the model on a wide range of proteins from the official training set and always running the model with fp16. Are you able to determine where in the model the loss is being pushed over the edge, or where it might be hanging?

As for the dataloading, here are a few things to try.

  1. Are you running a recent commit? A while ago we massively sped up the dataloader, so it should no longer be a bottleneck, unless you're using much larger MSAs than typically come out of the database search from the original.
  2. Try increasing the number of dataloader workers in the config, if you have the CPUs for it.

Finally, what is your single-GPU time with DeepSpeed? DDP training without DeepSpeed is pretty thorny to get working without it (PyTorch's native DDP implementation doesn't interact well with activation checkpointing), so we typically run the model with DeepSpeed enabled.

gahdritz avatar Nov 23 '21 17:11 gahdritz

Thanks for getting back to me!

Regarding the NaNs: I haven't been able to investigate further (no access atm, will do once I fix the other issues)

Dataloading is indeed the issue. I'm running the latest commit. I updated now to pytorch-lightning 1.5.2, which seems to be a tad faster (https://issueexplorer.com/issue/PyTorchLightning/pytorch-lightning/10389). I truncate the MSAs now to 12k entries and I dumped the raw data to disk (process_mmcif is skipped now). It saves a second or so per iteration.

The problem seems to be that the workers are not used (num_workers is set to 8). I see that multiple processes are spawned, but utilization in htop remains around 1 and all but one are at sleep. That's why using batch_size=2 is twice as slow, data loading is still somewhat sequential. Very strange. I haven't used pytorch-lightning before, never had issues with multiple workers in vanilla PyTorch.

With dumping the data to disk, I can get down to 8.5sec per sample on a single GPU (V100s). 15-16s for batch_size=2.

lhatsk avatar Nov 23 '21 17:11 lhatsk

Hm. Peculiar. The num_workers parameter is handed straight to a native PyTorch DataLoader, so I can't say off the top of my head why that might be happening. I'll look into whether I'm getting good worker thread utilization later today.

gahdritz avatar Nov 23 '21 17:11 gahdritz

Sorry for the delay. Despite testing with a number of values of batch_size and num_workers, I am unable to reproduce the behavior you described. How are you changing the batch size?

gahdritz avatar Nov 25 '21 04:11 gahdritz

No problem! I set batch_size = 2 in config.py. What speedup do you see with larger batches?

3 epochs in, I get NaN distogram losses again and it looks like the network can't recover. Every sample has a NaN loss afterwards and the overall loss is pretty much stuck. NaN caused by the distogram seems very strange, distances exploding and exceeding fp16? I will try to catch it, unfortunately I takes hours to get to this point.

You mentioned earlier the official training set, where can I find the instances? Is the train/ validation split known? I'm currently working on the trRosetta training set.

lhatsk avatar Nov 25 '21 09:11 lhatsk

With enough workers, the speedup is about linear in the size of the batches.

scripts/download_all_data.sh downloads the AlphaFold training set. The validation set is from CAMEO.

What exactly happens when you get NaN loss? The model should simply print that NaN loss occurred and skip the example in question, such that the weights are never updated with NaN gradients (this behavior is defined at the bottom of openfold/utils/loss.py). If this weren't just happening in a single loss, suggesting to me that the NaNs are arising during the forward pass and should theoretically be caught by our loss-skipping hack, I'd advise you to play around with the loss_scale_window and min_loss_scale parameters of the fp16 category of your DeepSpeed config (details here). Since it is just one, though, I'm kind of confused by this (maybe the projection weights in the distogram head are becoming NaN somehow? Seems unlikely to me). The only think I can say is that, if it had to happen in just one loss, it would be this one---the squared sum of differences during the distance calculation tends to be one of the largest intermediate activations in the network. If you could checkpoint the model after the second epoch or something and gather more data on this, that would be great.

gahdritz avatar Nov 25 '21 18:11 gahdritz

So actually, I see both happening in different runs. In my latest run all losses go to NaN (already in the first epoch after 1600 samples):

Epoch 0: 30%|██████████████████████████████ | 812/2733 [2:23:33<5:39:36, 10.61s/it, loss=4.12, v_num=28]NaN 6832.0 WARNING:root:distogram loss is NaN. Skipping... WARNING:root:fape loss is NaN. Skipping... WARNING:root:lddt loss is NaN. Skipping... WARNING:root:masked_msa loss is NaN. Skipping... WARNING:root:supervised_chi loss is NaN. Skipping... Epoch 0: 30%|██████████████████████████████ | 813/2733 [2:23:38<5:39:13, 10.60s/it, loss=3.93, v_num=28]NaN 3872.0 WARNING:root:distogram loss is NaN. Skipping... WARNING:root:fape loss is NaN. Skipping... WARNING:root:lddt loss is NaN. Skipping... WARNING:root:masked_msa loss is NaN. Skipping... WARNING:root:supervised_chi loss is NaN. Skipping... Epoch 0: 30%|██████████████████████████████ | 814/2733 [2:23:43<5:38:50, 10.59s/it, loss=3.71, v_num=28]

But once it starts skipping samples, all following samples also result in NaNs. Same happened when only the distogram loss became NaN.

I activated anomaly detection but it's not triggered...

lhatsk avatar Nov 26 '21 19:11 lhatsk

The layer norm in the first triangular update causes the NaN in my case: https://github.com/aqlaboratory/openfold/blob/main/openfold/model/triangular_multiplicative_update.py#L79

related to this? https://github.com/pytorch/pytorch/issues/66707

I thought LayerNorm (same as BatchNorm) will be in fp32

edit: max value in z is inf already before layer_norm_in inf is the result of outer_product_mean

lhatsk avatar Nov 28 '21 18:11 lhatsk

I've figured out what's happening.

Like you said, at some point in the model activations become NaN. In general, for certain inputs, this seems unavoidable---it's a fundamental limitation of fp16 training in a model this large. However, the impact of such inputs can be lessened if it's possible to skip arbitrary training examples. In vanilla PyTorch Lightning, this can easily be accomplished by returning None from the training step. With DeepSpeed enabled, this is not allowed.

The workaround we're currently using is to just return 0 in cases where the loss is NaN. In most cases, this works (although it does screw with average loss calculation). However, when this happens in just one rank of DeepSpeed distributed training, training hangs; DeepSpeed expects all ranks to update the same subset of the model's parameters, which obviously isn't the case if one of the ranks returns a fresh zero tensor. This means that, if skipping the backward pass for single ranks outright isn't possible, we need to find some way to generate a zero tensor that nevertheless interacts with all of the parameter tensors the loss would usually interact with.

I haven't found a good workaround yet, but the following is technically functional (albeit slow). Replace training_step in train_openfold.py with:

def _run_training_step(self, batch):
    # Run the model
    outputs = self(batch)
    # Remove the recycling dimension
    batch = tensor_tree_map(lambda t: t[..., -1], batch)
    # Compute loss
    loss = self.loss(outputs, batch)
    return loss

def training_step(self, batch, batch_idx):
    if(self.ema.device != batch["aatype"].device):
        self.ema.to(batch["aatype"].device)

    # Do a test run without grad enabled to see if loss is NaN...
    with torch.no_grad():
        loss = self._run_training_step(batch)

    loss_is_valid = not (torch.isnan(loss) or torch.isinf(loss))

    # If loss is NaN, zero the inputs to tamp down activations
    if(not loss_is_valid):
        logging.warning("loss is NaN. Returning 0 loss...")
        for k in batch:
            batch[k] *= 0

    loss = self._run_training_step(batch)

    # If loss is supposed to be NaN, we don't want to update the weights
    if(not loss_is_valid):
        loss = loss * 0

    return {"loss": loss}

Additionally, delete the lines that zero individual losses near the bottom of openfold/utils/loss.py (the block that starts if(torch.isnan...). It works in all situations that I can think of, but it requires an extra grad-less forward pass through the model, which takes an additional ~6 seconds per iteration on my hardware. In theory, it should be possible (easy, even) to run the real forward pass, check the outputs, and then "reset" the model for a second, zero'd forward pass only if necessary, meaning that you'd only have to endure the slowdown for "bad" inputs. However, I still haven't figured out how to do that without opening up DeepSpeed's internals, which I really don't want to have to do. In any case, hopefully this works as a stopgap while I figure it out.

Alternatively, if you have the hardware for it, I should note that this problem can be completely sidestepped by training with bfloat16 precision, for which I just added support. I don't have the GPUs to test it myself, but feel free to be our guinea pig.

gahdritz avatar Nov 29 '21 23:11 gahdritz

Thanks for the workaround!

I think I managed to solve the NaN issue in my case (successfully finished one epoch now, used to explode at around 50% of the first). I keep outer_product_mean in fp32.

I replaced layer_norm and linear in OuterProductMean with FP32 versions (init for linear_out is still missing!):

class Fp32LayerNorm(nn.LayerNorm):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, input):
        output = F.layer_norm(
            input.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(input)
class Fp32Linear(nn.Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, input):
        output = F.linear(
            input.float(),
            self.weight.float(),
            bias = self.bias.float() if self.bias is not None else None,
        )
        return output.type_as(input)
class OuterProductMean(nn.Module):
[..]
        self.layer_norm = Fp32LayerNorm(c_m)
        self.linear_1 = Fp32Linear(c_m, c_hidden)
        self.linear_2 = Fp32Linear(c_m, c_hidden)
        self.linear_out = Fp32Linear(c_hidden ** 2, c_z) #, init="final")

and casted explicitly here:

    z = z + self.outer_product_mean(
        m.float(), mask=msa_mask.float(), chunk_size=chunk_size
    ).half()

I don't quite understand why .half() works and .type_as() doesn't. Will report back tomorrow if it indeed fixes the issue.

The batch / worker issue remains though.

lhatsk avatar Nov 30 '21 19:11 lhatsk

I still haven't been able to reproduce the batch/worker issue.

gahdritz avatar Nov 30 '21 19:11 gahdritz

No NaNs in 5 epochs, no visible slowdown.

batch size = 2 is ~10% faster than batch size = 1. CPU utilization is at 1.8. Do you get the full 8 / 800% utilization?

Do you already have an intuition about what losses we can achieve/ should aim for? Didn't see anything lower than 3.29 so far.

lhatsk avatar Dec 01 '21 15:12 lhatsk

I usually get around 500-700% utilization.

I've done small overfitting experiments, and I bottomed out near zero loss.

gahdritz avatar Dec 01 '21 19:12 gahdritz

All losses go to NaN again after 3.5 days of training. Not sure where it happens, could be because of all the other fp16 LayerNorms. fp16 is a real pain. I think I will launch another run with all of the other LayerNorms replaced as well.

lhatsk avatar Dec 06 '21 10:12 lhatsk

Thanks for the quick response! I will give the tweaks a try. I didn't change the deepspeed config. I'm using SLURM to train on 4 V100s in parallel. The time measurement is based on the SLURM output. I don't have new estimates. It looks like training died over night (?). No updates in 7 hours: Epoch 0: 37%|█�WARNING:root:distogram loss is NaN. Skipping..., loss=3.85, v_num=2.06e+7]

I just tested it on a single GPU without deepspeed and I get down to 10.5s / it on a V100s. 9.5s / it without clear_cache_between_blocks.

Do you always train with fp16? I have had big problems making it stable, wondering if this is the cause of NaNs.

I have cards with 32 and 40GB VRAM, training with batch size = 2 works, but for some reason it's much slower. Looks like data loader workers are blocking. That also seems to be the main issue, data loading/ preprocessing is slowing everything down. Some alignments are multiple GBs in size, parsing/ clustering/ subsampling will take some time. I will truncate them.

Regarding the 20-second time estimate for the official implementation, I re-created the necessary features for training and built a bare bone training loop to play with optimizations. But I just ran experiments on a single target to see how far I will get in terms of training time. Looks like I was still way off. Getting FP16 to work was a pain (still have NaNs, bfloat16 works fine). But it was also my first experience with JAX, which I was happy to abandon in favor of openfold and PyTorch...

Hi, I recently transferred from my own server to the SLURM platform for training, but there was a problem with the deepseed+SLURM configuration. How did you configure the slurm script file?

panganqi avatar Dec 19 '21 17:12 panganqi

I use the openfold deepspeed configuration and SLURM is just:

#!/bin/bash

# SLURM SUBMIT SCRIPT
#SBATCH --nodes=4
#SBATCH --time=168:00:00
#SBATCH --gres=gpu:v100s:1
#SBATCH --ntasks-per-node=1
#SBATCH --mem=0
srun python3 train_openfold.py ...

which is launched via sbatch

lhatsk avatar Dec 21 '21 16:12 lhatsk

I'm still getting NaNs and training collapses after a while. I have now access to A100 GPUs and will switch to bfloat16.

lhatsk avatar Dec 21 '21 16:12 lhatsk

Thanks for getting back to me!

Regarding the NaNs: I haven't been able to investigate further (no access atm, will do once I fix the other issues)

Dataloading is indeed the issue. I'm running the latest commit. I updated now to pytorch-lightning 1.5.2, which seems to be a tad faster (https://issueexplorer.com/issue/PyTorchLightning/pytorch-lightning/10389). I truncate the MSAs now to 12k entries and I dumped the raw data to disk (process_mmcif is skipped now). It saves a second or so per iteration.

The problem seems to be that the workers are not used (num_workers is set to 8). I see that multiple processes are spawned, but utilization in htop remains around 1 and all but one are at sleep. That's why using batch_size=2 is twice as slow, data loading is still somewhat sequential. Very strange. I haven't used pytorch-lightning before, never had issues with multiple workers in vanilla PyTorch.

With dumping the data to disk, I can get down to 8.5sec per sample on a single GPU (V100s). 15-16s for batch_size=2.

Hello, I seem to be experiencing the same situation as you, that is, multiple processes are spawned, all but one are asleep, and after a few minutes, all are woken up.Setting batchsize to 2 is slower than batchsize=1. It would be helpful if you could tell me if you solved it, and if so how you solved the problem! @lhatsk

ZwormZ avatar Jan 18 '22 15:01 ZwormZ

Hello, I seem to be experiencing the same situation as you, that is, multiple processes are spawned, all but one are asleep, and after a few minutes, all are woken up.Setting batchsize to 2 is slower than batchsize=1. It would be helpful if you could tell me if you solved it, and if so how you solved the problem! @lhatsk

Unfortunately, I haven't been able to solve this issue. Due to memory constraints I'm currently training with batch_size=1

lhatsk avatar Jan 20 '22 10:01 lhatsk

DeepSpeed hanging on NaN is solved with a solution based on this: https://github.com/PyTorchLightning/pytorch-lightning/issues/4956#issue-755959012

def on_after_backward(self) -> None: if torch.isnan(self.trainer.callback_metrics['loss']) or torch.isinf(self.trainer.callback_metrics['loss']): logging.warning(f'detected inf or nan values in gradients. not updating model parameters') self.zero_grad()

The inf check in losses.py has to be removed.

lhatsk avatar Feb 10 '22 11:02 lhatsk

hi, I change batch_size to 2 in config.py, but it will occur to " File "/data1/zjc/openfold-install/openfold/model/model.py", line 189, in embed_templates t = t * (torch.sum(batch["template_mask"], dim=-1) > 0) RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 3". Is this a potential bug?thanks in advance.

gofreelee avatar Aug 11 '22 09:08 gofreelee

That's issue #197. I'll be fixing it soon.

gahdritz avatar Aug 11 '22 21:08 gahdritz

Is training on bigger batches supported in OpenFold-1.0? experimentally resolved loss and violation loss complain with a bigger batch size that the dimensionality is wrong.

I'm also wondering if the sum in experimentally_resolved_loss is correct. Wouldn't you only want to sum over the last column? That's the change I did to fix the dimensionality:

loss = loss / (eps + torch.sum(atom37_atom_exists, dim=-1))

I also had to average the violation loss in the end:

return loss.mean()

drmsd also couldn't handle the extra batch, I just removed it for now.

Unfortunately, I still don't see any reduction in training time, which is odd.

On an unrelated note. I see heavy regressions in my training from OpenFold-1.0 to my last version from some time in January (5 LDDT-CA points). So far I was unable to track down the issue, except for a slightly larger learning rate, which I changed now. Just straight ported my stuff to 1.0.

lhatsk avatar Aug 14 '22 14:08 lhatsk

Is training on bigger batches supported in OpenFold-1.0? experimentally resolved loss and violation loss complain with a bigger batch size that the dimensionality is wrong.

I'm also wondering if the sum in experimentally_resolved_loss is correct. Wouldn't you only want to sum over the last column? That's the change I did to fix the dimensionality:

loss = loss / (eps + torch.sum(atom37_atom_exists, dim=-1))

I also had to average the violation loss in the end:

return loss.mean()

drmsd also couldn't handle the extra batch, I just removed it for now.

Unfortunately, I still don't see any reduction in training time, which is odd.

On an unrelated note. I see heavy regressions in my training from OpenFold-1.0 to my last version from some time in January (5 LDDT-CA points). So far I was unable to track down the issue, except for a slightly larger learning rate, which I changed now. Just straight ported my stuff to 1.0.

Hello, I did as you said, but I finally meet this bug: raise RuntimeError("grad can be implicitly created only for scalar outputs") RuntimeError: grad can be implicitly created only for scalar outputs

How should I modify it, please

gofreelee avatar Aug 20 '22 13:08 gofreelee

I just want to provide some input since it has been a few months from the last posting here. It seems to me that in its current state, the code really doesn't support batch_size > 1. I know Gustaf mentioned in other threads that it should theoretically work, but I think there are some blocking issues with the code at the moment that assume batch_size=1. Of course if anyone knows that I am mistaken, please correct me!

There are some issues in experimentally_resolved_loss and drmsd/lddt_ca as noted by the comments above. I also ran into issues with data collation, where the OpenFoldBatchCollator tried to stack 2 sequence tensors in a batch in its call to dict_multimap, but the sequences were of different dimensionality so the stack failed. It's not clear to me where any padding of different length proteins occurs in the code, or if this issue was more related to some malformed CAMEO validation data, so I couldn't really find a solution to that bug at the moment. I'll be sticking to batch_size=1 for now, and hopefully my comment might help others who are looking into the same thing.

jonathanking avatar Mar 15 '23 14:03 jonathanking