transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Load fsdp+lora checkpoint error

Open A-zhudong opened this issue 7 months ago • 7 comments

System Info

  • transformers version: 4.42.0
  • Platform: Linux-5.15.0-105-generic-x86_64-with-glibc2.35
  • Python version: 3.9.19
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.32.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

step1: train without checkpoint and load llama2:

"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import glob
import argparse
import torch
import random
import warnings
import numpy as np
import pandas as pd
from pymatgen.core.structure import Structure
from pathlib import Path

from dataclasses import dataclass
import transformers
from transformers import ( 
    LlamaForCausalLM,
    LlamaTokenizer, 
    Trainer, 
    TrainingArguments,
    BitsAndBytesConfig,
    AutoModelForCausalLM
)

from trl import SFTTrainer
from torch.utils.data import Dataset

from peft import (
    LoraConfig, 
    get_peft_model, 
    prepare_model_for_kbit_training
)

IGNORE_INDEX = -100
MAX_LENGTH = 2048

def get_crystal_string(cif_str):
    structure = Structure.from_str(cif_str, fmt="cif")
    structure.translate_sites(
        indices=range(len(structure.sites)), vector=np.random.uniform(size=(3,))
    )

    lengths = structure.lattice.parameters[:3]
    angles = structure.lattice.parameters[3:]
    atom_ids = structure.species
    frac_coords = structure.frac_coords

    crystal_str = \
        " ".join(["{0:.1f}".format(x) for x in lengths]) + "\n" + \
        " ".join([str(int(x)) for x in angles]) + "\n" + \
        "\n".join([
            str(t) + "\n" + " ".join([
                "{0:.2f}".format(x) for x in c
            ]) for t,c in zip(atom_ids, frac_coords)
        ])

    return crystal_str

class CifDataset(Dataset):
    def __init__(
        self,
        csv_fn,
        format_options={},
        llama_tokenizer=None,
        w_attributes=False,
    ):
        super().__init__()

        if not os.path.exists(csv_fn) and not glob.glob(csv_fn):
            raise ValueError(f"CSV file {csv_fn} does not exist")

        df = pd.concat([pd.read_csv(fn) for fn in glob.glob(csv_fn)])
        print('length of all data: ', len(df))
        self.inputs = df.to_dict(orient="records")

        self.llama_tokenizer = llama_tokenizer

        self.format_options = format_options
        self.w_attributes = w_attributes
   
    def crystal_string(self, input_dict):
        k = 'cif' if 'cif' in input_dict else 'cif_str'
        return get_crystal_string(input_dict[k])

    def generation_task(self, input_dict):

        prompt = "Below is a description of a bulk material. "
        
        all_attributes = [
            "formation_energy_per_atom",
            "band_gap",
            "e_above_hull",
            "spacegroup.number",
        ]

        # sample a random collection of attributes
        num_attributes = random.randint(0, len(all_attributes))
        if num_attributes > 0 and self.w_attributes:
            attributes = random.sample(all_attributes, num_attributes)
            attributes = ["pretty_formula"] + attributes

            prompt_lookup = {
                "formation_energy_per_atom": "The formation energy per atom is",
                "band_gap": "The band gap is",
                "pretty_formula": "The chemical formula is",
                "e_above_hull": "The energy above the convex hull is",
                "elements": "The elements are",
                "spacegroup.number": "The spacegroup number is",
            }

            for attr in attributes:
                if attr == "elements":
                    prompt += f"{prompt_lookup[attr]} {', '.join(input_dict[attr])}. "
                elif attr in ["formation_energy_per_atom", "band_gap", "e_above_hull"]:
                    prompt += f"{prompt_lookup[attr]} {round(float(input_dict[attr]), 4)}. "
                else:
                    prompt += f"{prompt_lookup[attr]} {input_dict[attr]}. "

        prompt += (
            "Generate a description of the lengths and angles of the lattice vectors "
            "and then the element type and coordinates for each atom within the lattice:\n"
        )

        crystal_str = self.crystal_string(input_dict)

        tokens = self.llama_tokenizer(
            prompt + crystal_str  + self.llama_tokenizer.eos_token,
            return_tensors="pt",
            max_length=MAX_LENGTH,
            truncation=True,
        )

        return tokens

    def infill_task(self, input_dict):
        
        prompt = (
            'Below is a partial description of a bulk material where one '
            'element has been replaced with the string "[MASK]":\n'
        )

        k = 'cif' if 'cif' in input_dict else 'cif_str'
        structure = Structure.from_str(input_dict[k], fmt="cif")
        species = [str(s) for s in structure.species]
        species_to_remove = random.choice(species)

        crystal_string = self.crystal_string(input_dict)

        partial_crystal_str = crystal_string.replace(
            species_to_remove, "[MASK]"
        )

        infill_str = prompt + partial_crystal_str + "\n"

        infill_str += (
            "Generate an element that could replace [MASK] in the bulk material:\n"
        )

        infill_str += str(species_to_remove) + self.llama_tokenizer.eos_token

        tokens = self.llama_tokenizer(
            infill_str,
            return_tensors="pt",
            max_length=MAX_LENGTH,
            truncation=True,
        )

        return tokens

    def tokenize(self, input_dict):
        if random.random() < 0.66:
            tokens = self.generation_task(input_dict)
        else:
            tokens = self.infill_task(input_dict)

        input_ids = labels = tokens.input_ids[0]
        input_ids_lens = labels_lens = tokens.input_ids.ne(
            self.llama_tokenizer.pad_token_id).sum().item()
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        if not 0 <= index < len(self):
            raise IndexError("Index out of range")

        vals = self.inputs[index]
        vals = self.tokenize(vals)
        return vals

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        # print(instances)
        input_ids, labels = tuple(
            [instance[key].clone().detach() for instance in instances] 
                for key in ("input_ids", "labels")
        )
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

