lightning-thunder
lightning-thunder copied to clipboard
Support NeMo NeVA Model
🚀 Feature
NeMo's NeVa (LLaVa) is a multimodal language model
Initial examine:
Found 49 distinct operations, of which 39 (79.6%) are supported
Work items
- #145 (but looks like #584 will be enough for this model).
- #331
- #327
- #338
- #328
- #326
- #339
- #340
- #341
- #342
- #329
- #601
- #660
- #674
- #678
- #717
- #124
- https://github.com/NVIDIA/NeMo/pull/9689
- #643
- tentative: #750
- #752
- #643
- #812
- #824
- #825
- #826
- #858
- #891
- #896
- #872
- #753
- #1004
- #1040
- #1187
- #1242
- #1250
- #1251
- #1248
- #1252
Running the model
Required data
First download the freely available data and place it in a data directory.
NeMo installation
Dependencies
python3 -m pip install --no-deps \
huggingface-hub==0.23.2
NeMo branch
To keep the whole thunder team on the same NeMo revisions, and to prevent having a bunch of "modify <X> file to call thunder.jit()" instructions, we temporarily maintain our own branch for thunder. You can grab it by cloning https://github.com/tfogal/NeMo.git. Make sure you have checked out the tfogal/thunder-nemo branch.
To install NeMo, run python3 -m pip install -e . from the root of the checked-out directory.
Running the network
rm -fr foo-neva-train; mkdir -p foo-neva-train
HYDRA_FULL_ERROR=1 \
THUNDER_ANNOTATE_TRACES=1 \
NEMO_THUNDER_NEVA=thunder \
python3 ./examples/multimodal/multimodal_llm/neva/neva_pretrain.py \
trainer.precision=bf16-mixed \
model.megatron_amp_O2=True \
model.mcore_gpt=False \
trainer.num_nodes=1 \
trainer.devices=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=5 \
trainer.log_every_n_steps=1 \
++exp_manager.max_time_per_run=00:00:03:00 \
trainer.max_steps=20 \
model.micro_batch_size=2 \
model.global_batch_size=4 \
model.tensor_model_parallel_size=1 \
model.pipeline_model_parallel_size=1 \
exp_manager.create_checkpoint_callback=False \
model.data.data_path=./data/multimodal/tiny-neva/dummy.json \
model.data.image_folder=./data/multimodal/tiny-neva/images \
model.tokenizer.library=sentencepiece \
model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model \
model.num_layers=2 \
model.hidden_size=5120 \
model.ffn_hidden_size=13824 \
model.num_attention_heads=40 \
model.normalization=rmsnorm \
model.data.num_workers=0 \
model.data.conv_template=llama_2 \
model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14 \
model.mm_cfg.llm.from_pretrained=null \
model.use_flash_attention=false \
exp_manager.exp_dir=./foo-neva-train
Note that the latest version of the tfogal/thunder-nemo branch allows running with dynamo+thunder by setting NEMO_THUNDER_NEVA=dynamo.
cc @apaz-cli @tfogal
Can you share the script for the examine call?
Can you share the script for the
examinecall?
@athitten when you have a minute
Adding the updated command to use megatron_amp_O2=True and model.mcore_gpt = True (NeMo models will be defaulting to using models from Megatron, hence this setting). With megatron_amp_O2=True, having precision=bf16 should do mixed precision training with main copy of weights in FP32, but just to be safe also specifying precision=bf16-mixed.
python3 ./examples/multimodal/multimodal_llm/neva/neva_pretrain.py trainer.precision=bf16-mixed model.megatron_amp_O2=True model.mcore_gpt=True trainer.num_nodes=1 trainer.devices=1 trainer.val_check_interval=10 trainer.limit_val_batches=5 trainer.log_every_n_steps=1 ++exp_manager.max_time_per_run=00:00:03:00 trainer.max_steps=20 model.micro_batch_size=2 model.global_batch_size=4 model.tensor_model_parallel_size=1 model.pipeline_model_parallel_size=1 exp_manager.create_checkpoint_callback=False model.data.data_path=./data/multimodal/tiny-neva/dummy.json model.data.image_folder=./data/multimodal/tiny-neva/images model.tokenizer.library=sentencepiece model.tokenizer.model=./data/multimodal/tiny-neva/tokenizer_add_special.model model.num_layers=2 model.hidden_size=5120 model.ffn_hidden_size=13824 model.num_attention_heads=40 model.normalization=rmsnorm model.data.num_workers=0 model.data.conv_template=llama_2 model.mm_cfg.vision_encoder.from_pretrained=openai/clip-vit-large-patch14 model.mm_cfg.llm.from_pretrained=null model.use_flash_attention=false exp_manager.exp_dir=./foo-neva-train
This might be helpful: The full config with default values for all parameters can be found: here. Only the parameters we specify in the run command get overwritten by the specified values and others default to values mentioned in the config.
Adding the updated command
Thanks, @athitten !
I have edited the original issue to mostly reflect the updated command. Unfortunately #753 blocks setting model.mcore_gpt=True, so for now that one's still False... but let's prioritize that one!
Yes its important to prioritize getting thunder working with mcore_gpt=True as it will be default for NeMo models once we deprecate the legacy path.