keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

The LLaMA implementation by Keras Hub exhibits significant deviations in accuracy compared to the standard implementation (Hugging Face).

Open pass-lin opened this issue 1 year ago • 5 comments

import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"#chinese huggingface mirror source
os.environ["KERAS_BACKEND"] = "torch"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model_name = 'NousResearch/Meta-Llama-3.1-8B'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM,AutoConfig
import keras
hf_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                device_map="cuda:0",
                                             torch_dtype=torch.bfloat16, 
                                             _attn_implementation = 'eager',
                                             trust_remote_code=False).eval()
import keras_hub
keras.config.set_dtype_policy('bfloat16')
keras_model = keras_hub.models.Llama3CausalLM.from_preset('hf://'+model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids,mask = tokenizer('计算量决定了网络执行时间的长短,参数量决定了占用显存的量').values()
input_ids = keras.ops.expand_dims(input_ids,0)
mask = keras.ops.expand_dims(mask,0)

x1 = hf_model.forward(input_ids,attention_mask=mask)
x2 = keras_model([mask,input_ids])

error = keras.ops.abs(x1.logits-x2)

error = keras.ops.abs(x1.logits-x2)

print(keras.ops.max(error))
print(keras.ops.min(error))
print(keras.ops.mean(error))
print(keras.ops.std(error))

print(keras.ops.max(error,-1))
print(keras.ops.min(error,-1))
print(keras.ops.mean(error,-1))
print(keras.ops.std(error,-1))

The output is

tensor(3.2188, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MinBackward1>)
tensor(0.2441, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<ToCopyBackward0>)
tensor(0.2129, device='cuda:0', dtype=torch.bfloat16, grad_fn=<StdBackward0>)
tensor([[0.5938, 0.4062, 2.0938, 1.0781, 1.2188, 2.4062, 2.2812, 1.5625, 1.5234,
         1.3750, 1.4844, 2.9531, 2.3281, 1.7344, 2.4062, 2.1875, 2.4062, 3.2188,
         1.6953, 1.6250, 1.7969, 1.5078]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<AmaxBackward0>)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<AminBackward0>)
tensor([[0.1060, 0.0669, 0.3340, 0.1436, 0.1621, 0.3320, 0.2617, 0.2236, 0.2246,
         0.2090, 0.2422, 0.2490, 0.2930, 0.2637, 0.2500, 0.3066, 0.3574, 0.3340,
         0.2676, 0.2598, 0.2344, 0.2461]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)
tensor([[0.0679, 0.0510, 0.2578, 0.1138, 0.1270, 0.2617, 0.2080, 0.1719, 0.1768,
         0.1631, 0.1855, 0.2090, 0.2285, 0.2061, 0.2002, 0.2432, 0.2812, 0.2715,
         0.2090, 0.2031, 0.1816, 0.1904]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<StdBackward0>)

The implementations of standard implementation (Hugging Face) and Keras have significant accuracy differences in logits, both of which are based on the PyTorch backend to avoid framework-specific errors to a certain extent. In practical use, it's also observed that the LLaMA implementation by Keras Hub tends to have repetitive decoding more easily, while the implementations by HF and VLLM are less prone to repetitive decoding. Is it necessary to fix this precision difference?

pass-lin avatar Nov 27 '24 06:11 pass-lin

Upon further attempts, I found that not only does the issue exist under bf16, but a similar magnitude of error occurs under fp32 as well. Typically, we consider an error tolerance of 1e-5 or below to be acceptable under fp32, but here the error is significantly higher than that number. Therefore, the implementation of the llama model may have a considerable margin of error.

import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"#chinese huggingface mirror source
os.environ["KERAS_BACKEND"] = "torch"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
model_name = 'NousResearch/Meta-Llama-3.1-8B'
import torch
from transformers import AutoTokenizer,AutoModelForCausalLM,AutoConfig
import keras
hf_model = AutoModelForCausalLM.from_pretrained(model_name,
                                                device_map="cuda:1",
                                             torch_dtype=torch.float32, 
                                             _attn_implementation = 'eager',
                                             trust_remote_code=False).eval()
import keras_hub
#keras.config.set_dtype_policy('bfloat16')
keras_model = keras_hub.models.Llama3CausalLM.from_preset('hf://'+model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids,mask = tokenizer('计算量决定了网络执行时间的长短,参数量决定了占用显存的量').values()
input_ids = keras.ops.expand_dims(input_ids,0)
mask = keras.ops.expand_dims(mask,0)

x1 = hf_model.forward(input_ids.cuda("cuda:1"),attention_mask=mask.cuda("cuda:1"))
x2 = keras_model([mask,input_ids])

error = keras.ops.abs(x1.logits.cpu()-x2.cpu())

print(keras.ops.max(error))
print(keras.ops.min(error))
print(keras.ops.mean(error))
print(keras.ops.std(error))

print(keras.ops.max(error,-1))
print(keras.ops.min(error,-1))
print(keras.ops.mean(error,-1))
print(keras.ops.std(error,-1))
tensor(3.3085, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(0., device='cuda:0', grad_fn=<MinBackward1>)
tensor(0.2417, device='cuda:0', grad_fn=<MeanBackward1>)
tensor(0.2120, device='cuda:0', grad_fn=<StdBackward0>)
tensor([[0.4981, 0.5633, 1.9278, 1.0281, 0.9935, 2.5044, 2.2573, 1.5885, 1.5354,
         1.3483, 1.4797, 2.9066, 2.3571, 1.6378, 2.4488, 2.2407, 2.5110, 3.3085,
         1.7227, 1.6624, 1.7762, 1.5082]], device='cuda:0',
       grad_fn=<AmaxBackward0>)
tensor([[1.9670e-06, 4.7684e-07, 4.7684e-07, 0.0000e+00, 4.7684e-07, 1.4305e-06,
         1.5497e-06, 0.0000e+00, 4.7684e-07, 1.6689e-06, 2.3842e-06, 2.0266e-06,
         1.2398e-05, 1.1921e-07, 3.8147e-06, 9.0599e-06, 5.0068e-06, 4.5300e-06,
         2.3842e-07, 1.1921e-06, 2.6226e-06, 7.1526e-06]], device='cuda:0',
       grad_fn=<AminBackward0>)
tensor([[0.0929, 0.0694, 0.3153, 0.1340, 0.1513, 0.3295, 0.2608, 0.2181, 0.2219,
         0.2122, 0.2425, 0.2484, 0.2934, 0.2523, 0.2490, 0.3077, 0.3529, 0.3431,
         0.2756, 0.2599, 0.2309, 0.2557]], device='cuda:0',
       grad_fn=<MeanBackward1>)
tensor([[0.0589, 0.0522, 0.2429, 0.1061, 0.1178, 0.2618, 0.2068, 0.1680, 0.1737,
         0.1650, 0.1862, 0.2082, 0.2285, 0.1969, 0.1989, 0.2443, 0.2786, 0.2801,
         0.2143, 0.2034, 0.1799, 0.1978]], device='cuda:0',
       grad_fn=<StdBackward0>)

pass-lin avatar Dec 03 '24 11:12 pass-lin

Thanks! Will take a look.

mattdangerw avatar Dec 03 '24 21:12 mattdangerw

Thanks! Will take a look.

Hello, have you found out what might have caused it?

pass-lin avatar Feb 07 '25 10:02 pass-lin

Hi @pass-lin -

We have debug the issue and Here in Llama3CausalLM class model in keras_hub Llama3Backbone and Llama3CausalLMPreprocessor is used. In the Llama3Backbone seems pending to be implemented(Llama3Backbone) As here Llama3Backbone is subclass of LlamaBackbone. So due to LlamaBackbone class properties we can get the result but with accuracy deviations. We'll debug more and work on Llama3Backbone implementations soon. Thanks..!!!

mehtamansi29 avatar Feb 13 '25 15:02 mehtamansi29

@mehtamansi29 Hello, has this bug been fixed in the newly released 0.19?

pass-lin avatar Feb 27 '25 16:02 pass-lin