lightllm icon indicating copy to clipboard operation
lightllm copied to clipboard

torch, triton版本确认及显存占用分析

Open LittleYouEr opened this issue 1 year ago • 8 comments

requirements.txt中是torch 2.0.0;安装的时候和triton 2.1.0 不兼容;

安装时triton改为2.0.0安装;

安装后单独更新安装triton至2.1.0版本;

server可以正常运行,请求时发生错误:

/root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-centos-7-release/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::triton::gpu::BLockedEncodingAttr; From = mlir::Attribute]: Asserttion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' faliled

基础环境:redhat 7, cuda11.8, 4卡v100,python3.10

LittleYouEr avatar Aug 06 '24 07:08 LittleYouEr

v100 请到 github 上triton的官网,找到安装 nightly 版本的命令,2.0.0 和 2.1.0 都不能很好的支持V100显卡。

hiworldwzj avatar Aug 06 '24 09:08 hiworldwzj

@LittleYouEr

hiworldwzj avatar Aug 06 '24 09:08 hiworldwzj

多谢回复;目前在A800上启动的镜像可以正常work了,但是在测试过程中还有问题没追溯清楚;

我们启动命令如下:

python3 -m lightllm.server.api_server --model_dir /data/models/DeepSeek-Coder-V2-Instruct
-- host 0.0.0.0
-- port 8080
-- tp 8
--max_total_token_num 96 \ #该参数为了测试裸模型下的显存占用 --max_req_input_len 64
--max_req_total_len 32
--data_type bfloat16
--trust_remote_code

问题:裸模型情况下的推理显存分布;

当前测试的deepseek v2模型,原模型有236b,裸模型参数bfloat16分布在8卡上共计444gb,单卡大致需要57gb; 在极限参数情况下,比如说:

max_total_token_num: 96 max_req_input_len 32 max_req_total_len 64

单卡实际显存显示:66.5gb,请问这将近10gb的显存主要是哪些地方在占用,因为我们也没有指定额外的参数。

第二类问题,kv cache大小及max_total_token_num的设定

deepseek v2的kv cache管理也追溯了相关代码,应该实现了mla的做法;但对于这部分我们也有两个问题...

  1. max_total_token_num的计算,该link给出了相关说明,我们认为可能存在一些问题:kv cache的大小,每张卡的大小实际和总kv_cache的大小一致,每张卡都在开辟一个独立的kv_cache,也就是说,例子中:

gpu: use 2 A100 80G, (--tp 2) model: llama-7b, dtype: fp16, llama-7b hidden_size is 4096, layers num is 32, the gpu mem left after gpu load all weights,

80 * 2 - 7 * 2 = 146G

gpu mem for one Token kv cache:

4096 * 2 * 2 * 32 / 1024 / 1024 / 1024 = 0.000488281G

the max token num:

146 / 0.000488281 ≈ 299008

公式中146需要除以/2,也就是146/2/0.000488281,这个还请确认下。

以上,是当前测试遇到的问题;

先感谢大佬们的工作,可以开源mla的实现,目前应该是首家...

因为我们项目时间也比较紧,所以就直接提问了,还望不吝赐教,谢谢!

LittleYouEr avatar Aug 12 '24 02:08 LittleYouEr

I tried to use pdb.set_trace() (following the tutorial here)to debug the triton kernel in lightllm, but got the following error:

AssertionError: Function "set_trace" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this

my kernel is already with @triton.jit env:

pytorch-version: 2.1.0
triton-version: 2.1.0

Any advice on this?

Thanks in advance!

brisker avatar Aug 12 '24 13:08 brisker

@LittleYouEr 第一个问题,因为有些中间处理过程会消耗显存,比如加载的时候有cuda操作等,torch一般申请了你不主动释放就不会释放回去,有时候就偏大,具体可能和代码细节有关系,但是实际使用的显存就是你计算占用的部分,有部分被torch给缓存了,但是没使用。

第二个问题, mla 的max total token num 参数的计算和估计,和文档中描述的可能不是那么适用,因为那个主要是老的MHA架构的计算,你的思考是正确的,如果tp了确实需要多有一份kv,为了速度,没办法。

还有一个问题就是目前的 mla 算子没有详细的调优过,可能有不少提升空间。 moe 矩阵乘法部分实现也很粗糙,有很大的提升空间,后续还会逐步改进。

hiworldwzj avatar Aug 12 '24 15:08 hiworldwzj

I tried to use pdb.set_trace() (following the tutorial here)to debug the triton kernel in lightllm, but got the following error:

AssertionError: Function "set_trace" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this

my kernel is already with @triton.jit env:

pytorch-version: 2.1.0
triton-version: 2.1.0

Any advice on this?

Thanks in advance!

@brisker I did not use this feature before. Maybe you can try to update your triton version to nightly version. I am not sure that triton == 2.1.0 can support pdb.set_trace().

hiworldwzj avatar Aug 12 '24 15:08 hiworldwzj

@LittleYouEr 第一个问题,因为有些中间处理过程会消耗显存,比如加载的时候有cuda操作等,torch一般申请了你不主动释放就不会释放回去,有时候就偏大,具体可能和代码细节有关系,但是实际使用的显存就是你计算占用的部分,有部分被torch给缓存了,但是没使用。

第二个问题, mla 的max total token num 参数的计算和估计,和文档中描述的可能不是那么适用,因为那个主要是老的MHA架构的计算,你的思考是正确的,如果tp了确实需要多有一份kv,为了速度,没办法。

还有一个问题就是目前的 mla 算子没有详细的调优过,可能有不少提升空间。 moe 矩阵乘法部分实现也很粗糙,有很大的提升空间,后续还会逐步改进。

显存消耗那块看了源码,也做了一些调试输出,比较清楚了;

后面mla的优化,感觉还得靠大佬带飞啊;

感谢回复!

LittleYouEr avatar Aug 16 '24 09:08 LittleYouEr

@hiworldwzj 请问lightllm的w8a8的Triton kernel,是否在llama上测试过相比于fp16的加速效果的benchMark?

brisker avatar Oct 09 '24 11:10 brisker