llm-foundry
llm-foundry copied to clipboard
How to install MPT-7B?
Hey, I wanted to use the HF transformer library for the currently fastest inference possible (triton). I am trying to use https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py
python test.py --temperature 1.0 \ (transformers-direct) --name_or_path "../../../text/models/mpt-7b" \
--top_p 0.95 \
--top_k 50 \
--seed 1 \
--max_new_tokens 256 \
--attn_impl triton \
--prompts \
"The answer to life, the universe, and happiness is" \
"MosaicML is an ML training efficiency startup that is known for" \
"Here's a quick recipe for baking chocolate chip cookies: Start by" \
"The best 5 cities to visit in Europe are" \
--device cuda \
--trust_remote_code
However, I only receive this error (I installed the latest flash_attn because I can not install the exact version provided) [and it takes ages to load before the "Loading shard" appears]:
Loading HF Config...
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Loading HF model to device=cuda and dtype=torch.bfloat16...
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py:144: UserWarning: While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. When training larger models this can trigger alloc retries which hurts performance. If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.
warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.07s/it]
n_params=6649286656
Loading HF tokenizer...
/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py:167: UserWarning: pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.
warnings.warn(
Generate kwargs:
{'max_new_tokens': 256, 'temperature': 1.0, 'top_p': 0.95, 'top_k': 50, 'use_cache': True, 'do_sample': True, 'eos_token_id': 0, 'pad_token_id': 0}
Tokenizing prompts...
NOT using autocast...
Warming up...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Traceback (most recent call last):
File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-3d28bf8dd4111863b189d74cf84b730b-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-2d732a2488b7ed996facc3e641ee56bf-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 937, in build_triton_ir
generator.visit(fn.parse())
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 415, in generic_visit
self.visit(item)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 678, in visit_For
self.visit_compound_statement(node.body)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 319, in visit_AugAssign
self.visit(assign)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 301, in visit_Assign
values = self.visit(node.value)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 339, in visit_BinOp
rhs = self.visit(node.right)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
return visitor(node)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 797, in visit_Call
return fn(*args, _builder=self.builder, **kws)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/impl/base.py", line 22, in wrapper
return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 269, in <module>
main(parse_args())
File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 218, in main
_ = _generate(encoded_inp)
File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 209, in _generate
return model.generate(
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/transformers/generation/utils.py", line 1485, in generate
return self.sample(
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/transformers/generation/utils.py", line 2524, in sample
outputs = self(
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 237, in forward
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 183, in forward
(x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/blocks.py", line 36, in forward
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py", line 171, in forward
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py", line 111, in triton_flash_attn_fn
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
_fwd_kernel[grid](
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 199, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 41, in _fwd_kernel
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 1621, in compile
next_module = compile(module)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 1550, in <lambda>
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 78:24:
def _fwd_kernel(
Q, K, V, Bias, Out,
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^
I can run the model with torch attention, but I only get around 10 to 20 tokens/second with an 3090. As I asked in another issue, FasterTransformer inference support will probably take months to finish.
On the other hand, what can I do to make triton (the advertised flash attention support) to work? Is it even possible to run it in triton mode?
I really need the high performance and I am super grateful for any help! :)
same issue here, triton.compiler.CompilationError: at 78:24: during training
change triton_flash_attention to flash_attention in yaml is a solution of this problem~
change triton_flash_attention to flash_attention in yaml is a solution of this problem~
flash attn does not support alibi. you're technically not running the correct model...
What version of torch are you using? are you following the requirements and installation instructions? (which say use torch1.13)
(this is an issue you would have if you are using torch2)
torch 1.13 right, not 2.0
it is works only with triton dev version? official triton 2.0 doesn’t works?
the triton installed with torch2 made breaking changes to the triton version we need for our training setup. please use torch1.13 as instructed in the requirements. torch2 is NOT supported yet.
(we're working to enable torch2 here)
Note: our setup does install 2 versions of triton, please follow the install instructions
Closing issue, if you still have issues, feel free to re-open the issue.
great stuff, I just want to say thanks, the triton flash attention works with:
torch==1.13.1+cu117
triton-pre-mlir @ git+https://github.com/vchiley/triton.git@2dd3b957698a39bbca615c02a447a98482c144a3#subdirectory=python
and flash-attn==v1.0.3.post0
installed everything from setup.py(gpu setup) with this docker mosaicml/pytorch:latest