text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Support for mosaicml/mpt-30b-instruct model

Open maziyarpanahi opened this issue 2 years ago • 17 comments
trafficstars

Feature request

I was wondering if there will be a support for the newly released mpt-30b-instruct

Motivation

It's not possible to use mosaicml/mpt-30b-instruct model:

ValueError: sharded is not supported for AutoModel

Your contribution

I am not sure how you can add support for new LLM models. (if there is a step by step guide as where to start would be great and I can contribute)

maziyarpanahi avatar Jun 23 '23 16:06 maziyarpanahi

did you try --trust-remote-code while running the docker

mantrakp04 avatar Jun 23 '23 21:06 mantrakp04

it's very slow. This model is not supported for sharding at the moment in text-generation-inference.

did you try --trust-remote-code while running the docker

tim-a-davis avatar Jun 23 '23 22:06 tim-a-davis

Then try implementing a rudimentary implementation of it, you can use rust or js as router and Python for inference, copy the custom kernels from the repo, modify them as suitable, mpt already has an implementation for flash attention in its "remote code file" use that and batch_encode_plus while tokenizer and batch_decode, implement batching on router server and volla u have your own server ready for inference

mantrakp04 avatar Jun 23 '23 22:06 mantrakp04

Then try implementing a rudimentary implementation of it, you can use rust or js as router and Python for inference, copy the custom kernels from the repo, modify them as suitable, mpt already has an implementation for flash attention in its "remote code file" use that and batch_encode_plus while tokenizer and batch_decode, implement batching on router server and volla u have your own server ready for inference

Maybe you could write me one as an example?

tim-a-davis avatar Jun 23 '23 22:06 tim-a-davis

Am working on one right now, if you would like to help out (discord: mantrakp)

mantrakp04 avatar Jun 23 '23 22:06 mantrakp04

I am also very interested in this I know the router side but how do you actually "on the fly" batch compute multiple requests at once with transformers?

SinanAkkoyun avatar Jun 23 '23 23:06 SinanAkkoyun

(And can we expect an optimized tgi implementation soon?)

SinanAkkoyun avatar Jun 23 '23 23:06 SinanAkkoyun

Take example to other models we have done in server/text-generation-server/models/custom_modeling/*.py maybe ?

There's also some files in server/text-generation-server/models/*.py. Those are declaring the model as being flash enabled (the batching happens differently when a model supports flash).

If you succeed PRs are welcome !

Narsil avatar Jun 26 '23 09:06 Narsil

is mpt even supported https://github.com/huggingface/text-generation-inference/issues/290 ?

louis030195 avatar Jun 27 '23 21:06 louis030195

It's supported on the "best effort basis".

I started some work to actually support it, but it means rewriting flash attention (the cuda version) with added bias, which may take some time.

Narsil avatar Jun 28 '23 10:06 Narsil

Sad news: I didn't succeed, the mpt model is a bit different, i tried loading it but it didn't work as expected and keeps mixing up tokens. I am looking forward for your implementation Narsil, sorry for the wait.

A pre thanks (an advance thank) to narsil :D

mantrakp04 avatar Jun 28 '23 23:06 mantrakp04

It's supported on the "best effort basis".

I started some work to actually support it, but it means rewriting flash attention (the cuda version) with added bias, which may take some time.

Can you guide on how you started writing the flash attention part and what are your thoughts on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel. A little guidance would be really great, maybe we can collaborate and try this out.

ankit201 avatar Jun 29 '23 06:06 ankit201

on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel.

This won't require work once we have flash attention.

Narsil avatar Jun 29 '23 11:06 Narsil

on implementing dynamic batching for this as it only supports 1 concurrent request for now on AutoModel.

This won't require work once we have flash attention.

Please correct me if I'm wrong but do we need to implement this since mpt-30 models already has flashattention usage prebuilt in its config? mpt-30b-chat

import torch
import transformers

name = 'mosaicml/mpt-30b-chat'

config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'triton'  # change this to use triton-based FlashAttention
config.init_device = 'cuda:0' # For fast initialization directly on GPU!

model = transformers.AutoModelForCausalLM.from_pretrained(
  name,
  config=config,
  torch_dtype=torch.bfloat16, # Load model weights in bfloat16
  trust_remote_code=True
)

ankit201 avatar Jun 29 '23 13:06 ankit201

Because it doesn't implement the flash attention we want.

This is Triton's flash attention, which doesn't support "unpadded" batching, which is the one necessary to work nicely on TGI (removing batching, removes a LOT of issues and unnecessary memory and speeds up inference much more than flash by itself).

Flash attention actually doesn't play that big of a role in speeding things up at inference, since most of the time is spent in decode where it doesn't really help. But the no padding thing is extremely important.

Narsil avatar Jun 30 '23 07:06 Narsil

Here is the non flash version (as a temporary measure since modifying the kernel is taking more time than I anticipated: https://github.com/huggingface/text-generation-inference/pull/514

This should enable sharding at least.

Narsil avatar Jul 01 '23 10:07 Narsil

Here is the non flash version (as a temporary measure since modifying the kernel is taking more time than I anticipated: #514

This should enable sharding at least.

Many thanks for this. Looking forward to the flash class too. Cheers!

ankit201 avatar Jul 01 '23 12:07 ankit201

Because it doesn't implement the flash attention we want.

This is Triton's flash attention, which doesn't support "unpadded" batching, which is the one necessary to work nicely on TGI (removing batching, removes a LOT of issues and unnecessary memory and speeds up inference much more than flash by itself).

Flash attention actually doesn't play that big of a role in speeding things up at inference, since most of the time is spent in decode where it doesn't really help. But the no padding thing is extremely important.

Triton is the only flash attention implementation that supports ALiBi, if I understand this correctly.

So for TGI, if we want to use MPT with ALiBi, does that leave us with just the native pytorch implementation?

ConProgramming avatar Jul 03 '23 17:07 ConProgramming

We will fork and add it ourselves to the flash attention cuda kernels.

OlivierDehaene avatar Jul 04 '23 07:07 OlivierDehaene