Added mistral model support
Motivation
This PR aims to add support to train mistral models
Modifications
- Added distributed model impl for mistral model architecture
- Added template to registry
- Added training scripts
Accuracy Test
python tests/test_target_modeling/test_mistral_tp.py
test_mistral_tp (__main__.TestMistralTP) ... rank 1: bind to device 1
rank 0: bind to device 0
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
[rank1]:[W901 17:53:44.564655749 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Saved model to /tmp/tmpp_pm18t_
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
[rank0]:[W901 17:53:44.619454270 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading model from /tmp/tmpp_pm18t_
Loading model from /tmp/tmpp_pm18t_
[rank0]:[W901 17:53:47.376029869 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W901 17:53:48.134036842 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ok
----------------------------------------------------------------------
Ran 1 test in 11.484s
OK
python tests/test_preprocessing.py
test_assistant_span_boundaries (__main__.TestPreprocessing)
Test that assistant span boundaries are correctly identified without truncation. ... ok
test_conversation_preprocessing_basic (__main__.TestPreprocessing)
Test basic conversation preprocessing with assistant response identification. ... ok
test_multiple_turns_conversation (__main__.TestPreprocessing)
Test conversation with multiple user-assistant turns. ... ok
test_preformatted_conversation (__main__.TestPreprocessing)
Test preprocessing of pre-formatted conversation strings. ... ok
----------------------------------------------------------------------
Ran 4 tests in 3.120s
OK
mistralai/Mistral-Small-24B-Instruct-2501 training
Checklist
- [x] Format your code according to the Code Formatting with Pre-Commit.
- [x] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [x] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.
Could you fix the code format @ValeGian
Fix code format
Done with https://github.com/sgl-project/SpecForge/pull/208/commits/06cdfeb6af6c6cd661479762a004767bd4b521de
May I ask if you ran the training on the device mentioned above? When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error. @ValeGian
May I ask if you ran the training on the device mentioned above? When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error. @ValeGian
The model itself is around 47GB on disk. I ran the training on a node of 8xH200(an AWS instance of p5en.48xlarge)
I just tried on a smaller node and was able to run a test training on 2xH100 by just modifying the examples/run_mistral_small_24B_eagle3_online.sh so that --tp 2
bash examples/run_mistral_small_24B_eagle3_online.sh
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
import pynvml # type: ignore[import]
rank 0: bind to device 0
rank 0: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 0: Initialized distributed environment
rank 0: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
rank 1: bind to device 1
rank 1: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 1: Initialized distributed environment
rank 1: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 48210.39it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 26259.16it/s]
rank 1: Initialized target model
rank 0: Initialized target model
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 34239.22it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 18139.31it/s]
rank 0: Initialized draft model
rank 1: Initialized draft model
dataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
[rank1]:[W912 10:54:26.648956214 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1] using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
warnings.warn( # warn only once
[rank0]:[W912 10:54:26.764850022 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0] using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
rank 0: Initialized train dataloader
rank 0: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 0: Loaded vocab mappingdataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
warnings.warn(
rank 0: Initialized Eagle3 FSDP model
rank 0: Initialized optimizer and scheduler
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
rank 1: Initialized train dataloader
rank 1: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 1: Loaded vocab mapping
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
warnings.warn(
rank 1: Initialized Eagle3 FSDP model
rank 1: Initialized optimizer and scheduler
Starting training from epoch 0
Training Epoch 0: 0%| | 84/120675 [00:22<6:29:14, 5.16it/s, loss=0.00, acc=0.00]
with
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 |
| N/A 51C P0 587W / 700W | 67691MiB / 81559MiB | 99% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 |
| N/A 63C P0 592W / 700W | 67691MiB / 81559MiB | 93% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 22069 C ...u/SpecForge/.venv/bin/python3 67682MiB |
| 1 N/A N/A 22070 C ...u/SpecForge/.venv/bin/python3 67682MiB |
+-----------------------------------------------------------------------------------------+
Do you want me to update the examples/run_mistral_small_24B_eagle3_online.sh so that tp is 2?
@ZhengHSI consider that the training for which I reported the curves was optimized to run on the 8xH200 node, the complete set of parameters found on MLflow was
Name | Value
attention_backend | flex_attention
batch_size | 1
build_dataset_num_proc | 192
cache_dir | /home/ubuntu/SpecForge/cache
cache_key | None
chat_template | mistral-small-24B
dist_timeout | 20
dp_size | 2
draft_accumulation_steps | 4
draft_global_batch_size | 8
draft_micro_batch_size | 1
draft_model_config | /home/ubuntu/SpecForge/configs/mistral-small-24B-eagle3.json
embedding_key | model.embed_tokens.weight
eval_data_path | None
eval_interval | 1
is_preformatted | False
is_vlm | False
learning_rate | 0.0001
log_steps | 50
max_grad_norm | 0.5
max_length | 2048
max_pixels | 802816
min_pixels | 50176
mlflow_experiment_name | EAGLE3-mistral-Small-24B
mlflow_run_name | None
mlflow_tracking_uri | <MLflow URI>
num_epochs | 2
output_dir | /home/ubuntu/SpecForge/outputs/mistral-Small-24B-eagle3
profile | False
profile_num_steps | 4
profile_record_shapes | False
profile_start_step | 30
report_to | mlflow
resume | False
save_interval | 1
seed | 0
swanlab_key | None
swanlab_name | None
swanlab_project | None
target_model_path | mistralai/Mistral-Small-24B-Instruct-2501
total_steps | None
tp_size | 4
train_data_path | /home/ubuntu/SpecForge/cache/dataset/sharegpt.jsonl
ttt_length | 7
verbose | False
wandb_key | None
wandb_name | None
wandb_project | None
warmup_ratio | 0.015
I didn't upload the updated configuration as I saw that in the examples folder you keep pretty much the same configuration for every training script, even for larger models such as meta-llama/Llama-4-Scout-17B-16E
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3_online.py \
--target-model-path meta-llama/Llama-4-Scout-17B-16E \
--draft-model-config $ROOT_DIR/configs/llama4-scout-17B-16E-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
--output-dir $ROOT_DIR/outputs/llama4-scout-17B-16E-eagle3 \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template llama4 \
--cache-dir $ROOT_DIR/cache \
--embedding-key language_model.model.embed_tokens.weight \
--tp-size $NUM_GPUS
Thanks for your answer. It would be better to update the script — your current script does not set the tp size, which causes tensor parallelism not to be enabled and leads to OOM. Please modify the script accordingly. @ValeGian
In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.
I’ll look into this. It’s odd, since I previously completed a few training runs successfully and shared the MLflow run for one of them. I'll try once again since a few merges from main + minor commits were done
@ZhengHSI I confirmed that recent merges from main broke the PR, you can find the fixes in commit https://github.com/sgl-project/SpecForge/pull/208/commits/ab36686db3a8aeb2aebc55d8f8a04d6b05e58122. I verified the correct functioning using visualize_loss_mask.
I also updated the default Tensor Parallelism for the script in commit https://github.com/sgl-project/SpecForge/pull/208/commits/26022f14b723ed65c69f1c54f85ad3aad1bca9eb.
Running it on a node with 2 H100 I got
Training Epoch 0: 10%|█████████████▍ | 12578/120675 [33:09<6:16:52, 4.78it/s, loss=2.74, acc=0.42]
Leaving it to run for some steps I got the following MLflow charts
@ZhengHSI any update about this?
@ZhengHSI seems like latest merge from main broke the tests
@ZhengHSI is there any action on my side to allow closing this PR?
In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.