llama icon indicating copy to clipboard operation
llama copied to clipboard

LLaMA 13B and 70B fail on CPU with BF16

Open Yufeng98 opened this issue 1 year ago • 3 comments

LLaMA 7B runs well on CPU with both BF16 and FP32. But LLaMA 13B and 70B only work on CPU with FP32.

The error for LLaMA 13B and 70B with BF16 comes from embedding and the RuntimeError is Invalid scalar type.

Traceback (most recent call last):                                                                                                                                                                                  
  File "/data1/llama-cpu/example_text_completion.py", line 70, in <module>                                                                                                                                 
    fire.Fire(main)                                                                                                                                                                                                 
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 141, in Fire                                                                                                                       
    component_trace = _Fire(component, args, parsed_flag_args, context, name)                                                                                                                                       
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire                                                                                                                      
    component, remaining_args = _CallAndUpdateTrace(                                                                                                                                                                
  File "/home/anaconda3/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace                                                                                                        
    component = fn(*varargs, **kwargs)                                                                                                                                                                              
  File "/data1/llama-cpu/example_text_completion.py", line 57, in main                                                                                                                                     
    results = generator.text_completion(                                                                                                                                                                            
  File "/data1/llama-cpu/llama/generation.py", line 264, in text_completion                                                                                                                                
    generation_tokens, generation_logprobs = self.generate(                                                                                                                                                         
  File "/data1/llama-cpu/llama/generation.py", line 181, in generate                                                                                                                                       
    logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)                                                                                                                                              
  File "/data1/llama-cpu/llama/model.py", line 471, in forward                                                                                                                                             
    h = self.tok_embeddings(tokens)                                                                                                                                                                                 
  File "/home/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                                                                  
    return forward_call(*input, **kwargs)
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/layers.py", line 214, in forward                                                                                           
    output = gather_from_model_parallel_region(output_parallel)                              
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 156, in gather_from_model_parallel_region
    return _GatherFromModelParallelRegion.apply(input_)                                                   
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 131, in forward                                                                                         
    return _gather(input_)
  File "/home/anaconda3/lib/python3.9/site-packages/fairscale/nn/model_parallel/mappings.py", line 82, in _gather                                                                                          
    torch.distributed.all_gather(tensor_list, input_, group=group)                                        
  File "/home/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 2075, in all_gather
    work.wait()                                                                                           
RuntimeError: Invalid scalar type 

Yufeng98 avatar Sep 10 '23 22:09 Yufeng98

i got the same issue. Is there any update or workaround now?

jwyang-google avatar Sep 23 '23 03:09 jwyang-google

any thoughts on this @osalpekar ?

jspisak avatar Oct 11 '23 18:10 jspisak

Just FYI. I was able to run both 13B and 70B with bfloat16 on CPU. I had tried the Llama2 Chat Model from Hugging Face.

dineshchitlangia avatar Nov 14 '23 02:11 dineshchitlangia