torch, triton版本确认及显存占用分析
requirements.txt中是torch 2.0.0;安装的时候和triton 2.1.0 不兼容;
安装时triton改为2.0.0安装;
安装后单独更新安装triton至2.1.0版本;
server可以正常运行,请求时发生错误:
/root/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-centos-7-release/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::triton::gpu::BLockedEncodingAttr; From = mlir::Attribute]: Asserttion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' faliled
基础环境:redhat 7, cuda11.8, 4卡v100,python3.10
v100 请到 github 上triton的官网,找到安装 nightly 版本的命令,2.0.0 和 2.1.0 都不能很好的支持V100显卡。
@LittleYouEr
多谢回复;目前在A800上启动的镜像可以正常work了,但是在测试过程中还有问题没追溯清楚;
我们启动命令如下:
python3 -m lightllm.server.api_server --model_dir /data/models/DeepSeek-Coder-V2-Instruct
-- host 0.0.0.0
-- port 8080
-- tp 8
--max_total_token_num 96 \ #该参数为了测试裸模型下的显存占用 --max_req_input_len 64
--max_req_total_len 32
--data_type bfloat16
--trust_remote_code
问题:裸模型情况下的推理显存分布;
当前测试的deepseek v2模型,原模型有236b,裸模型参数bfloat16分布在8卡上共计444gb,单卡大致需要57gb; 在极限参数情况下,比如说:
max_total_token_num: 96 max_req_input_len 32 max_req_total_len 64
单卡实际显存显示:66.5gb,请问这将近10gb的显存主要是哪些地方在占用,因为我们也没有指定额外的参数。
第二类问题,kv cache大小及max_total_token_num的设定
deepseek v2的kv cache管理也追溯了相关代码,应该实现了mla的做法;但对于这部分我们也有两个问题...
- max_total_token_num的计算,该link给出了相关说明,我们认为可能存在一些问题:kv cache的大小,每张卡的大小实际和总kv_cache的大小一致,每张卡都在开辟一个独立的kv_cache,也就是说,例子中:
gpu: use 2 A100 80G, (--tp 2) model: llama-7b, dtype: fp16, llama-7b hidden_size is 4096, layers num is 32, the gpu mem left after gpu load all weights,
80 * 2 - 7 * 2 = 146G
gpu mem for one Token kv cache:
4096 * 2 * 2 * 32 / 1024 / 1024 / 1024 = 0.000488281G
the max token num:
146 / 0.000488281 ≈ 299008
公式中146需要除以/2,也就是146/2/0.000488281,这个还请确认下。
以上,是当前测试遇到的问题;
先感谢大佬们的工作,可以开源mla的实现,目前应该是首家...
因为我们项目时间也比较紧,所以就直接提问了,还望不吝赐教,谢谢!
I tried to use pdb.set_trace() (following the tutorial here)to debug the triton kernel in lightllm, but got the following error:
AssertionError: Function "set_trace" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this
my kernel is already with @triton.jit env:
pytorch-version: 2.1.0
triton-version: 2.1.0
Any advice on this?
Thanks in advance!
@LittleYouEr 第一个问题,因为有些中间处理过程会消耗显存,比如加载的时候有cuda操作等,torch一般申请了你不主动释放就不会释放回去,有时候就偏大,具体可能和代码细节有关系,但是实际使用的显存就是你计算占用的部分,有部分被torch给缓存了,但是没使用。
第二个问题, mla 的max total token num 参数的计算和估计,和文档中描述的可能不是那么适用,因为那个主要是老的MHA架构的计算,你的思考是正确的,如果tp了确实需要多有一份kv,为了速度,没办法。
还有一个问题就是目前的 mla 算子没有详细的调优过,可能有不少提升空间。 moe 矩阵乘法部分实现也很粗糙,有很大的提升空间,后续还会逐步改进。
I tried to use
pdb.set_trace()(following the tutorial here)to debug the triton kernel in lightllm, but got the following error:
AssertionError: Function "set_trace" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix thismy kernel is already with @triton.jit env:
pytorch-version: 2.1.0 triton-version: 2.1.0Any advice on this?
Thanks in advance!
@brisker I did not use this feature before. Maybe you can try to update your triton version to nightly version. I am not sure that triton == 2.1.0 can support pdb.set_trace().
@LittleYouEr 第一个问题,因为有些中间处理过程会消耗显存,比如加载的时候有cuda操作等,torch一般申请了你不主动释放就不会释放回去,有时候就偏大,具体可能和代码细节有关系,但是实际使用的显存就是你计算占用的部分,有部分被torch给缓存了,但是没使用。
第二个问题, mla 的max total token num 参数的计算和估计,和文档中描述的可能不是那么适用,因为那个主要是老的MHA架构的计算,你的思考是正确的,如果tp了确实需要多有一份kv,为了速度,没办法。
还有一个问题就是目前的 mla 算子没有详细的调优过,可能有不少提升空间。 moe 矩阵乘法部分实现也很粗糙,有很大的提升空间,后续还会逐步改进。
显存消耗那块看了源码,也做了一些调试输出,比较清楚了;
后面mla的优化,感觉还得靠大佬带飞啊;
感谢回复!
@hiworldwzj 请问lightllm的w8a8的Triton kernel,是否在llama上测试过相比于fp16的加速效果的benchMark?