def setup_datasets(args, llama_tokenizer, transform_args={}):    
    format_options = {
        "permute_composition": args.format_permute_composition,
        "permute_structure": args.format_permute_structure,
    }
    datasets = {
        "train": CifDataset(
            str(args.data_path / "train_10.csv"), 
            format_options,
            llama_tokenizer=llama_tokenizer,
            w_attributes=args.w_attributes,
        ),
        "val": CifDataset(
            str(args.data_path / "val_10.csv"),
            format_options,
            llama_tokenizer=llama_tokenizer,
            w_attributes=args.w_attributes,
        ),
    }
    return datasets


def setup_training_args(args):
    output_dir= args.expdir / args.run_name
    output_dir.mkdir(parents=True, exist_ok=True)

    if args.debug:
        os.environ["WANDB_DISABLED"] = "True"
    os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
    args.batch_size = 1
    # exit()
    training_args = TrainingArguments(
        # fsdp=False,
        fsdp=True,
        # fp16=not args.fp8,
        fp16=False,
        # bf16=False,
        bf16=True,
        gradient_checkpointing=False,
        # gradient_checkpointing=True,
        ddp_find_unused_parameters=False,
        num_train_epochs=args.num_epochs,
        eval_steps=args.eval_freq,
        save_steps=args.save_freq,
        logging_steps=10,
        evaluation_strategy="steps",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.lr,
        lr_scheduler_type=args.lr_scheduler,
        warmup_steps=args.num_warmup_steps,
        # warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        gradient_accumulation_steps=args.grad_accum,
        output_dir=output_dir,
        run_name=args.run_name,
        report_to="wandb",
        dataloader_num_workers=8,
        remove_unused_columns=False,
        label_names=["crystal_ids"], #this is just to get trainer to behave how I want
        # resume_from_checkpoint='/work/zd/crystal_llm/exp/batch_2Attibute_alldata_secondpre/7b-run_fullTrain_pre/checkpoint-100000/'
    )
    return training_args

