SpecForge icon indicating copy to clipboard operation
SpecForge copied to clipboard

Added mistral model support

Open ValeGian opened this issue 4 months ago • 12 comments

Motivation

This PR aims to add support to train mistral models

Modifications

  • Added distributed model impl for mistral model architecture
  • Added template to registry
  • Added training scripts

Accuracy Test

python tests/test_target_modeling/test_mistral_tp.py
test_mistral_tp (__main__.TestMistralTP) ... rank 1: bind to device 1
rank 0: bind to device 0
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank1]:[W901 17:53:44.564655749 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Saved model to /tmp/tmpp_pm18t_
/home/ubuntu/workspace/forks/SpecForge/venv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank0]:[W901 17:53:44.619454270 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading model from /tmp/tmpp_pm18t_
Loading model from /tmp/tmpp_pm18t_
[rank0]:[W901 17:53:47.376029869 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W901 17:53:48.134036842 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ok

----------------------------------------------------------------------
Ran 1 test in 11.484s

OK
python tests/test_preprocessing.py 
test_assistant_span_boundaries (__main__.TestPreprocessing)
Test that assistant span boundaries are correctly identified without truncation. ... ok
test_conversation_preprocessing_basic (__main__.TestPreprocessing)
Test basic conversation preprocessing with assistant response identification. ... ok
test_multiple_turns_conversation (__main__.TestPreprocessing)
Test conversation with multiple user-assistant turns. ... ok
test_preformatted_conversation (__main__.TestPreprocessing)
Test preprocessing of pre-formatted conversation strings. ... ok

----------------------------------------------------------------------
Ran 4 tests in 3.120s

OK
image

mistralai/Mistral-Small-24B-Instruct-2501 training image image

Checklist

  • [x] Format your code according to the Code Formatting with Pre-Commit.
  • [x] Add unit tests as outlined in the Running Unit Tests.
  • [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [x] Please feel free to join our Slack channel at https://sgl-fru7574.slack.com/archives/C09784E3EN6 to discuss your PR.

ValeGian avatar Sep 01 '25 18:09 ValeGian

Could you fix the code format @ValeGian

sleepcoo avatar Sep 04 '25 03:09 sleepcoo

Fix code format

Done with https://github.com/sgl-project/SpecForge/pull/208/commits/06cdfeb6af6c6cd661479762a004767bd4b521de

ValeGian avatar Sep 04 '25 09:09 ValeGian

May I ask if you ran the training on the device mentioned above? When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error. @ValeGian

ZhengHSI avatar Sep 12 '25 08:09 ZhengHSI

May I ask if you ran the training on the device mentioned above? When I use your script to train the model on 8×H20 GPUs (96 GB each), it results in an OOM (out‑of‑memory) error. @ValeGian

The model itself is around 47GB on disk. I ran the training on a node of 8xH200(an AWS instance of p5en.48xlarge)

I just tried on a smaller node and was able to run a test training on 2xH100 by just modifying the examples/run_mistral_small_24B_eagle3_online.sh so that --tp 2

bash examples/run_mistral_small_24B_eagle3_online.sh
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774]
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0912 10:54:03.059000 25881 torch/distributed/run.py:774] *****************************************
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
rank 0: bind to device 0
rank 0: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 0: Initialized distributed environment
rank 0: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
rank 1: bind to device 1
rank 1: device mesh: DeviceMesh('cuda', [[0, 1]], mesh_dim_names=('dp', 'tp'))
rank 1: Initialized distributed environment
rank 1: draft_accumulation_steps=8 // 1 // 1=8
Set draft model tie_word_embeddings to False
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 48210.39it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 26259.16it/s]
rank 1: Initialized target model
rank 0: Initialized target model
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
WARNING:specforge.modeling.draft.llama3_eagle:Using flex attention on draft model training!
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 34239.22it/s]
Fetching 22 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 18139.31it/s]
rank 0: Initialized draft model
rank 1: Initialized draft model
dataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank1]:[W912 10:54:26.648956214 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:4807: UserWarning: No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user.
  warnings.warn(  # warn only once
[rank0]:[W912 10:54:26.764850022 ProcessGroupNCCL.cpp:5023] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 as device used by this process is currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. You can specify device_id in init_process_group() to force use of a particular device.
rank 0: Initialized train dataloader
rank 0: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 0: Loaded vocab mappingdataset is cached at /home/ubuntu/SpecForge/cache/processed_dataset/e0db0026cc75db208ec5b318141dd0eb.pkl

/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
  warnings.warn(
rank 0: Initialized Eagle3 FSDP model
rank 0: Initialized optimizer and scheduler
Loading vocab mapping from the cached file at: /home/ubuntu/SpecForge/cache/vocab_mapping/e0db0026cc75db208ec5b318141dd0eb.pt
rank 1: Initialized train dataloader
rank 1: Auto-calculated total_steps: 30170 (num_epochs=2 * steps_per_epoch=15085)
rank 1: Loaded vocab mapping
/home/ubuntu/SpecForge/.venv/lib/python3.12/site-packages/torch/distributed/fsdp/_init_utils.py:430: UserWarning: FSDP is switching to use `NO_SHARD` instead of ShardingStrategy.SHARD_GRAD_OP since the world size is 1.
  warnings.warn(
rank 1: Initialized Eagle3 FSDP model
rank 1: Initialized optimizer and scheduler
Starting training from epoch 0
Training Epoch 0:   0%|                                                                                                                                    | 84/120675 [00:22<6:29:14,  5.16it/s, loss=0.00, acc=0.00]

with

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.148.08             Driver Version: 570.148.08     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:07:00.0 Off |                    0 |
| N/A   51C    P0            587W /  700W |   67691MiB /  81559MiB |     99%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:08:00.0 Off |                    0 |
| N/A   63C    P0            592W /  700W |   67691MiB /  81559MiB |     93%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A           22069      C   ...u/SpecForge/.venv/bin/python3      67682MiB |
|    1   N/A  N/A           22070      C   ...u/SpecForge/.venv/bin/python3      67682MiB |
+-----------------------------------------------------------------------------------------+

Do you want me to update the examples/run_mistral_small_24B_eagle3_online.sh so that tp is 2?

ValeGian avatar Sep 12 '25 11:09 ValeGian

@ZhengHSI consider that the training for which I reported the curves was optimized to run on the 8xH200 node, the complete set of parameters found on MLflow was

Name | Value
attention_backend | flex_attention
batch_size | 1
build_dataset_num_proc | 192
cache_dir | /home/ubuntu/SpecForge/cache
cache_key | None
chat_template | mistral-small-24B
dist_timeout | 20
dp_size | 2
draft_accumulation_steps | 4
draft_global_batch_size | 8
draft_micro_batch_size | 1
draft_model_config | /home/ubuntu/SpecForge/configs/mistral-small-24B-eagle3.json
embedding_key | model.embed_tokens.weight
eval_data_path | None
eval_interval | 1
is_preformatted | False
is_vlm | False
learning_rate | 0.0001
log_steps | 50
max_grad_norm | 0.5
max_length | 2048
max_pixels | 802816
min_pixels | 50176
mlflow_experiment_name | EAGLE3-mistral-Small-24B
mlflow_run_name | None
mlflow_tracking_uri | <MLflow URI>
num_epochs | 2
output_dir | /home/ubuntu/SpecForge/outputs/mistral-Small-24B-eagle3
profile | False
profile_num_steps | 4
profile_record_shapes | False
profile_start_step | 30
report_to | mlflow
resume | False
save_interval | 1
seed | 0
swanlab_key | None
swanlab_name | None
swanlab_project | None
target_model_path | mistralai/Mistral-Small-24B-Instruct-2501
total_steps | None
tp_size | 4
train_data_path | /home/ubuntu/SpecForge/cache/dataset/sharegpt.jsonl
ttt_length | 7
verbose | False
wandb_key | None
wandb_name | None
wandb_project | None
warmup_ratio | 0.015

I didn't upload the updated configuration as I saw that in the examples folder you keep pretty much the same configuration for every training script, even for larger models such as meta-llama/Llama-4-Scout-17B-16E

torchrun \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_eagle3_online.py \
    --target-model-path meta-llama/Llama-4-Scout-17B-16E \
    --draft-model-config $ROOT_DIR/configs/llama4-scout-17B-16E-eagle3.json \
    --train-data-path $ROOT_DIR/cache/dataset/sharegpt.jsonl \
    --output-dir $ROOT_DIR/outputs/llama4-scout-17B-16E-eagle3 \
    --num-epochs 10 \
    --batch-size 1 \
    --learning-rate 1e-4 \
    --max-length 2048 \
    --chat-template llama4 \
    --cache-dir $ROOT_DIR/cache \
    --embedding-key language_model.model.embed_tokens.weight \
    --tp-size $NUM_GPUS

ValeGian avatar Sep 12 '25 11:09 ValeGian

Thanks for your answer. It would be better to update the script — your current script does not set the tp size, which causes tensor parallelism not to be enabled and leads to OOM. Please modify the script accordingly. @ValeGian

ZhengHSI avatar Sep 15 '25 03:09 ZhengHSI

image In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.

ZhengHSI avatar Sep 15 '25 03:09 ZhengHSI

image In addition, I tried training several times, but the loss and accuracy have always remained at 0. I saw in your previous answer that you also encountered this situation.

I’ll look into this. It’s odd, since I previously completed a few training runs successfully and shared the MLflow run for one of them. I'll try once again since a few merges from main + minor commits were done

ValeGian avatar Sep 15 '25 08:09 ValeGian

@ZhengHSI I confirmed that recent merges from main broke the PR, you can find the fixes in commit https://github.com/sgl-project/SpecForge/pull/208/commits/ab36686db3a8aeb2aebc55d8f8a04d6b05e58122. I verified the correct functioning using visualize_loss_mask.

I also updated the default Tensor Parallelism for the script in commit https://github.com/sgl-project/SpecForge/pull/208/commits/26022f14b723ed65c69f1c54f85ad3aad1bca9eb.

Running it on a node with 2 H100 I got

Training Epoch 0:  10%|█████████████▍     | 12578/120675 [33:09<6:16:52,  4.78it/s, loss=2.74, acc=0.42]

Leaving it to run for some steps I got the following MLflow charts image image

ValeGian avatar Sep 22 '25 19:09 ValeGian

@ZhengHSI any update about this?

ValeGian avatar Sep 30 '25 19:09 ValeGian

@ZhengHSI seems like latest merge from main broke the tests

ValeGian avatar Oct 02 '25 18:10 ValeGian

@ZhengHSI is there any action on my side to allow closing this PR?

ValeGian avatar Oct 07 '25 20:10 ValeGian