InfLLM
InfLLM copied to clipboard
IndexErrors when attempting to run triton flashattention
Hi, I am attempting to run the mistral-inf-llm-fattn on a single v100, but I am getting indexerrors. Do you have any indicators on what the problem might be? Below is the full output:
Exception has occurred: IndexError
map::at
File "/home/aoomerjee/EM-LLM/inf_llm/attention/dot_production_attention/triton_impl.py", line 430, in _forward
_attn_fwd[grid](
File "/home/aoomerjee/EM-LLM/inf_llm/attention/dot_production_attention/triton_impl.py", line 534, in append
o, m, l = _forward(
^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/attention/inf_llm_context_manager.py", line 555, in _retrieve_and_attend
attn.append(
File "/home/aoomerjee/EM-LLM/inf_llm/attention/inf_llm_context_manager.py", line 779, in append
exc_block_attn_output, exc_repr_score = self._retrieve_and_attend(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/attention/inf_llm.py", line 64, in forward
o = past_key_value.append(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/utils/patch_hf.py", line 19, in hf_forward
ret = forward(
^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/utils/patch_hf.py", line 87, in model_forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/utils/greedy_search.py", line 46, in _model_pass
out = self.model(
^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/utils/greedy_search.py", line 74, in _decode
out = self._model_pass(
^^^^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/inf_llm/utils/greedy_search.py", line 32, in generate
result = self._decode(input_ids, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/benchmark/pred.py", line 259, in get_pred
output = searcher.generate(
^^^^^^^^^^^^^^^^^^
File "/home/aoomerjee/EM-LLM/benchmark/pred.py", line 324, in <module>
preds = get_pred(
^^^^^^^^^
IndexError: map::at
It seems that Triton 2.2.0 does not support V100.
Try pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly==2.1.0.dev20231014192330.
And change the torch_dtype to torch.half.