open_clip
open_clip copied to clipboard
[WIP] Support FSDP
This PR adds FSDP (https://pytorch.org/docs/stable/fsdp.html) support for training large models than cannot fit in memory.
The code works already, but still need to be improved, so this is still a draft.
Some scaling plots with sample per sec per gpu, done on JUWELS Booster.
G-14:
X-14 (15B visual, 5B text):
I also tried G-14 as visual encoder together with a pre-trained T5-XXL as text encoder.
Putting again some remarks and possible improvements, discussed earlier in discord:
- I see some hanging issues starting from large number of nodes (256 nodes, 1024 GPUs onJUWELS Booster), no single iteration, I don't see anything special on NCCL debugging info except that it does a lot of
all_gather
, which is expected from FSDP - CPU offloading and gradient checkpointing are supported
- each time
encode_image
orencode_text
orlogit_scale
were accessed without going through the forward function (so happens when clipping logit scale, or at evaluation), an exception was raised (see https://github.com/pytorch/pytorch/issues/82461 for reference). The workaround I found is to modify the forward function so that it is possible to encode both text and image (as currently done), or text only, or image only, or use it for clipping logit scale. It would be better if we find a cleaner solution. The solution provided by the issue in pytorch above is to wrap the modules (here text and image encoders) using FSDP, but we need then to change some internals, as part of the text encoder cannot be wrapped as it is annn.Parameter
, FSDP needs annn.Module
. We could useCustomTextCLIP
to wrap the text encoder in its entirty as proposed by @rwightman, then we need to deal withlogit_scale
. - The list of layers to FSDP-wrap is important as it affects the peak memory (documented here https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html#id2) and the layer names are model dependent. e.g. in T5, I am FSDP wrapping T5Block, and in CLIP class of I am wrapping ResidualAttentionBlock. So we need a way to parametrize this, for the moment they are hardcoded. If we just use the default auto policy of FSDP, we get OOM.
One thing that will break in this implementation is the ability to set separate WD for LN parameters or the bias e.g. here - https://github.com/mlfoundations/open_clip/blob/16e229c596cafaec46a4defaf27e0e30ffcca12d/src/training/main.py#L270. Because FSDP will wrap the VisionTransformer or TextTransformer into one FSDP block, these names will not be retained and the filter will fail silently. You can print the layer names after it is wrapped in FSDP to verify this.
To make sure FSDP retains the original name, you will need to pass an additional field in the FSDP constructor called use_orig_params=True
. See my discussion in PyTorch forum here - https://discuss.pytorch.org/t/setting-different-weight-decay-values-for-parameters-within-one-fsdp-unit/169862/2. This feature needs PyTorch nightly.
@orchidmajumder good point, I believe that arg is also required if we want to use torch.compile with FSDP, at least in its current state
Can you rebase on master ?
Thanks @orchidmajumder @rwightman will look into that ! @rom1504 just rebased.
Update: layer names to FSDP-wrap are not hardcoded anymore, they can now be provided in the CLI with defaults that will work already with models we have.
Update: following this thread https://github.com/huggingface/accelerate/issues/807, full / partial locking now works. Currently getting some throughput numbers with mt5-xxl-ViT-G-14
Update: I mentioned earlier that training was hanging with large nodes (e.g., 256 on JUWELS Booster), after checking lower nb of nodes, it seems that the starting up phase (before displaying the first "INFO | Train Epoch") duration is long and proportional to nb of nodes, which is problematic. e.g. for 128 nodes, the starting up phase takes 24mins, and in 64 nodes it takes 11mins. Will open an issue pytorch. So it is probably not properly hanging for 256, I just did not run it for long enough, but it's a lot of time if it would take 48mins.
This is GPU usage for a 128 nodes run of mt5-xxl-ViT-G-14
.
Small GPU usage for the first 24 mins, then it starts to have > 99% usage and that coincides with the first "INFO | Train Epoch" message in the logs.
Hi @mehdidc, Base on your code I try to use ViT-e-14 model, openCLIP will hang after first epoch step with FSDP enable. Do you meet same issue?
hey @nkflash, thanks I actually noticed that as well, even with smaller models, I am on it.
EDIT: found a fix, will push soon
@nkflash pushed, could you please try again? I can confirm that it worked for me
Thanks, @orchidmajumder , use_orig_params
is working as expected. So with pytorch nightly, we can already use it. If we want to also support current pytorch stable version (1.13), wrapping layer norms into in their own FSDP units using the option I added --fsdp-layers-to-wrap
would also work, but it won't handle other cases, e.g. biases from MLP layers, we need also to wrap them as well separately, so I am not sure we would be able to support current pytorch current stable (1.13) without more complications in the code, I think for now we just need to document that to the user (except if we find a better solution), i.e. the closest that can be done to get the "correct" behavior is to FSDP-wrap layer norms, but in that case biases from MLPs will be decayed, otherwise one needs pytorch nightly, or next stable version.
The other thing that needs to be changed is:
https://github.com/mlfoundations/open_clip/blob/6ed7dd66b52c4dfda6c6fc6bad50e5857ee63123/src/training/main.py#L271
since FSDP flattens everything, p.ndim
is always < 2, so everything would be excluded in the current code, which means
everything will be weight-decayed.
I found out that for ViTs at least, the only additional case that p.ndim < 2
covers is visual.class_embedding
(https://github.com/mlfoundations/open_clip/blob/6ed7dd66b52c4dfda6c6fc6bad50e5857ee63123/src/open_clip/transformer.py#L366),
the rest is covered by the other clauses.
Or, is this supposed to cover something else? @rwightman @rom1504 @mitchellnw @gabrielilharco
What about making exclusions parametrizable, e.g. with regexps ? perhaps to be more explicit about the parameters
to decay.
the p.ndim < 2
check should also cover logit_scale
Yes was thinking of that as well but saw that there is already 'logit_scale' in n
in exclude
@mehdidc the position, token class embeddings are typically not decayed as well but looks like that was never done in OpenCLIP, hrmm. I have no_weight_decay
methods in timm that return lists of names to exclude from decay to cover the dim >= 2 cases like position embeddings, etc.
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L506
https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/maxxvit.py#L1229
I feel that name based decay by itself is error prone and won't generalized well to other models (like timm vision towers), the unflattened dim is the strongest signal (but sometimes you need to use names for some layers like embeddings, etc).
I feel the best approach would be to build a set
of names (from both shape and name) to not decay before wrapping in FSDP when that info is available, so move the current code a bit earlier.
On the other topic, I feel it's fine to support nightlies only, I'm already exclusively using nightlies to train the convnext models because it's the only way to get decent bfloat16 support for convolutions. I'm going to add torch.compile soon to see how that works, the nn.MHA is quite a bit faster on nightlies (has a fused kernel).
@rwightman Thanks for the suggestion, I moved the code a bit earlier now, now it is fixed.
Update:@rwightman @rom1504 @mitchellnw @gabrielilharco @JeniaJitsev just for info, regarding the starting up phase I mentioned earlier (https://github.com/mlfoundations/open_clip/pull/358#issuecomment-1423851399), I found out that it is not only proportional to nb of nodes but also model size, but found a fix. Read below if you want more info.
So e.g. with 256 nodes on JUWELS Booster, it took 13mins for ViT-B/32, 16mins for ViT-L/14, and 28mins for ViT-g/14, that is a lot of waste of time.
After checking the trace, I saw that it was not hanging, stuff was happening, and found that FSDP does something special on the first forward pass (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L210). After profiling, I found out that they have two for loops (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L235) which basically take (for ViT-B/32) 12secs each, and it is done for each FSDP unit (nb of FSDP units is proportional model size as we FSDP wrap residual blocks). If you count the total, you get the explanation of why. The for loops iterate over all pair of ranks basically, about 1M total for the two loops, that shouldn't be slow, the thing is that there is a repeated access to a Tensor that is in GPU (https://github.com/pytorch/pytorch/blob/85e0fd0280948a342a916429448fed2486e82aa5/torch/distributed/fsdp/_exec_order_utils.py#L237) which slow things down. I will open an issue/PR, a simple .cpu()
before the for loop solves the problem, it's then a matter of seconds.
I have also observed the delay with FSDP on AWS clusters and actually thought FSDP hangs over a certain number of nodes and didn't pursue it further - thanks for the amazing deep-dive @mehdidc .
@nkflash pushed, could you please try again? I can confirm that it worked for me
I checkout the head code, it works well now
Update: as the problem with large nodes is solved, following are updated scaling plots up to 1024 GPUs:
G-14:
I also tested freezing a subset of layers, with MT5-XXL as text encoder (5 last blocks trainable, rest is frozen), G-14 as visual encoder (last block is trainable, rest is frozen), patch dropout 0.5
Update: the first fully trained model with FSDP is finished, I started with a ViT-B/32 on LAION-400M , 32 epochs (96 gpus, local bs of 896, global bs of 86016, lr of 0.001), zero-shot accuracy in ImageNet is 63.6% with ~90K samples/s throughput. Training was done using pytorch-nightly (torch-2.0.0.dev20230218+cu117
).
{"dataset": "wds/imagenet1k", "model": "ViT-B-32", "pretrained": "epoch_32.pt", "task": "zeroshot_classification", "metrics": {"acc1": 0.63606, "acc5": 0.87912, "mean_per_class_recall": 0.6360399999999999}, "language": "en"}
it's similar to what we get in https://arxiv.org/pdf/2212.07143.pdf (Table 13)
hi @mehdidc one more thing, current scaler by default use torch.cuda.amp. If we enable FSDP cpuoffload, it could get error, since such scaler can not work on cpu, so maybe need change that to ShardedGradScaler
Hi @nkflash thank you very much I just changed that