lightseq
lightseq copied to clipboard
运行Example中的ls_bert报错,TypeError: infer(): incompatible function arguments
运行 ls_bert.py时报错 环境 LightSeq 2.2.0,Python 3.7,CUDA 11.0,tensorflow-gpu 2.4,
Bert buf_bytesize: 1409286144
====================START warmup====================
=========lightseq=========
lightseq generating...
Traceback (most recent call last):
File "test/ls_bert.py", line 111, in
Invoked with: <lightseq.inference.Bert object at 0x7fad1706e5f0>, tensor([[ 101, 7592, 1010, 2026, 3899, 2003, 10140, 102], [ 101, 4931, 1010, 2129, 2024, 2017, 102, 0], [ 101, 2023, 2003, 1037, 3231, 102, 0, 0], [ 101, 5604, 1996, 2944, 2153, 102, 0, 0]])
改成和2.1.3版本一致就可以了,少传入attn_mask这个参数
def infer(self, inputs, attn_mask): last_hidden_states = self.ls_bert.infer(inputs, attn_mask) last_hidden_states = torch.Tensor(last_hidden_states).float() pooled_output = self.pooler(last_hidden_states.to("cuda:0")) logits = self.classifier(pooled_output) return logits
但是lightseq性能还不huggingface未加速版本,GPU是 1080Ti
====================END warmup==================== tokenizing the sentences... =========lightseq========= lightseq generating... lightseq time: 0.12853449676185846s lightseq results (class predictions): [1 1 1 1] =========huggingface========= huggingface generating... huggingface time: 0.017190586775541306s huggingface results (class predictions): [1 1 1 1]
改成和2.1.3版本一致就可以了,少传入attn_mask这个参数
def infer(self, inputs, attn_mask): last_hidden_states = self.ls_bert.infer(inputs, attn_mask) last_hidden_states = torch.Tensor(last_hidden_states).float() pooled_output = self.pooler(last_hidden_states.to("cuda:0")) logits = self.classifier(pooled_output) return logits
但是lightseq性能还不huggingface未加速版本,GPU是 1080Ti
====================END warmup==================== tokenizing the sentences... =========lightseq========= lightseq generating... lightseq time: 0.12853449676185846s lightseq results (class predictions): [1 1 1 1] =========huggingface========= huggingface generating... huggingface time: 0.017190586775541306s huggingface results (class predictions): [1 1 1 1]
好像默认编译好的是FP16的, 1080Ti的半精度单元不咋行,FP32似乎需要自己编译下