optimum-habana icon indicating copy to clipboard operation
optimum-habana copied to clipboard

Update Mixtral-8x7B Optimization

Open jychen21 opened this issue 10 months ago • 12 comments

What does this PR do?

  • Update Mixtral-8x7B Optimization: reuse_cache / enable FP8 KV Cache / FP8 Attn / bucket_internal ...

  • Support long sequence prompt

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1  \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_size 128 \
--bucket_internal \
--max_new_tokens 100 \
--bf16 \
--batch_size 1

QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generation.py \
--model_name_or_path mistralai/Mixtral-8x7B-v0.1  \
--use_hpu_graphs \
--use_kv_cache \
--limit_hpu_graphs \
--reuse_cache \
--bucket_internal \
--bucket_size 128 \
--max_new_tokens 100 \
--bf16 \
--fp8 \
--batch_size 2 \
--warmup 1 \
--n_iterations 1 \
--max_input_tokens 32000

image

jychen21 avatar Mar 26 '24 07:03 jychen21

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@schoi-habana , please provide details of how you optimized Falcon-180b fp8 for Jinyan to follow to add to this model. thanks

mandy-li avatar Mar 29 '24 16:03 mandy-li

I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?

input 1: ('DeepSpeed is a machine learning framework',) output 1: ('DeepSpeed is a machine learning framework',)

schoi-habana avatar Apr 05 '24 23:04 schoi-habana

@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR https://github.com/HabanaAI/optimum-habana-fork/pull/65 helps this model

schoi-habana avatar Apr 08 '24 06:04 schoi-habana

I tested this PR with run_generation.py in 1.16.0 docker. It could fit 30k input tokens but the generated output was empty. Did you check the output?

input 1: ('DeepSpeed is a machine learning framework',) output 1: ('DeepSpeed is a machine learning framework',)

In 1.15 steup env, I didn't get this issue.

jychen21 avatar Apr 09 '24 02:04 jychen21

@jychen-habana , as we sync off-line:

  1. kv_cache_fp8 is the previous way to support fp8 inference which will be removed soon. All the models fp8 inference should use HQT.
  2. Your current code in this PR causes regression for HQT measurement.

fixed.

jychen21 avatar Apr 09 '24 02:04 jychen21

@jychen-habana after you implement ScopedLinearAllreduce, please see if in-place addition in this PR HabanaAI#65 helps this model

Sure.

jychen21 avatar Apr 09 '24 02:04 jychen21

@jychen-habana , please post the performance measurements with/without this PR here.

mandy-li avatar Apr 16 '24 15:04 mandy-li

@jychen-habana , please rebase to latest code in OH main branch

mandy-li avatar Apr 16 '24 17:04 mandy-li

@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16

Error:

File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update qinput = self.quant_input_0(cur) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'

mandy-li avatar Apr 16 '24 20:04 mandy-li

Do not merge! Will break this PR into small pieces: https://github.com/huggingface/optimum-habana/pull/898 https://github.com/huggingface/optimum-habana/pull/901 https://github.com/huggingface/optimum-habana/pull/903

jychen21 avatar Apr 18 '24 05:04 jychen21

@jychen-habana , this PR doesn't work with Synapse 1.15 release docker when measurement.

QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --model_name_or_path /mnt/weka/data/mixtral/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/1e637f2d7cb0a9d6fb1922f305cb784995190a83/ --use_hpu_graphs --use_kv_cache --limit_hpu_graphs --bucket_size 128 --max_new_tokens 128 --batch_size 1 --bf16

Error:

File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 787, in forward outputs = self.model( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 692, in forward layer_outputs = decoder_layer( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 518, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1564, in _call_impl result = forward_call(*args, **kwargs) File "/home/jwang/test/optimum-habana-jychen/optimum/habana/transformers/models/mixtral/modeling_mixtral.py", line 356, in forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/usr/local/lib/python3.10/dist-packages/habana_quantization_toolkit/_quant_common/helper_modules.py", line 264, in update qinput = self.quant_input_0(cur) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1691, in getattr raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'") AttributeError: 'PatchedKVCache' object has no attribute 'quant_input_0'

Please add --reuse_kvcache when measure with bf16, from my understanding, because kvcache need to be an 'nn.Module', then it could be measured.

For quantization mode, it's fine to just remove --reuse_cache.

Or if there is any solution, please let me know

jychen21 avatar Apr 18 '24 05:04 jychen21