tensor_parallel icon indicating copy to clipboard operation
tensor_parallel copied to clipboard

Max Recursion Error when using with lora

Open Ar-Kareem opened this issue 10 months ago • 2 comments

I get the following error when attempting to use LoRa with Llama 2


File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  [Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded

caused by the peft module executing: if getattr(model, "is_gradient_checkpointing", True):

Below is the minimal reproducible example that breaks when using tensor parallel and works when disabling it

import os
import functools
import torch
from transformers import LlamaTokenizer, LlamaConfig, LlamaForCausalLM
import tensor_parallel as tp
from peft import get_peft_model, LoraConfig

USE_TENSOR_PARALLEL = True
LLAMA_HF_PATH = "./models/llama2/llama_hf_converted/7b"


def spawn(main_fn, world_size):
    wrapped = functools.partial(_wrap_main_fn, main_fn=main_fn)
    torch.multiprocessing.spawn(wrapped, args=(world_size, ), nprocs=world_size, daemon=True)

def _wrap_main_fn(rank, world_size, main_fn):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12357'
    torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
    main_fn(rank)
    torch.distributed.destroy_process_group()

def get_model(device):
    torch.set_default_dtype(torch.bfloat16)
    config = LlamaConfig.from_pretrained(LLAMA_HF_PATH)
    model = LlamaForCausalLM.from_pretrained(LLAMA_HF_PATH, config=config)
    model.half()

    torch.cuda.set_device(device)
    if USE_TENSOR_PARALLEL:
        tpmodel = tp.tensor_parallel(model, [device], distributed=True)
        model = tpmodel[0]
        lora_target_modules = ['q_proj.tp_wrapped_module', 'v_proj.tp_wrapped_module']  # target modules when model is wrapped by tp
    else:
        lora_target_modules = ['q_proj', 'v_proj']

    print('Before lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
    peft_config = LoraConfig(
        # inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
        inference_mode=False, r=int(8), lora_alpha=int(32), lora_dropout=float(0.01),
        target_modules=lora_target_modules,
    )
    model_peft = get_peft_model(model, peft_config)
    print('After lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
    return model_peft


def main_fn(rank):
    devices = ['cuda:0', 'cuda:1']
    get_model(devices[rank])

if __name__ == '__main__':
    spawn(main_fn, world_size=2)

When settings USE_TENSOR_PARALLEL = False the code works, but when setting USE_TENSOR_PARALLEL = True I get the following error:

-- Process 0 terminated with the following error:                                                                                                                                   
Traceback (most recent call last):                                                                                                                                                  
  File "/path/libraries/conda/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap                                                            
    fn(i, *args)                                                                                                                                                                    
  File "/path/projects/silos/TEST_tensor_parallel.py", line 24, in _wrap_main_fn                                                                                        
    main_fn(rank)                                                                                                                                                                   
  File "/path/projects/silos/TEST_tensor_parallel.py", line 56, in main_fn                                                                                              
    get_model(devices[rank])                                                                                                                                                        
  File "/path/projects/silos/TEST_tensor_parallel.py", line 49, in get_model
    model_peft = get_peft_model(model, peft_config)
  File "/path/libraries/conda/lib/python3.9/site-packages/peft/mapping.py", line 105, in get_peft_model
    return PeftModel(model, peft_config, adapter_name=adapter_name)
  File "/path/libraries/conda/lib/python3.9/site-packages/peft/peft_model.py", line 120, in __init__
    if getattr(model, "is_gradient_checkpointing", True):
  File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in is_gradient_checkpointing
    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
    return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
    return getattr(self.tp_wrapped_module, attr)
  [Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded

Ar-Kareem avatar Oct 02 '23 05:10 Ar-Kareem