tensor_parallel
tensor_parallel copied to clipboard
Max Recursion Error when using with lora
I get the following error when attempting to use LoRa with Llama 2
File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
return getattr(self.tp_wrapped_module, attr)
[Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded
caused by the peft
module executing: if getattr(model, "is_gradient_checkpointing", True):
Below is the minimal reproducible example that breaks when using tensor parallel and works when disabling it
import os
import functools
import torch
from transformers import LlamaTokenizer, LlamaConfig, LlamaForCausalLM
import tensor_parallel as tp
from peft import get_peft_model, LoraConfig
USE_TENSOR_PARALLEL = True
LLAMA_HF_PATH = "./models/llama2/llama_hf_converted/7b"
def spawn(main_fn, world_size):
wrapped = functools.partial(_wrap_main_fn, main_fn=main_fn)
torch.multiprocessing.spawn(wrapped, args=(world_size, ), nprocs=world_size, daemon=True)
def _wrap_main_fn(rank, world_size, main_fn):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12357'
torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)
main_fn(rank)
torch.distributed.destroy_process_group()
def get_model(device):
torch.set_default_dtype(torch.bfloat16)
config = LlamaConfig.from_pretrained(LLAMA_HF_PATH)
model = LlamaForCausalLM.from_pretrained(LLAMA_HF_PATH, config=config)
model.half()
torch.cuda.set_device(device)
if USE_TENSOR_PARALLEL:
tpmodel = tp.tensor_parallel(model, [device], distributed=True)
model = tpmodel[0]
lora_target_modules = ['q_proj.tp_wrapped_module', 'v_proj.tp_wrapped_module'] # target modules when model is wrapped by tp
else:
lora_target_modules = ['q_proj', 'v_proj']
print('Before lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
peft_config = LoraConfig(
# inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
inference_mode=False, r=int(8), lora_alpha=int(32), lora_dropout=float(0.01),
target_modules=lora_target_modules,
)
model_peft = get_peft_model(model, peft_config)
print('After lora, is_gradient_checkpointing=', getattr(model, "is_gradient_checkpointing", None))
return model_peft
def main_fn(rank):
devices = ['cuda:0', 'cuda:1']
get_model(devices[rank])
if __name__ == '__main__':
spawn(main_fn, world_size=2)
When settings USE_TENSOR_PARALLEL = False
the code works, but when setting USE_TENSOR_PARALLEL = True
I get the following error:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/path/libraries/conda/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/path/projects/silos/TEST_tensor_parallel.py", line 24, in _wrap_main_fn
main_fn(rank)
File "/path/projects/silos/TEST_tensor_parallel.py", line 56, in main_fn
get_model(devices[rank])
File "/path/projects/silos/TEST_tensor_parallel.py", line 49, in get_model
model_peft = get_peft_model(model, peft_config)
File "/path/libraries/conda/lib/python3.9/site-packages/peft/mapping.py", line 105, in get_peft_model
return PeftModel(model, peft_config, adapter_name=adapter_name)
File "/path/libraries/conda/lib/python3.9/site-packages/peft/peft_model.py", line 120, in __init__
if getattr(model, "is_gradient_checkpointing", True):
File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in is_gradient_checkpointing
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
File "/path/libraries/conda/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1745, in <genexpr>
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
return getattr(self.tp_wrapped_module, attr)
File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
return getattr(self.tp_wrapped_module, attr)
File "/path/libraries/conda/lib/python3.9/site-packages/tensor_parallel/wrapper.py", line 75, in __getattr__
return getattr(self.tp_wrapped_module, attr)
[Previous line repeated 2979 more times]
RecursionError: maximum recursion depth exceeded
I've identified that the error happens exactly in the lora library at this setattr
line:
https://github.com/huggingface/peft/blob/52ff0cde9f2cc64059e171c2cfd94512914c85df/src/peft/tuners/lora/model.py#L225
When setattr(parent, child_name, new_module)
is executed, and parent is a tensor parallel wrapper, child_name is a string "tp_wrapped_module", new_module is a lora linear layer.
I think I fixed it.