Poor Performance of W4A4 on MMLU task
I evaluated the MMLU performance of LLaMA 3.1-8B and LLaMA 2-7B models under ABQ W4A4 quantization. The results are as follows, showing a significant drop in MMLU performance.
| Model | Precision | wikitext2 | mmlu |
|---|---|---|---|
| llama3.1-8b-instruct | FP16 | 7.2129 | 68.09 |
| W4A4 | 18.596 | 26.31 | |
| llama2-7b | FP16 | 5.472 | 41.80 |
| W4A4 | 9.220 | 25.12 |
# cache scale and shift
python generate_act_scale_shift.py --model Meta-Llama-3.1-8B-Instruct
# weight-activation quantization
python main.py \
--model Meta-Llama-3.1-8B-Instruct \
--epochs 20 --output_dir $OUTPUT_DIR --save_dir $SAVE_DIR\
--wbits 4 --abits 4 --lwc --let
eval script in main.evaluate
import lm_eval
from lm_eval import utils as lm_eval_utils
from lm_eval.models.huggingface import HFLM
hflm = HFLM(pretrained=lm.model, tokenizer=lm.tokenizer, batch_size=8)
task_manager = lm_eval.tasks.TaskManager(include_path="/usr/local/lib/python3.10/dist-packages/lm_eval/tasks/", include_defaults=False)
task_names = lm_eval_utils.pattern_match(args.tasks, task_manager.all_tasks)
results = {}
# import pdb; pdb.set_trace()
task_names = ['mmlu']
for task_name in task_names:
logger.info(f"Evaluating {task_name}...")
result = lm_eval.simple_evaluate(hflm, tasks=[task_name], batch_size=8, task_manager=task_manager)['results']
result = result[task_name]
acc = round(result.get('acc_norm,none', result['acc,none']) * 100, 2)
results[task_name] = acc
logger.info(f"acc: {acc}%")
metric_vals = {task: result for task, result in results.items()}
metric_vals['acc_avg'] = round(sum(metric_vals.values()) / len(metric_vals.values()), 2)
logger.info(metric_vals)
In practice, W4A4 low-bit quantization algorithms optimized under the "smooth paradigm" often encounter performance bottlenecks on complex evaluation tasks, particularly when applied to cutting-edge models such as LLaMA-3.1. These state-of-the-art models, typically trained on hundreds of billions of tokens, demand higher representational precision, thereby exacerbating the challenges of low-bit quantization. Similar approaches, including OmniQuant, I-LLM, and AffineQuant, also exhibit significant degradation in accuracy under such settings. To address this issue, we recommend replacing the conventional per-channel or per-token quantization strategies with a per-group quantization scheme, which better balances precision and efficiency across layers and leads to improved performance on complex reasoning benchmarks such as MMLU and GSM8K.