llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

How to install MPT-7B?

Open SinanAkkoyun opened this issue 2 years ago • 2 comments

Hey, I wanted to use the HF transformer library for the currently fastest inference possible (triton). I am trying to use https://github.com/mosaicml/llm-foundry/blob/main/scripts/inference/hf_generate.py

 python test.py --temperature 1.0 \                                    (transformers-direct)                                                                                                              --name_or_path "../../../text/models/mpt-7b" \
--top_p 0.95 \
--top_k 50 \
--seed 1 \
--max_new_tokens 256 \
--attn_impl triton \
--prompts \
"The answer to life, the universe, and happiness is" \
"MosaicML is an ML training efficiency startup that is known for" \
"Here's a quick recipe for baking chocolate chip cookies: Start by" \
"The best 5 cities to visit in Europe are" \
--device cuda \
--trust_remote_code

However, I only receive this error (I installed the latest flash_attn because I can not install the exact version provided) [and it takes ages to load before the "Loading shard" appears]:

Loading HF Config...
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Loading HF model to device=cuda and dtype=torch.bfloat16...
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py:144: UserWarning: While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. When training larger models this can trigger alloc retries which hurts performance. If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.
  warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.07s/it]
n_params=6649286656

Loading HF tokenizer...
/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py:167: UserWarning: pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.
  warnings.warn(

Generate kwargs:
{'max_new_tokens': 256, 'temperature': 1.0, 'top_p': 0.95, 'top_k': 50, 'use_cache': True, 'do_sample': True, 'eos_token_id': 0, 'pad_token_id': 0}

Tokenizing prompts...
NOT using autocast...
Warming up...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0-3d28bf8dd4111863b189d74cf84b730b-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-2d732a2488b7ed996facc3e641ee56bf-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 937, in build_triton_ir
    generator.visit(fn.parse())
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 183, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 415, in generic_visit
    self.visit(item)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 678, in visit_For
    self.visit_compound_statement(node.body)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 319, in visit_AugAssign
    self.visit(assign)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 301, in visit_Assign
    values = self.visit(node.value)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/ast.py", line 407, in visit
    return visitor(node)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 797, in visit_Call
    return fn(*args, _builder=self.builder, **kws)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/impl/base.py", line 22, in wrapper
    return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 269, in <module>
    main(parse_args())
  File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 218, in main
    _ = _generate(encoded_inp)
  File "/home/ubuntu/ml/llm/inference/hf/transformers-direct/test.py", line 209, in _generate
    return model.generate(
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/transformers/generation/utils.py", line 1485, in generate
    return self.sample(
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/transformers/generation/utils.py", line 2524, in sample
    outputs = self(
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 237, in forward
    outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/modeling_mpt.py", line 183, in forward
    (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/blocks.py", line 36, in forward
    (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py", line 171, in forward
    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
  File "/home/ubuntu/.cache/huggingface/modules/transformers_modules/mpt-7b/attention.py", line 111, in triton_flash_attn_fn
    attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 199, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 1621, in compile
    next_module = compile(module)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 1550, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 962, in ast_to_ttir
    mod, _ = build_triton_ir(fn, signature, specialization, constants)
  File "/home/ubuntu/.anaconda3/envs/transformers-direct/lib/python3.9/site-packages/triton/compiler.py", line 942, in build_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 78:24:
def _fwd_kernel(
    Q, K, V, Bias, Out,
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_bb, stride_bh, stride_bm,
    stride_ob, stride_oh, stride_om,
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
    BIAS_TYPE: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^

SinanAkkoyun avatar May 14 '23 00:05 SinanAkkoyun

I can run the model with torch attention, but I only get around 10 to 20 tokens/second with an 3090. As I asked in another issue, FasterTransformer inference support will probably take months to finish.

On the other hand, what can I do to make triton (the advertised flash attention support) to work? Is it even possible to run it in triton mode?

I really need the high performance and I am super grateful for any help! :)

SinanAkkoyun avatar May 14 '23 01:05 SinanAkkoyun

same issue here, triton.compiler.CompilationError: at 78:24: during training

germanjke avatar May 15 '23 10:05 germanjke

change triton_flash_attention to flash_attention in yaml is a solution of this problem~

germanjke avatar May 18 '23 10:05 germanjke

change triton_flash_attention to flash_attention in yaml is a solution of this problem~

flash attn does not support alibi. you're technically not running the correct model...

vchiley avatar May 18 '23 21:05 vchiley

What version of torch are you using? are you following the requirements and installation instructions? (which say use torch1.13)

(this is an issue you would have if you are using torch2)

vchiley avatar May 18 '23 21:05 vchiley

torch 1.13 right, not 2.0

germanjke avatar May 18 '23 21:05 germanjke

it is works only with triton dev version? official triton 2.0 doesn’t works?

germanjke avatar May 18 '23 21:05 germanjke

the triton installed with torch2 made breaking changes to the triton version we need for our training setup. please use torch1.13 as instructed in the requirements. torch2 is NOT supported yet.

(we're working to enable torch2 here)

vchiley avatar May 19 '23 15:05 vchiley

torch2 now works 🥳

Note: our setup does install 2 versions of triton, please follow the install instructions

Closing issue, if you still have issues, feel free to re-open the issue.

vchiley avatar May 19 '23 22:05 vchiley

great stuff, I just want to say thanks, the triton flash attention works with: torch==1.13.1+cu117 triton-pre-mlir @ git+https://github.com/vchiley/triton.git@2dd3b957698a39bbca615c02a447a98482c144a3#subdirectory=python and flash-attn==v1.0.3.post0

installed everything from setup.py(gpu setup) with this docker mosaicml/pytorch:latest

germanjke avatar May 21 '23 22:05 germanjke