How implement LLaDA training from the SMDM
Thank you for your outstanding work. Following your recommendation, I attempted to implement LLaDA training from the SMDM code repository, but I am unsure how to set these parameters for the 8B model.
I am looking forward to your reply, it will be of great help to me.
model_para_config = { # TODO for 8B ?
'6': 6.294784, '19': 18.880896, '34': 33.563136, '48': 47.786688, '66': 65.54944,
'85': 85.21408, '75': 75.38752, '113': 113.265408, '142': 141.581568, '170': 169.897728,
'180': 179.856768, '206': 205.550464, '231': 231.24416, '268': 268.469248, '302': 302.027776,
'336': 335.586304, '472': 471.90656, '551': 550.55744, '571': 571.001728, '629': 629.20832,
'666': 666.168448, '717': 717.285888, '761': 761.335168, '831': 830.541312, '944': 943.796736,
'1028': 1027.677952, '1233': 1233.213184, '1476': 1476.487168, '1678': 1677.826048, '2121': 2121.39328
}
# Hyperparameters 32 GPU
num_of_devices = 8
global_batch_size = int(args.batch_size / args.nodes_num)
learning_rate = 4e-4
if args.model <= 20:
micro_batch_size = 32
elif args.model <= 50:
micro_batch_size = 16
elif args.model <= 1000:
micro_batch_size = 8
else:
micro_batch_size = 4
average_length = 2048 * (1 - args.ssl_ratio) + (1 + 2048) * 0.5 * args.ssl_ratio
max_step = int(args.flops * 1e12 / (6 * model_para_config[f'{args.model}'] * global_batch_size * average_length) / args.nodes_num)
I also am looking for this. Also, if the same could be applied to an already existing model, it'd be awesome.
@NieShenRuc Hello. Could you please clarify this? This is very important to me. In addition, what should the following settings be for the 8B model in config.py? Especially block_size? The config.json file of the model is not comprehensive.
dict( # TODO check 8B
name="Diff_LLaMA_8xxxM", # 8xxx ?
block_size=,
vocab_size=,
padding_multiple=,
n_layer=32,
n_head=32,
n_embd=,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5, # Llama 2 use 1e-5. Llama 1 use 1e-6
_mlp_class="LLaMAMLP",
intermediate_size=,
n_query_groups=,
),
Thank you for your interest!
SMDM and LLaDA only share the same loss function. SMDM is an earlier research project that laid the groundwork for LLaDA, and the SMDM code does not include LLaDA configurations.
If you would like to train LLaDA from scratch or fine-tune it, and you already have code for training an autoregressive model, you can easily adapt it to train LLaDA by making a few modifications following the instructions in GUIDELINES.md.
As for the SMDM code, please note that block_size refers to the sequence length, which is set to 4096 for LLaDA. Please refer to https://huggingface.co/GSAI-ML/LLaDA-8B-Base/blob/main/config.json for more details od LLaDA.
Thank you for your interest!
SMDM and LLaDA only share the same loss function. SMDM is an earlier research project that laid the groundwork for LLaDA, and the SMDM code does not include LLaDA configurations.
If you would like to train LLaDA from scratch or fine-tune it, and you already have code for training an autoregressive model, you can easily adapt it to train LLaDA by making a few modifications following the instructions in GUIDELINES.md.
As for the SMDM code, please note that block_size refers to the sequence length, which is set to 4096 for LLaDA. Please refer to https://huggingface.co/GSAI-ML/LLaDA-8B-Base/blob/main/config.json for more details od LLaDA.
Thank you very much for your answer.
I still modified the training code based on the SMDM code. When block_size=2048, it can only run when batch_size=1. After changing to 4096, unfortunately it can't run even if it is 1. OOM will occur. I used FSDP.
strategy = FSDPStrategy(
auto_wrap_policy={LLaDALlamaBlock},
activation_checkpointing_policy=None,
state_dict_type="full",
limit_all_gathers=True,
cpu_offload=True,
)
fabric = L.Fabric(devices=8, num_nodes=1, strategy=strategy, precision=precision, loggers=logger)
fabric.launch(main, train_data_dir, val_data_dir, resume)
My graphics card is A800 80G, theoretically, it should be enough to train the 8B model?
LLaDA is trained on H800 chips, which also have 80G memory, and the batch size of each card is 4. We did not use SMDM but used a framework optimized for our cluster, so I am not sure whether SMDM can successfully train an 8B model.
The SMDM framework is modified from TinyLLaMA, you may want to see if anyone has used the TinyLLaMA framework to train an 8B model.
@NieShenRuc Hello,
I noticed that there is ActivationCheckpointingStrategy in configuration_llada.py. Did you enable it during training? Which one did you enable?
@yuecao0119 I was wondering if you have since found a working solution to fine-tune the LLaDA 8B model? If so, would you mind sharing details that helped you make it work?
@user50lab
Hi there, we’ve built dllm, a lightweight finetuning framework for diffusion language models on top of the 🤗 Transformers Trainer. Give it a try if you’d like to finetune LLaDA / LLaDA-MoE and Dream.