llama3 icon indicating copy to clipboard operation
llama3 copied to clipboard

RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16'

Open yuanlaishihaoge opened this issue 2 months ago • 8 comments

(algo_python38) root@4347dc632bb3:/data/data/llama3-main# torchrun --nproc_per_node 1 example_chat_completion.py --ckpt_dir Meta-Llama-3-8B/ --tokenizer_path Meta-Llama-3-8B/tokenizer.model --max_seq_len 512 --max_batch_size 6

initializing model parallel with size 1 initializing ddp with size 1 initializing pipeline with size 1 Loaded in 14.34 seconds Traceback (most recent call last): File "example_chat_completion.py", line 58, in fire.Fire(main) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/fire/core.py", line 143, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/fire/core.py", line 477, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "example_chat_completion.py", line 41, in main results = generator.chat_completion( File "/data/data/llama3-main/llama/generation.py", line 309, in chat_completion generation_tokens, generation_logprobs = self.generate( File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/data/data/llama3-main/llama/generation.py", line 176, in generate logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/data/data/llama3-main/llama/model.py", line 290, in forward mask = torch.triu(mask, diagonal=1) RuntimeError: "triu_tril_cuda_template" not implemented for 'BFloat16' ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 201) of binary: /opt/conda/envs/algo_python38/bin/python Traceback (most recent call last): File "/opt/conda/envs/algo_python38/bin/torchrun", line 33, in sys.exit(load_entry_point('torch==1.13.1', 'console_scripts', 'torchrun')()) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper return f(*args, **kwargs) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/distributed/run.py", line 762, in main run(args) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/distributed/run.py", line 753, in run elastic_launch( File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 132, in call return launch_agent(self._config, self._entrypoint, list(args)) File "/opt/conda/envs/algo_python38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ example_chat_completion.py FAILED


Failures: <NO_OTHER_FAILURES>

Root Cause (first observed failure): [0]: time : 2024-04-22_13:29:53 host : 4347dc632bb3 rank : 0 (local_rank: 0) exitcode : 1 (pid: 201) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

yuanlaishihaoge avatar Apr 22 '24 05:04 yuanlaishihaoge

你是改了代码,用半精度加载的吗

jidandan666 avatar Apr 22 '24 09:04 jidandan666

Same issue. Anybody knows how to solve it?

SorasakiHiina avatar Apr 22 '24 09:04 SorasakiHiina

I got the same issue

huhuhu5798 avatar Apr 22 '24 10:04 huhuhu5798

I got the same issue

ghLcd9dG avatar Apr 23 '24 03:04 ghLcd9dG

update torch or eidt llama/generation.py

class Llama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None,
        seed: int = 1,
    ) -> "Llama":
        ...
        assert model_args.vocab_size == tokenizer.n_words
        if torch.cuda.is_bf16_supported():
            #torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        else:
            torch.set_default_tensor_type(torch.cuda.HalfTensor)

        ...

lifetruth-liu avatar Apr 23 '24 03:04 lifetruth-liu

llama/generation.py

class Llama:
    @staticmethod
    def build(
        ckpt_dir: str,
        tokenizer_path: str,
        max_seq_len: int,
        max_batch_size: int,
        model_parallel_size: Optional[int] = None,
        seed: int = 1,
    ) -> "Llama":
        ...
        assert model_args.vocab_size == tokenizer.n_words
        if torch.cuda.is_bf16_supported():
            #torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        else:
            torch.set_default_tensor_type(torch.cuda.HalfTensor)

        ...

it works! thanks

ghLcd9dG avatar Apr 23 '24 03:04 ghLcd9dG

Let me summarize it. It was owing to the fact that triu_tril_cuda_template was implemented for BFfloat in torch 2.1.0 and version later than that. Reference: https://github.com/huggingface/diffusers/issues/3453 So, basically you have two method to solve it.

  1. update your torch to version 2.10 and older
  2. in generation.py, set it to half tensor torch.set_default_tensor_type(torch.cuda.HalfTensor)

ghLcd9dG avatar Apr 23 '24 03:04 ghLcd9dG

torch.set_default_tensor_type(torch.cuda.HalfTensor)

i have same problem when i train llama3, in modeling_llama.py 1095:

            causal_mask = torch.triu(causal_mask, diagonal=1)

i fix this by :

           causal_mask = causal_mask.to(torch.float32)#改
            causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask = causal_mask.to('cuda', dtype=torch.bfloat16)#改

i pretrain the base model using chinese data, but the result is very bad, i don't know my operation damage the precision, can anyone help me?

cooper12121 avatar Apr 26 '24 07:04 cooper12121