optimum-habana
optimum-habana copied to clipboard
Update Mixtral-8x7B Optimization
What does this PR do?
-
Update Mixtral-8x7B Optimization: reuse_cache / enable FP8 KV Cache / FP8 Attn / bucket_internal ...
-
Support long sequence prompt
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_size 128 \
--bucket_internal \
--max_new_tokens 100 \
--bf16 \
--batch_size 1
QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_internal \
--bucket_size 128 \
--max_new_tokens 100 \
--bf16 \
--fp8 \
--batch_size 2 \
--warmup 1 \
--n_iterations 1 \
--max_input_tokens 32000
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@schoi-habana , please provide details of how you optimized Falcon-180b fp8 for Jinyan to follow to add to this model. thanks
I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?
input 1: ('DeepSpeed is a machine learning framework',) output 1: ('DeepSpeed is a machine learning framework',)
@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR https://github.com/HabanaAI/optimum-habana-fork/pull/65 helps this model
I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?
input 1: ('DeepSpeed is a machine learning framework',) output 1: ('DeepSpeed is a machine learning framework',)
In 1.15 steup env, I didn't get this issue.
@jychen-habana , as we sync off-line:
- kv_cache_fp8 is the previous way to support fp8 inference which will be removed soon. All the models fp8 inference should use HQT.
- Your current code in this PR causes regression for HQT measurement.
fixed.
@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR HabanaAI#65 helps this model
Sure.
@jychen-habana , please post the performance measurements with/without this PR here.
@jychen-habana , please rebase to latest code in OH main branch
@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16
Error:
File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update qinput = self.quant_input_0(cur) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'
Do not merge! Will break this PR into small pieces: https://github.com/huggingface/optimum-habana/pull/898 https://github.com/huggingface/optimum-habana/pull/901 https://github.com/huggingface/optimum-habana/pull/903
@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16
Error:
File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update qinput = self.quant_input_0(cur) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'
Please add --reuse_kvcache when measure with bf16, from my understanding, because kvcache need to be an 'nn.Module', then it could be measured.
For quantization mode, it's fine to just remove --reuse_cache.
Or if there is any solution, please let me know