def smart_tokenizer_and_embedding_resize(
    special_tokens_dict, 
    llama_tokenizer, 
    model,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(llama_tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

def setup_model(args, rank):
    llama_options = args.model_name.split("-")
    is_chat = len(llama_options) == 2
    model_size = llama_options[0]

    def llama2_model_string(model_size, chat):
        chat = "chat-" if chat else ""
        return f"meta-llama/Llama-2-{model_size.lower()}-{chat}hf"

    model_string = llama2_model_string(model_size, is_chat)
    model_string = 'Models/Llama-2-7b-h/Llama-2-7b-hf/'

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_storage=torch.bfloat16,
    )

    model = LlamaForCausalLM.from_pretrained(
        model_string,
        torch_dtype=torch.bfloat16,
        # load_in_8bit=args.fp8,
        device_map={"": rank},
        quantization_config=bnb_config,
        # use_flash_attention_2=False,
    )

    llama_tokenizer = LlamaTokenizer.from_pretrained(
        model_string,
        model_max_length=MAX_LENGTH,
        padding_side="right",
        use_fast=False,
    )

    lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=['q_proj', 'v_proj'],
    )


    model = get_peft_model(model, lora_config)
    # print(model)
    # exit()
    model.print_trainable_parameters()

    special_tokens_dict = dict()
    if llama_tokenizer.pad_token is None:
        special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
    if llama_tokenizer.eos_token is None:
        special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
    if llama_tokenizer.bos_token is None:
        special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
    if llama_tokenizer.unk_token is None:
        special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        llama_tokenizer=llama_tokenizer,
        model=model,
    )

    return model, llama_tokenizer

def setup_trainer(args):
    training_args = setup_training_args(args)
    model, llama_tokenizer = setup_model(args, training_args.local_rank)

    datasets = setup_datasets(args, llama_tokenizer)

    data_collator = DataCollatorForSupervisedDataset(
        tokenizer=llama_tokenizer, 
    )

    trainer = Trainer(
    # trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=datasets["train"],
        eval_dataset=datasets["val"],
        data_collator=data_collator,
    )
    # trainer.fit(model)
    return trainer

def main(args):


    trainer = setup_trainer(args)

    if args.resume_dir is not None:
        train_result = trainer.train(resume_from_checkpoint=args.resume_dir)
    else:
        train_result = trainer.train()

    print(train_result)
    trainer.save_state()
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    trainer.save_model(args.expdir / args.run_name), 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--run-name", type=str, required=True)
    parser.add_argument("--expdir", type=Path, default="exp")
    parser.add_argument("--model_name", default="7b")
    parser.add_argument("--fp8", action="store_true", default=True)
    parser.add_argument("--lora-rank", type=int, default=8)
    parser.add_argument("--lora-alpha", type=int, default=32)
    parser.add_argument("--lora-dropout", type=float, default=0.05)
    parser.add_argument("--data-path", type=Path, default="data/basic")
    parser.add_argument("--num-epochs", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--grad-accum", type=int, default=1)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lr-scheduler", type=str, default="cosine")
    parser.add_argument("--num-warmup-steps", type=int, default=100)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--eval-freq", default=1000, type=int)
    parser.add_argument("--save-freq", default=500, type=int)
    parser.add_argument("--format-permute-composition", action="store_true", default=False)
    parser.add_argument("--format-permute-structure", action="store_true", default=False)
    parser.add_argument("--w-attributes", type=int, default=1)
    parser.add_argument("--resume-dir", type=Path, default=None)
    parser.add_argument("--finetune-dir", type=Path, default=None)
    parser.add_argument("--debug", action="store_true", default=False)
    args = parser.parse_args()
    print(args.batch_size, args.w_attributes)
    print(args.expdir)
    main(args)

step2: set the resume checkpoint path(saved by trainer)

error:
  File "/home/wuzh/zd/GIT-Mol/crystal-text-llm-main/llama_finetune.py", line 522, in <module>
    main(args)
  File "/home/wuzh/zd/GIT-Mol/crystal-text-llm-main/llama_finetune.py", line 481, in main
    train_result = trainer.train(resume_from_checkpoint=args.resume_dir)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 823, in forward
    args, kwargs = _root_pre_forward(self, self, args, kwargs)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 558, in _root_pre_forward
    _lazy_init(state, module)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 173, in _lazy_init
    _share_state_and_init_handle_attrs(state, root_module)
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 261, in _share_state_and_init_handle_attrs
    _p_assert(
  File "/work/zd/anaconda/fsdp/lib/python3.9/site-packages/torch/distributed/utils.py", line 145, in _p_assert
    traceback.print_stack()
Non-root FSDP instance's `_is_root` should not have been set yet or should have been set to `False`

Expected behavior

load checkpoint without error

A-zhudong avatar Jul 10 '24 13:07 A-zhudong