Dinghao Zhou
Dinghao Zhou
sft: **2b gemma fsdp zero3**
generate in batch way: gemma: llama:
解释下这里为什么要把shape 变成[bs, seq_len,head, head_dim] https://github.com/wenet-e2e/wenet/blob/9805ed68638f711b6fda17627efb7aa918ce6870/wenet/transformer/attention.py#L637-#L651 来自gpt4的解释:  实测[bs, seq_len,head, head_dim], 对head_dim 上apply pos等操作要慢于[bs,head,seq_len, head_dim] ```bash 6s vs 2s (长度为300) ``` ref: https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L256 **所以其他xxx attention 是否也需要有对应修改?**
> 周神,torch官方也有个llama微调的代码:https://github.com/pytorch/torchtune > > 嗯 这个有看过。 不过我们最终目的不是llm 而是为了语音理解大模型和语音合成 而且大模型训练 有自己的设计原则和技巧 我们需要把优秀的组件 继承过来
该pr会拆分成以下加个pr - [x] decoderonly https://github.com/wenet-e2e/wenet/pull/2547 - [ ] llm dataset - [ ] convert script
需要自己改下,初始化model为k2 并且 model.decode 直接调用 model.hlgxxx
plz make a pr to enhance ?
@pengzhendong plz make some advice
可以自己改一下 decode接口本身是支持batch的
不支持这么长的 改位置编码的max len可以 但是效果会变差