open_clip icon indicating copy to clipboard operation
open_clip copied to clipboard

[WIP] Support FSDP

Open mehdidc opened this issue 1 year ago • 23 comments

This PR adds FSDP (https://pytorch.org/docs/stable/fsdp.html) support for training large models than cannot fit in memory.

The code works already, but still need to be improved, so this is still a draft.

Some scaling plots with sample per sec per gpu, done on JUWELS Booster.

G-14: G-14

X-14 (15B visual, 5B text): X-14

I also tried G-14 as visual encoder together with a pre-trained T5-XXL as text encoder.

Putting again some remarks and possible improvements, discussed earlier in discord:

  • I see some hanging issues starting from large number of nodes (256 nodes, 1024 GPUs onJUWELS Booster), no single iteration, I don't see anything special on NCCL debugging info except that it does a lot of all_gather, which is expected from FSDP
  • CPU offloading and gradient checkpointing are supported
  • each time encode_image or encode_text or logit_scale were accessed without going through the forward function (so happens when clipping logit scale, or at evaluation), an exception was raised (see https://github.com/pytorch/pytorch/issues/82461 for reference). The workaround I found is to modify the forward function so that it is possible to encode both text and image (as currently done), or text only, or image only, or use it for clipping logit scale. It would be better if we find a cleaner solution. The solution provided by the issue in pytorch above is to wrap the modules (here text and image encoders) using FSDP, but we need then to change some internals, as part of the text encoder cannot be wrapped as it is an nn.Parameter, FSDP needs an nn.Module. We could use CustomTextCLIP to wrap the text encoder in its entirty as proposed by @rwightman, then we need to deal with logit_scale.
  • The list of layers to FSDP-wrap is important as it affects the peak memory (documented here https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#id2) and the layer names are model dependent. e.g. in T5, I am FSDP wrapping T5Block, and in CLIP class of I am wrapping ResidualAttentionBlock. So we need a way to parametrize this, for the moment they are hardcoded. If we just use the default auto policy of FSDP, we get OOM.

mehdidc avatar Jan 17 '23 15:01 mehdidc

One thing that will break in this implementation is the ability to set separate WD for LN parameters or the bias e.g. here - https://github.com/mlfoundations/open_clip/blob/16e229c596cafaec46a4defaf27e0e30ffcca12d/src/training/main.py#L270. Because FSDP will wrap the VisionTransformer or TextTransformer into one FSDP block, these names will not be retained and the filter will fail silently. You can print the layer names after it is wrapped in FSDP to verify this.

To make sure FSDP retains the original name, you will need to pass an additional field in the FSDP constructor called use_orig_params=True. See my discussion in PyTorch forum here - https://discuss.pytorch.org/t/setting-different-weight-decay-values-for-parameters-within-one-fsdp-unit/169862/2. This feature needs PyTorch nightly.

orchidmajumder avatar Jan 25 '23 19:01 orchidmajumder

@orchidmajumder good point, I believe that arg is also required if we want to use torch.compile with FSDP, at least in its current state

rwightman avatar Jan 26 '23 03:01 rwightman

Can you rebase on master ?

rom1504 avatar Jan 30 '23 23:01 rom1504

Thanks @orchidmajumder @rwightman will look into that ! @rom1504 just rebased.

mehdidc avatar Jan 31 '23 11:01 mehdidc

Update: layer names to FSDP-wrap are not hardcoded anymore, they can now be provided in the CLI with defaults that will work already with models we have.

mehdidc avatar Feb 02 '23 12:02 mehdidc

Update: following this thread https://github.com/huggingface/accelerate/issues/807, full / partial locking now works. Currently getting some throughput numbers with mt5-xxl-ViT-G-14

mehdidc avatar Feb 04 '23 12:02 mehdidc

Update: I mentioned earlier that training was hanging with large nodes (e.g., 256 on JUWELS Booster), after checking lower nb of nodes, it seems that the starting up phase (before displaying the first "INFO | Train Epoch") duration is long and proportional to nb of nodes, which is problematic. e.g. for 128 nodes, the starting up phase takes 24mins, and in 64 nodes it takes 11mins. Will open an issue pytorch. So it is probably not properly hanging for 256, I just did not run it for long enough, but it's a lot of time if it would take 48mins.

starting_up starting_up_2

This is GPU usage for a 128 nodes run of mt5-xxl-ViT-G-14. Small GPU usage for the first 24 mins, then it starts to have > 99% usage and that coincides with the first "INFO | Train Epoch" message in the logs.

mehdidc avatar Feb 09 '23 09:02 mehdidc

Hi @mehdidc, Base on your code I try to use ViT-e-14 model, openCLIP will hang after first epoch step with FSDP enable. Do you meet same issue?

nkflash avatar Feb 16 '23 06:02 nkflash

hey @nkflash, thanks I actually noticed that as well, even with smaller models, I am on it.

EDIT: found a fix, will push soon

mehdidc avatar Feb 17 '23 21:02 mehdidc

@nkflash pushed, could you please try again? I can confirm that it worked for me

mehdidc avatar Feb 18 '23 01:02 mehdidc

Thanks, @orchidmajumder , use_orig_params is working as expected. So with pytorch nightly, we can already use it. If we want to also support current pytorch stable version (1.13), wrapping layer norms into in their own FSDP units using the option I added --fsdp-layers-to-wrap would also work, but it won't handle other cases, e.g. biases from MLP layers, we need also to wrap them as well separately, so I am not sure we would be able to support current pytorch current stable (1.13) without more complications in the code, I think for now we just need to document that to the user (except if we find a better solution), i.e. the closest that can be done to get the "correct" behavior is to FSDP-wrap layer norms, but in that case biases from MLPs will be decayed, otherwise one needs pytorch nightly, or next stable version.

The other thing that needs to be changed is:

https://github.com/mlfoundations/open_clip/blob/6ed7dd66b52c4dfda6c6fc6bad50e5857ee63123/src/training/main.py#L271

since FSDP flattens everything, p.ndim is always < 2, so everything would be excluded in the current code, which means everything will be weight-decayed. I found out that for ViTs at least, the only additional case that p.ndim < 2 covers is visual.class_embedding (https://github.com/mlfoundations/open_clip/blob/6ed7dd66b52c4dfda6c6fc6bad50e5857ee63123/src/open_clip/transformer.py#L366), the rest is covered by the other clauses. Or, is this supposed to cover something else? @rwightman @rom1504 @mitchellnw @gabrielilharco What about making exclusions parametrizable, e.g. with regexps ? perhaps to be more explicit about the parameters to decay.

mehdidc avatar Feb 18 '23 19:02 mehdidc

the p.ndim < 2 check should also cover logit_scale

mitchellnw avatar Feb 18 '23 20:02 mitchellnw

Yes was thinking of that as well but saw that there is already 'logit_scale' in n in exclude

mehdidc avatar Feb 18 '23 20:02 mehdidc

@mehdidc the position, token class embeddings are typically not decayed as well but looks like that was never done in OpenCLIP, hrmm. I have no_weight_decay methods in timm that return lists of names to exclude from decay to cover the dim >= 2 cases like position embeddings, etc.
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L506 https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/maxxvit.py#L1229

I feel that name based decay by itself is error prone and won't generalized well to other models (like timm vision towers), the unflattened dim is the strongest signal (but sometimes you need to use names for some layers like embeddings, etc).

I feel the best approach would be to build a set of names (from both shape and name) to not decay before wrapping in FSDP when that info is available, so move the current code a bit earlier.

rwightman avatar Feb 19 '23 00:02 rwightman

On the other topic, I feel it's fine to support nightlies only, I'm already exclusively using nightlies to train the convnext models because it's the only way to get decent bfloat16 support for convolutions. I'm going to add torch.compile soon to see how that works, the nn.MHA is quite a bit faster on nightlies (has a fused kernel).

rwightman avatar Feb 19 '23 00:02 rwightman

@rwightman Thanks for the suggestion, I moved the code a bit earlier now, now it is fixed.

mehdidc avatar Feb 19 '23 10:02 mehdidc

Update:@rwightman @rom1504 @mitchellnw @gabrielilharco @JeniaJitsev just for info, regarding the starting up phase I mentioned earlier (https://github.com/mlfoundations/open_clip/pull/358#issuecomment-1423851399), I found out that it is not only proportional to nb of nodes but also model size, but found a fix. Read below if you want more info.

So e.g. with 256 nodes on JUWELS Booster, it took 13mins for ViT-B/32, 16mins for ViT-L/14, and 28mins for ViT-g/14, that is a lot of waste of time. After checking the trace, I saw that it was not hanging, stuff was happening, and found that FSDP does something special on the first forward pass (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L210). After profiling, I found out that they have two for loops (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L235) which basically take (for ViT-B/32) 12secs each, and it is done for each FSDP unit (nb of FSDP units is proportional model size as we FSDP wrap residual blocks). If you count the total, you get the explanation of why. The for loops iterate over all pair of ranks basically, about 1M total for the two loops, that shouldn't be slow, the thing is that there is a repeated access to a Tensor that is in GPU (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L237) which slow things down. I will open an issue/PR, a simple .cpu() before the for loop solves the problem, it's then a matter of seconds.

mehdidc avatar Feb 19 '23 12:02 mehdidc

I have also observed the delay with FSDP on AWS clusters and actually thought FSDP hangs over a certain number of nodes and didn't pursue it further - thanks for the amazing deep-dive @mehdidc .

orchidmajumder avatar Feb 19 '23 19:02 orchidmajumder

@nkflash pushed, could you please try again? I can confirm that it worked for me

I checkout the head code, it works well now

nkflash avatar Feb 20 '23 06:02 nkflash

Update: as the problem with large nodes is solved, following are updated scaling plots up to 1024 GPUs:

G-14:

G14

I also tested freezing a subset of layers, with MT5-XXL as text encoder (5 last blocks trainable, rest is frozen), G-14 as visual encoder (last block is trainable, rest is frozen), patch dropout 0.5 mt5

mehdidc avatar Mar 06 '23 17:03 mehdidc

Update: the first fully trained model with FSDP is finished, I started with a ViT-B/32 on LAION-400M , 32 epochs (96 gpus, local bs of 896, global bs of 86016, lr of 0.001), zero-shot accuracy in ImageNet is 63.6% with ~90K samples/s throughput. Training was done using pytorch-nightly (torch-2.0.0.dev20230218+cu117).

{"dataset": "wds/imagenet1k", "model": "ViT-B-32", "pretrained": "epoch_32.pt", "task": "zeroshot_classification", "metrics": {"acc1": 0.63606, "acc5": 0.87912, "mean_per_class_recall": 0.6360399999999999}, "language": "en"}

index

it's similar to what we get in https://arxiv.org/pdf/2212.07143.pdf (Table 13)

mehdidc avatar Mar 12 '23 19:03 mehdidc

hi @mehdidc one more thing, current scaler by default use torch.cuda.amp. If we enable FSDP cpuoffload, it could get error, since such scaler can not work on cpu, so maybe need change that to ShardedGradScaler

nkflash avatar Apr 18 '23 09:04 nkflash

Hi @nkflash thank you very much I just changed that

mehdidc avatar May 17 '23 11:05 mehdidc