Long-Context-Data-Engineering
Long-Context-Data-Engineering copied to clipboard
ValueError: TensorParallelPreTrainedModel does not support Flash Attention 2.0 yet.
Hi, when I utilize the tensor-parallel package as the repo indicates:
model = transformers.LlamaForCausalLM.from_pretrained(model_path, use_flash_attention_2="flash_attention_2", torch_dtype=torch.bfloat16)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
# This is the continue pretrained LLaMA 2 7B model with modified rope
def reset_rope(model, model_max_train_len, scaling_factor):
for l in model.model.layers:
l.self_attn.rotary_emb.scaling_factor = scaling_factor
# l.self_attn.rotary_emb._set_cos_sin_cache(seq_len=model_max_train_len, device="cpu", dtype=torch.float32)
return
scaling_factor = 10 # hardcode here
reset_rope(model, model_max_train_len=81920, scaling_factor=scaling_factor)
model = tp.tensor_parallel(model, sharded=True)
I meet the bug:
The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████| 3/3 [00:01<00:00, 1.66it/s]
Traceback (most recent call last):
File "/vepfs/wcf/G/zecheng/modelzipper/projects/state-space-model/src/test_passkey_search.py", line 81, in <module>
fire.Fire(main)
File "/opt/conda/envs/pan/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/opt/conda/envs/pan/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/opt/conda/envs/pan/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/vepfs/wcf/G/zecheng/modelzipper/projects/state-space-model/src/test_passkey_search.py", line 39, in main
model = tp.tensor_parallel(model, sharded=True)
File "/opt/conda/envs/pan/lib/python3.10/site-packages/tensor_parallel/factory.py", line 61, in tensor_parallel
return TensorParallelPreTrainedModel(
File "/opt/conda/envs/pan/lib/python3.10/site-packages/tensor_parallel/pretrained_model.py", line 47, in __init__
super().__init__(module.config) # Temporary empty config. Gets replaced in from_pretrained
File "/opt/conda/envs/pan/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1230, in __init__
config = self._autoset_attn_implementation(
File "/opt/conda/envs/pan/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1377, in _autoset_attn_implementation
cls._check_and_enable_flash_attn_2(
File "/opt/conda/envs/pan/lib/python3.10/site-packages/transformers/modeling_utils.py", line 1458, in _check_and_enable_flash_attn_2
raise ValueError(
ValueError: TensorParallelPreTrainedModel does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted,