LLaDA icon indicating copy to clipboard operation
LLaDA copied to clipboard

How implement LLaDA training from the SMDM

Open yuecao0119 opened this issue 9 months ago • 8 comments

Thank you for your outstanding work. Following your recommendation, I attempted to implement LLaDA training from the SMDM code repository, but I am unsure how to set these parameters for the 8B model.

I am looking forward to your reply, it will be of great help to me.

model_para_config = { # TODO for 8B ?
    '6': 6.294784, '19': 18.880896, '34': 33.563136, '48': 47.786688, '66': 65.54944,
    '85': 85.21408, '75': 75.38752, '113': 113.265408, '142': 141.581568, '170': 169.897728,
    '180': 179.856768, '206': 205.550464, '231': 231.24416, '268': 268.469248, '302': 302.027776,
    '336': 335.586304, '472': 471.90656, '551': 550.55744, '571': 571.001728, '629': 629.20832,
    '666': 666.168448, '717': 717.285888, '761': 761.335168, '831': 830.541312, '944': 943.796736,
    '1028': 1027.677952, '1233': 1233.213184, '1476': 1476.487168, '1678': 1677.826048, '2121': 2121.39328
}

# Hyperparameters 32 GPU
num_of_devices = 8
global_batch_size = int(args.batch_size / args.nodes_num)
learning_rate = 4e-4  
if args.model <= 20:
    micro_batch_size = 32
elif args.model <= 50:
    micro_batch_size = 16
elif args.model <= 1000:
    micro_batch_size = 8
else:
    micro_batch_size = 4
average_length = 2048 * (1 - args.ssl_ratio) + (1 + 2048) * 0.5 * args.ssl_ratio
max_step = int(args.flops * 1e12 / (6 * model_para_config[f'{args.model}'] * global_batch_size * average_length) / args.nodes_num)

yuecao0119 avatar Mar 05 '25 06:03 yuecao0119

I also am looking for this. Also, if the same could be applied to an already existing model, it'd be awesome.

prp-e avatar Mar 05 '25 19:03 prp-e

@NieShenRuc Hello. Could you please clarify this? This is very important to me. In addition, what should the following settings be for the 8B model in config.py? Especially block_size? The config.json file of the model is not comprehensive.

dict( # TODO check 8B
name="Diff_LLaMA_8xxxM", # 8xxx ?
block_size=,
vocab_size=,
padding_multiple=,
n_layer=32,
n_head=32,
n_embd=,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5, # Llama 2 use 1e-5. Llama 1 use 1e-6
_mlp_class="LLaMAMLP",
intermediate_size=,
n_query_groups=,
),

yuecao0119 avatar Mar 06 '25 06:03 yuecao0119

Thank you for your interest!

SMDM and LLaDA only share the same loss function. SMDM is an earlier research project that laid the groundwork for LLaDA, and the SMDM code does not include LLaDA configurations.

If you would like to train LLaDA from scratch or fine-tune it, and you already have code for training an autoregressive model, you can easily adapt it to train LLaDA by making a few modifications following the instructions in GUIDELINES.md.

As for the SMDM code, please note that block_size refers to the sequence length, which is set to 4096 for LLaDA. Please refer to https://huggingface.co/GSAI-ML/LLaDA-8B-Base/blob/main/config.json for more details od LLaDA.

nieshenx avatar Mar 06 '25 08:03 nieshenx

Thank you for your interest!

SMDM and LLaDA only share the same loss function. SMDM is an earlier research project that laid the groundwork for LLaDA, and the SMDM code does not include LLaDA configurations.

If you would like to train LLaDA from scratch or fine-tune it, and you already have code for training an autoregressive model, you can easily adapt it to train LLaDA by making a few modifications following the instructions in GUIDELINES.md.

As for the SMDM code, please note that block_size refers to the sequence length, which is set to 4096 for LLaDA. Please refer to https://huggingface.co/GSAI-ML/LLaDA-8B-Base/blob/main/config.json for more details od LLaDA.

Thank you very much for your answer.

I still modified the training code based on the SMDM code. When block_size=2048, it can only run when batch_size=1. After changing to 4096, unfortunately it can't run even if it is 1. OOM will occur. I used FSDP.

strategy = FSDPStrategy(
auto_wrap_policy={LLaDALlamaBlock},
activation_checkpointing_policy=None,
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=True,
)

fabric = L.Fabric(devices=8, num_nodes=1, strategy=strategy, precision=precision, loggers=logger)
fabric.launch(main, train_data_dir, val_data_dir, resume)

My graphics card is A800 80G, theoretically, it should be enough to train the 8B model?

yuecao0119 avatar Mar 06 '25 13:03 yuecao0119

LLaDA is trained on H800 chips, which also have 80G memory, and the batch size of each card is 4. We did not use SMDM but used a framework optimized for our cluster, so I am not sure whether SMDM can successfully train an 8B model.

The SMDM framework is modified from TinyLLaMA, you may want to see if anyone has used the TinyLLaMA framework to train an 8B model.

nieshenx avatar Mar 07 '25 03:03 nieshenx

@NieShenRuc Hello,

I noticed that there is ActivationCheckpointingStrategy in configuration_llada.py. Did you enable it during training? Which one did you enable?

yuecao0119 avatar Mar 13 '25 07:03 yuecao0119

@yuecao0119 I was wondering if you have since found a working solution to fine-tune the LLaDA 8B model? If so, would you mind sharing details that helped you make it work?

user50lab avatar Aug 04 '25 07:08 user50lab

@user50lab

Hi there, we’ve built dllm, a lightweight finetuning framework for diffusion language models on top of the 🤗 Transformers Trainer. Give it a try if you’d like to finetune LLaDA / LLaDA-MoE and Dream.

ZHZisZZ avatar Sep 21 '25 23:09 ZHZisZZ