wenet
wenet copied to clipboard
torch batch forward
Describe the bug For non-streaming inference, it seems that the exported onnx model can support batch forward, and I try to modify the asr_model.py file to export jitscript model which also supports batch forward(not the chunk forward) just like the onnx, the encoder can run as expected but failed in decoder embed. I trace and think maybe the ignore_id -1 of padding results in the bug. I have checked the code and think there are not differences, this confuses me, why onnx can export batch forward model but jitscript can't ? I think since you guys don't export the interfaces, maybe you know more, so can you help me out? thank you.
To Reproduce Steps to reproduce the behavior:
- Add jit export functions in asr_model.py
- run batch input data in x86 runtime
@torch.jit.export
def batch_forward_encoder(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
beam_size: int = 10
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
encoder_out, encoder_mask = self.encoder.forward(xs, xs_lens, -1, -1)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
ctc_log_probs = self.ctc.log_softmax(encoder_out)
encoder_out_lens = encoder_out_lens.int()
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, beam_size, dim=2)
return encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx
@torch.jit.export
def batch_forward_attention_decoder(
self,
encoder_out: torch.Tensor,
encoder_lens: torch.Tensor,
hyps_pad_sos_eos: torch.Tensor,
hyps_lens_sos: torch.Tensor,
r_hyps_pad_sos_eos: torch.Tensor,
ctc_score: torch.Tensor,
ctc_weight: float = 0.5,
reverse_weight: float = 0.0,
beam_size: int = 10
) -> torch.Tensor:
B, T, F = encoder_out.shape
bz = beam_size
B2 = B * bz
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
T2 = hyps_pad_sos_eos.shape[2] - 1
hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
hyps_lens = hyps_lens_sos.view(B2,)
hyps_pad_sos = hyps_pad[:, :-1].contiguous()
hyps_pad_eos = hyps_pad[:, 1:].contiguous()
r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos,
reverse_weight)
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
V = decoder_out.shape[-1]
decoder_out = decoder_out.view(B2, T2, V)
mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2
# mask index, remove ignore id
index = torch.unsqueeze(hyps_pad_eos * mask, 2)
score = decoder_out.gather(2, index).squeeze(2) # B2 X T2
# mask padded part
score = score * mask
decoder_out = decoder_out.view(B, bz, T2, V)
if reverse_weight > 0:
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
r_decoder_out = r_decoder_out.view(B2, T2, V)
index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
r_score = r_decoder_out.gather(2, index).squeeze(2)
r_score = r_score * mask
score = score * (1 - reverse_weight) + reverse_weight * r_score
r_decoder_out = r_decoder_out.view(B, bz, T2, V)
score = torch.sum(score, dim=1) # B2
score = torch.reshape(score, (B, bz)) + ctc_weight * ctc_score
best_index = torch.argmax(score, dim=1)
return best_index
Expected behavior run successfully
please paste the error message
please paste the error message
sorry, my fault.
I save the input data of onnx decoder(the padding by ignore_id), and replace random input with these data to run export_onnx.py for decoder.
the config(reverse_weight=0.0):
{'accum_grad': 1, 'cmvn_file': 'global_cmvn', 'collate_conf': {'feature_dither': 0.0, 'spec_aug': True, 'spec_aug_conf': {'max_f': 10, 'max_t': 50, 'max_w': 80, 'num_f_mask': 2, 'num_t_mask': 2, 'warp_for_time': False}}, 'dataset_conf': {'batch_size': 10, 'batch_type': 'dynamic', 'max_frames_in_batch': 12000, 'max_length': 10240, 'min_length': 0, 'sort': True, 'filter_conf': {'max_length': 2000, 'min_length': 50, 'token_max_length': 400, 'token_min_length': 1, 'min_output_input_ratio': 0.0005, 'max_output_input_ratio': 0.1}, 'fbank_conf': {'num_mel_bins': 80, 'frame_shift': 10, 'frame_length': 25, 'dither': 0.0}, 'resample_conf': {'resample_rate': 16000}, 'batch_conf': {'batch_type': 'static', 'batch_size': 8}}, 'decoder': 'transformer', 'decoder_conf': {'attention_heads': 8, 'dropout_rate': 0.1, 'linear_units': 2048, 'num_blocks': 6, 'positional_dropout_rate': 0.1, 'self_attention_dropout_rate': 0.0, 'src_attention_dropout_rate': 0.0}, 'encoder': 'conformer', 'encoder_conf': {'activation_type': 'swish', 'attention_dropout_rate': 0.0, 'attention_heads': 4, 'causal': True, 'cnn_module_kernel': 15, 'cnn_module_norm': 'layer_norm', 'dropout_rate': 0.1, 'input_layer': 'conv2d', 'linear_units': 2048, 'normalize_before': True, 'num_blocks': 12, 'output_size': 256, 'pos_enc_layer_type': 'rel_pos', 'positional_dropout_rate': 0.1, 'selfattention_layer_type': 'rel_selfattn', 'use_cnn_module': True, 'use_dynamic_chunk': False}, 'grad_clip': 5, 'input_dim': 80, 'is_json_cmvn': False, 'log_interval': 100, 'max_epoch': 160, 'model_conf': {'ctc_weight': 0.5, 'length_normalized_loss': False, 'lsm_weight': 0.1}, 'optim': 'adam', 'optim_conf': {'lr': 0.002}, 'output_dim': 5070, 'raw_wav': False, 'scheduler': 'warmuplr', 'scheduler_conf': {'warmup_steps': 25000}}
the input data:
=> hyps_pad_sos_eos: torch.Size([2, 10, 36])
tensor([[[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 3135, 669, 5042, 4648,
2451, 2135, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581,
3246, 4647, 5069],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 5042, 4648, 2451, 2135,
4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246, 4647,
5069, -1, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 3135, 669, 5042, 4648,
2451, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246,
4647, 5069, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 5042, 4648, 2451, 4647,
4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246, 4647, 5069,
-1, -1, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 3135, 669, 5042, 4648,
2451, 2117, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581,
3246, 4647, 5069],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 25, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 3135, 669, 5042, 4648,
2451, 2135, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581,
3246, 4647, 5069],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 202, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 5042, 4648, 2451, 2117,
4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246, 4647,
5069, -1, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 25, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 5042, 4648, 2451, 2135,
4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246, 4647,
5069, -1, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 25, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 3135, 669, 5042, 4648,
2451, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246,
4647, 5069, -1],
[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262, 25, 4954,
2, 1054, 4648, 2451, 2117, 169, 5042, 5042, 4648, 2451, 4647,
4923, 3540, 4708, 1601, 2989, 4409, 1811, 581, 3246, 4647, 5069,
-1, -1, -1]],
[[5069, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532, 3511,
2215, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532,
3511, 2215, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532, 3511,
2215, 4921, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532, 3511,
2215, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532,
3511, 2215, 4921, 2962, 2054, 2263, 4647, 5069, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532,
3511, 2215, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532, 3511,
2215, 4921, 2962, 2054, 2263, 4647, 5069, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532,
3511, 2215, 4921, 2962, 2054, 2263, 4647, 5069, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 0, 5069, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1],
[5069, 0, 5069, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1]]], dtype=torch.int32)
=> hyps_lens_sos: torch.Size([2, 10])
tensor([[35, 33, 34, 32, 35, 35, 33, 33, 34, 32],
[16, 17, 17, 16, 18, 17, 17, 18, 2, 2]], dtype=torch.int32)
=> r_hyps_pad_sos_eos: torch.Size([2, 10, 36])
tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]], dtype=torch.int32)
=> encoder_out: torch.Size([2, 538, 256])
tensor([[[-0.0112931002, -0.0037228700, 0.0140792998, ...,
-0.0213916991, -0.0696593001, 0.0179855991],
[ 0.0236336999, -0.0084405299, 0.0264567994, ...,
-0.0065244199, -0.1207040027, -0.0038900599],
[ 0.0716644004, -0.0227231998, 0.0324570984, ...,
0.0369418003, -0.2284329981, -0.0327154994],
...,
[ 0.0007195030, 0.1284279972, 0.0712073967, ...,
-0.0373377986, 0.0767147988, 0.2463870049],
[-0.1903409958, 0.2421520054, 0.0602861010, ...,
-0.1040119976, 0.1219519973, -0.0208250992],
[-0.2125509977, 0.2335509956, 0.1081859991, ...,
-0.0806310028, 0.0531232990, -0.1088310033]],
[[-0.0137972003, -0.0014524700, 0.0161499996, ...,
-0.0094943298, -0.0790612996, 0.0239565000],
[ 0.0199223999, 0.0055981702, 0.0381584018, ...,
0.0017427501, -0.1110450029, 0.0021909799],
[ 0.0662441999, -0.0008342030, 0.0566346012, ...,
0.0435902998, -0.1948059946, -0.0183284003],
...,
[ 0.0337589011, -0.2291299999, 0.0693117008, ...,
0.0938744023, -0.3226909935, 0.1645780057],
[ 0.0337589011, -0.2291299999, 0.0693117008, ...,
0.0938744023, -0.3226909935, 0.1645780057],
[ 0.0337589011, -0.2291299999, 0.0693117008, ...,
0.0938744023, -0.3226909935, 0.1645780057]]])
=> encoder_out_lens: torch.Size([2])
tensor([538, 299], dtype=torch.int32)
=> ctc_score: torch.Size([2, 10])
tensor([[-2.4608499527e+01, -4.2024700165e+01, -2.9768400192e+01,
-4.7184600830e+01, -2.1780700684e+01, -2.8415100098e+01,
-3.9196899414e+01, -4.5831298828e+01, -3.3575000763e+01,
-5.0991199493e+01],
[-2.4205099106e+01, -2.6406499863e+01, -1.8420799255e+01,
-2.6921899796e+01, -2.0622200012e+01, -2.9123300552e+01,
-2.1137599945e+01, -2.3339000702e+01, 1.1754900068e-38,
1.1754900068e-38]])
the error message
Traceback (most recent call last):
File "wenet/bin/export_onnx_gpu.py", line 415, in <module>
torch.onnx.export(decoder,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 305, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 118, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 719, in _export
_model_to_graph(model, args, verbose, input_names,
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 499, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
File "wenet/bin/export_onnx_gpu.py", line 127, in forward
decoder_out, r_decoder_out, _ = self.decoder(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
File "/ws/code/wenet2/wenet/transformer/decoder.py", line 122, in forward
x, _ = self.embed(tgt)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 141, in forward
input = module(input)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/sparse.py", line 158, in forward
return F.embedding(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2183, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
I have no idea why this error happened. Just wondering why you prefer use libtorch rather than onnx. If you would like use triton, I suggest use onnx which is also friendly to further optimization techniques like tensorrt.
I have no idea why this error happened. Just wondering why you prefer use libtorch rather than onnx. If you would like use triton, I suggest use onnx which is also friendly to further optimization techniques like tensorrt.
There are actually many limitations to using triton, datacenter GPUs, CUDA version, triton version to support stateful models etc. In our practice, triton and onnxruntime also have their own problems, some of them are even fatal like memory leak(https://github.com/microsoft/onnxruntime/issues/8147).
Anyway, thank you for your reply.
this error if there is -1 in hyps_pad_sos_eos ? I have the same problem, in my case is hyps_pad_sos_eos have the value greater than output_dim
Onnx can support negative index. JIT maybe not ? You may use other value instead of -1 in this case.
Onnx can support negative index. JIT maybe not ? You may use other value instead of -1 in this case.
thanks, I'll try.
this error if there is -1 in hyps_pad_sos_eos ? I have the same problem, in my case is hyps_pad_sos_eos have the value greater than output_dim
thanks, I'll try.
This issue has been automatically closed due to inactivity.