ColossalAI-Examples
ColossalAI-Examples copied to clipboard
[feature] New example: MAE pretraining on ImageNet 1000 dataset
Colossal-AI implementation of MAE, arxiv.
As an example, we just cover the pretrain phase with ImageNet 1000 mini dataset. Helpers under subdir util/ are from facebookresearch/deit, under Apache License 2.0.
About the coding style
The coding style might be a little different from other examples like run_resnet_cifar10_with_engine.py
, the configuration config/pretrain.py
handled rich initialization logic and default values.
The DeiT and MAE code has a really complicated and intertwined initialization process. By making full use of Colossal-AI's dynamic python configuration ability, we can keep things simple enough for newcomers to understand.
Hi, as we want to support hybrid parallel MAE, can you try to support TP and PP as well? You can refer to the tutorial.
Pure DP MAE with colossal has been finished by https://github.com/lucasliunju/mae-colossalai
Okay, thank you for the help! I might be a beginner and I'm glad to check those links.
Hey everybody! I managed to support (limited) Tensor Parallelism, check it by running:
torchrun --standalone --nnodes 1 --nproc_per_node 4 main_pretrain.py --config ./config/pretrain_1d_tp2.py
I tune the model inside models_mae_tp.py
. More information in README.md
Add save & load model functionality, with colossalai.utils.checkpointing
.
Hi @ofey404, thank you for your contribution! Would you please provide train logs in different parallelism settings?
Hi @ofey404, thank you for your contribution! Would you please provide train logs in different parallelism settings?
Several epochs or a full run? A full 800 epochs run might take a long time to finish...
Hi @ofey404, thank you for your contribution! Would you please provide train logs in different parallelism settings?
Several epochs or a full run? A full 800 epochs run might take a long time to finish...
Perhaps you could consider providing some basic validation first, such as on CIFAR or a subset of ImageNet, eg. ImageNet100. As well as using the idle time of server at night to verify 30% epochs of ImageNet. So, we can provide this to users with confidence to some extent. Finally, you can complete the convergence verification.
ImageNet100 on kaggle is 16 GB while ImageNet1000 I used is only 2 GB. CIFAR10 might be a good candidate for basic validation.
The problem is that, the original pretrain part of MAE doesn't contain validation, only the main training process has validation. So I'd better implement the main process too.