gritlm icon indicating copy to clipboard operation
gritlm copied to clipboard

Train llama 3.1 with GRIT

Open ThisisXXZ opened this issue 1 year ago • 5 comments

I'm now trying to train llama3.1 with GRIT pipeline.

At first I directly change --model_name_or_path and run the training code (the training script I used is as follows)

#!/bin/bash
#SBATCH --time=6:00:00
#SBATCH --job-name=grit_train
#SBATCH --gres=gpu:h100-96:2
#SBATCH --mem=60G
#SBATCH --output=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_out.log
#SBATCH --error=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_err.log

source ~/.bashrc
conda activate grit_eval

export CUDA_HOME='/usr/local/cuda-12.1'
# CUDA_VISIBLE_DEVICES=$(python train/gritlm/mig_uuid_setup.py)
export CUDA_VISIBLE_DEVICES=0,1

cd /home/e/e1347696/unified_encoder_decoder

# nvidia-smi 

deepspeed \
    --num_gpus=2 \
    --module train.gritlm.training.run \
    --output_dir results/GritLM-7B-training \
    --model_name_or_path model/Llama-3.1-8B \
    --train_data data/grit_training_data \
    --max_example_num_per_dataset 1000 \
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --max_steps 1253 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 256 \
    --per_device_generative_bs 32 \
    --dataloader_drop_last \
    --normalized \
    --temperature 0.02 \
    --train_group_size 2 \
    --negatives_cross_device \
    --query_max_len 256 \
    --passage_max_len 1024 \
    --mode unified \
    --logging_steps 1 \
    --bf16 \
    --pooling_method mean \
    --use_unique_indices \
    --loss_gen_type mixed \
    --attn bbcc \
    --attn_implementation sdpa \
    --no_gen_gas \
    --gradient_checkpointing \
    --save_steps 1000 \
    --split_emb \
    --deepspeed scripts/configs/config_8gpusds_m7.json

But there is an error TypeError: LlamaModel.forward() got an unexpected keyword argument 'is_causal'. I looked into it and found several issues regarding this #34, #32 and #19. Just to confirm, if I want to train llama 3.1 model with GRIT, can I just

  • reuse the provided modeling file directly by putting modeling_gritlm7b.py into llama3.1 model folder or do I need to
  • change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

Thank you so much!

ThisisXXZ avatar Nov 09 '24 13:11 ThisisXXZ

change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

Muennighoff avatar Nov 09 '24 19:11 Muennighoff

change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

I thought is_causal is an argument controlling whether we are using bidirectional attention in the model or not, since the original modeling file does not accept such argument, do we need to implement this for it?

I’m still learning, so please kindly correct me if I’m mistaken. Thank you so much!

ThisisXXZ avatar Nov 10 '24 02:11 ThisisXXZ

yes you're right; i meant to say that that is the option you have to go with; sorry i should have removed the ?

Muennighoff avatar Nov 10 '24 02:11 Muennighoff

yes you're right; i meant to say that that is the option you have to go with; sorry i should have removed the ?

Thank you so much! Nah you don't need to remove "the" it's just me don't have much confidence in that :)

I've checked the code for modeling_mistral_gritlm.py and had few other questions

  • Do I need to modify other settings despite the implementation of is_causal arg?
  • I noticed that flash attention is used in modeling_mistral_gritlm.py but not in llama3 modeling file. So if I'm going to implement my own version of is_causal arg, can I just simply add the bidirectional component in Attention class?

Really appreciate your prompted guidance, even on weekends! You are indeed one of the most helpful and fastest responding authors I have contacted.

ThisisXXZ avatar Nov 10 '24 03:11 ThisisXXZ

  1. only is_causal
  2. yeah any attention mechanism should be fine as long as you implement the masking for it

Muennighoff avatar Nov 10 '24 03:11 Muennighoff