wenet icon indicating copy to clipboard operation
wenet copied to clipboard

torch batch forward

Open murphypei opened this issue 2 years ago • 5 comments

Describe the bug For non-streaming inference, it seems that the exported onnx model can support batch forward, and I try to modify the asr_model.py file to export jitscript model which also supports batch forward(not the chunk forward) just like the onnx, the encoder can run as expected but failed in decoder embed. I trace and think maybe the ignore_id -1 of padding results in the bug. I have checked the code and think there are not differences, this confuses me, why onnx can export batch forward model but jitscript can't ? I think since you guys don't export the interfaces, maybe you know more, so can you help me out? thank you.

To Reproduce Steps to reproduce the behavior:

  1. Add jit export functions in asr_model.py
  2. run batch input data in x86 runtime
@torch.jit.export
def batch_forward_encoder(
    self,
    xs: torch.Tensor,
    xs_lens: torch.Tensor,
    beam_size: int = 10
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    encoder_out, encoder_mask = self.encoder.forward(xs, xs_lens, -1, -1)
    encoder_out_lens = encoder_mask.squeeze(1).sum(1)
    ctc_log_probs = self.ctc.log_softmax(encoder_out)
    encoder_out_lens = encoder_out_lens.int()
    beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs, beam_size, dim=2)
    return encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx

@torch.jit.export
def batch_forward_attention_decoder(
    self,
    encoder_out: torch.Tensor,
    encoder_lens: torch.Tensor,
    hyps_pad_sos_eos: torch.Tensor,
    hyps_lens_sos: torch.Tensor,
    r_hyps_pad_sos_eos: torch.Tensor,
    ctc_score: torch.Tensor,
    ctc_weight: float = 0.5,
    reverse_weight: float = 0.0,
    beam_size: int = 10
) -> torch.Tensor:
    B, T, F = encoder_out.shape
    bz = beam_size
    B2 = B * bz
    encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
    encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1)
    encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T)
    T2 = hyps_pad_sos_eos.shape[2] - 1
    hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1)
    hyps_lens = hyps_lens_sos.view(B2,)
    hyps_pad_sos = hyps_pad[:, :-1].contiguous()
    hyps_pad_eos = hyps_pad[:, 1:].contiguous()

    r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1)
    r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous()
    r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous()

    decoder_out, r_decoder_out, _ = self.decoder(
        encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos,
        reverse_weight)
    decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
    V = decoder_out.shape[-1]
    decoder_out = decoder_out.view(B2, T2, V)
    mask = ~make_pad_mask(hyps_lens, T2)  # B2 x T2
    # mask index, remove ignore id
    index = torch.unsqueeze(hyps_pad_eos * mask, 2)
    score = decoder_out.gather(2, index).squeeze(2)  # B2 X T2
    # mask padded part
    score = score * mask
    decoder_out = decoder_out.view(B, bz, T2, V)
    if reverse_weight > 0:
        r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
        r_decoder_out = r_decoder_out.view(B2, T2, V)
        index = torch.unsqueeze(r_hyps_pad_eos * mask, 2)
        r_score = r_decoder_out.gather(2, index).squeeze(2)
        r_score = r_score * mask
        score = score * (1 - reverse_weight) + reverse_weight * r_score
        r_decoder_out = r_decoder_out.view(B, bz, T2, V)
    score = torch.sum(score, dim=1)  # B2
    score = torch.reshape(score, (B, bz)) + ctc_weight * ctc_score
    best_index = torch.argmax(score, dim=1)
    return best_index

Expected behavior run successfully

murphypei avatar Jul 27 '22 02:07 murphypei

please paste the error message

robin1001 avatar Jul 27 '22 08:07 robin1001

please paste the error message

sorry, my fault.

I save the input data of onnx decoder(the padding by ignore_id), and replace random input with these data to run export_onnx.py for decoder.

the config(reverse_weight=0.0):

{'accum_grad': 1, 'cmvn_file': 'global_cmvn', 'collate_conf': {'feature_dither': 0.0, 'spec_aug': True, 'spec_aug_conf': {'max_f': 10, 'max_t': 50, 'max_w': 80, 'num_f_mask': 2, 'num_t_mask': 2, 'warp_for_time': False}}, 'dataset_conf': {'batch_size': 10, 'batch_type': 'dynamic', 'max_frames_in_batch': 12000, 'max_length': 10240, 'min_length': 0, 'sort': True, 'filter_conf': {'max_length': 2000, 'min_length': 50, 'token_max_length': 400, 'token_min_length': 1, 'min_output_input_ratio': 0.0005, 'max_output_input_ratio': 0.1}, 'fbank_conf': {'num_mel_bins': 80, 'frame_shift': 10, 'frame_length': 25, 'dither': 0.0}, 'resample_conf': {'resample_rate': 16000}, 'batch_conf': {'batch_type': 'static', 'batch_size': 8}}, 'decoder': 'transformer', 'decoder_conf': {'attention_heads': 8, 'dropout_rate': 0.1, 'linear_units': 2048, 'num_blocks': 6, 'positional_dropout_rate': 0.1, 'self_attention_dropout_rate': 0.0, 'src_attention_dropout_rate': 0.0}, 'encoder': 'conformer', 'encoder_conf': {'activation_type': 'swish', 'attention_dropout_rate': 0.0, 'attention_heads': 4, 'causal': True, 'cnn_module_kernel': 15, 'cnn_module_norm': 'layer_norm', 'dropout_rate': 0.1, 'input_layer': 'conv2d', 'linear_units': 2048, 'normalize_before': True, 'num_blocks': 12, 'output_size': 256, 'pos_enc_layer_type': 'rel_pos', 'positional_dropout_rate': 0.1, 'selfattention_layer_type': 'rel_selfattn', 'use_cnn_module': True, 'use_dynamic_chunk': False}, 'grad_clip': 5, 'input_dim': 80, 'is_json_cmvn': False, 'log_interval': 100, 'max_epoch': 160, 'model_conf': {'ctc_weight': 0.5, 'length_normalized_loss': False, 'lsm_weight': 0.1}, 'optim': 'adam', 'optim_conf': {'lr': 0.002}, 'output_dim': 5070, 'raw_wav': False, 'scheduler': 'warmuplr', 'scheduler_conf': {'warmup_steps': 25000}}

the input data:

=> hyps_pad_sos_eos: torch.Size([2, 10, 36])
 tensor([[[5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 3135,  669, 5042, 4648,
          2451, 2135, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581,
          3246, 4647, 5069],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 5042, 4648, 2451, 2135,
          4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246, 4647,
          5069,   -1,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 3135,  669, 5042, 4648,
          2451, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246,
          4647, 5069,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 5042, 4648, 2451, 4647,
          4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246, 4647, 5069,
            -1,   -1,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 3135,  669, 5042, 4648,
          2451, 2117, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581,
          3246, 4647, 5069],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,   25, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 3135,  669, 5042, 4648,
          2451, 2135, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581,
          3246, 4647, 5069],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,  202, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 5042, 4648, 2451, 2117,
          4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246, 4647,
          5069,   -1,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,   25, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 5042, 4648, 2451, 2135,
          4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246, 4647,
          5069,   -1,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,   25, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 3135,  669, 5042, 4648,
          2451, 4647, 4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246,
          4647, 5069,   -1],
         [5069, 4115, 4115, 4115, 4115, 1414, 4647, 5042, 4262,   25, 4954,
             2, 1054, 4648, 2451, 2117,  169, 5042, 5042, 4648, 2451, 4647,
          4923, 3540, 4708, 1601, 2989, 4409, 1811,  581, 3246, 4647, 5069,
            -1,   -1,   -1]],

        [[5069, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532, 3511,
          2215, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532,
          3511, 2215, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532, 3511,
          2215, 4921, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532, 3511,
          2215, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 5025, 4182, 4532,
          3511, 2215, 4921, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532,
          3511, 2215, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532, 3511,
          2215, 4921, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069, 5053, 3165, 4954, 2962, 5025, 4182, 2962, 4979, 4182, 4532,
          3511, 2215, 4921, 2962, 2054, 2263, 4647, 5069,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069,    0, 5069,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1],
         [5069,    0, 5069,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
            -1,   -1,   -1]]], dtype=torch.int32)
=> hyps_lens_sos: torch.Size([2, 10])
 tensor([[35, 33, 34, 32, 35, 35, 33, 33, 34, 32],
        [16, 17, 17, 16, 18, 17, 17, 18,  2,  2]], dtype=torch.int32)
=> r_hyps_pad_sos_eos: torch.Size([2, 10, 36])
 tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]], dtype=torch.int32)
=> encoder_out: torch.Size([2, 538, 256])
 tensor([[[-0.0112931002, -0.0037228700,  0.0140792998,  ...,
          -0.0213916991, -0.0696593001,  0.0179855991],
         [ 0.0236336999, -0.0084405299,  0.0264567994,  ...,
          -0.0065244199, -0.1207040027, -0.0038900599],
         [ 0.0716644004, -0.0227231998,  0.0324570984,  ...,
           0.0369418003, -0.2284329981, -0.0327154994],
         ...,
         [ 0.0007195030,  0.1284279972,  0.0712073967,  ...,
          -0.0373377986,  0.0767147988,  0.2463870049],
         [-0.1903409958,  0.2421520054,  0.0602861010,  ...,
          -0.1040119976,  0.1219519973, -0.0208250992],
         [-0.2125509977,  0.2335509956,  0.1081859991,  ...,
          -0.0806310028,  0.0531232990, -0.1088310033]],

        [[-0.0137972003, -0.0014524700,  0.0161499996,  ...,
          -0.0094943298, -0.0790612996,  0.0239565000],
         [ 0.0199223999,  0.0055981702,  0.0381584018,  ...,
           0.0017427501, -0.1110450029,  0.0021909799],
         [ 0.0662441999, -0.0008342030,  0.0566346012,  ...,
           0.0435902998, -0.1948059946, -0.0183284003],
         ...,
         [ 0.0337589011, -0.2291299999,  0.0693117008,  ...,
           0.0938744023, -0.3226909935,  0.1645780057],
         [ 0.0337589011, -0.2291299999,  0.0693117008,  ...,
           0.0938744023, -0.3226909935,  0.1645780057],
         [ 0.0337589011, -0.2291299999,  0.0693117008,  ...,
           0.0938744023, -0.3226909935,  0.1645780057]]])
=> encoder_out_lens: torch.Size([2])
 tensor([538, 299], dtype=torch.int32)
=> ctc_score: torch.Size([2, 10])
 tensor([[-2.4608499527e+01, -4.2024700165e+01, -2.9768400192e+01,
         -4.7184600830e+01, -2.1780700684e+01, -2.8415100098e+01,
         -3.9196899414e+01, -4.5831298828e+01, -3.3575000763e+01,
         -5.0991199493e+01],
        [-2.4205099106e+01, -2.6406499863e+01, -1.8420799255e+01,
         -2.6921899796e+01, -2.0622200012e+01, -2.9123300552e+01,
         -2.1137599945e+01, -2.3339000702e+01,  1.1754900068e-38,
          1.1754900068e-38]])

the error message

Traceback (most recent call last):
  File "wenet/bin/export_onnx_gpu.py", line 415, in <module>
    torch.onnx.export(decoder,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 305, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 118, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 719, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 499, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 440, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 391, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "wenet/bin/export_onnx_gpu.py", line 127, in forward
    decoder_out, r_decoder_out, _ = self.decoder(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/ws/code/wenet2/wenet/transformer/decoder.py", line 122, in forward
    x, _ = self.embed(tgt)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1098, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2183, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self

murphypei avatar Jul 28 '22 02:07 murphypei

I have no idea why this error happened. Just wondering why you prefer use libtorch rather than onnx. If you would like use triton, I suggest use onnx which is also friendly to further optimization techniques like tensorrt.

yuekaizhang avatar Jul 28 '22 03:07 yuekaizhang

I have no idea why this error happened. Just wondering why you prefer use libtorch rather than onnx. If you would like use triton, I suggest use onnx which is also friendly to further optimization techniques like tensorrt.

There are actually many limitations to using triton, datacenter GPUs, CUDA version, triton version to support stateful models etc. In our practice, triton and onnxruntime also have their own problems, some of them are even fatal like memory leak(https://github.com/microsoft/onnxruntime/issues/8147).

Anyway, thank you for your reply.

murphypei avatar Jul 28 '22 06:07 murphypei

this error if there is -1 in hyps_pad_sos_eos ? I have the same problem, in my case is hyps_pad_sos_eos have the value greater than output_dim

ziyu123 avatar Sep 23 '22 07:09 ziyu123

Onnx can support negative index. JIT maybe not ? You may use other value instead of -1 in this case.

Slyne avatar Sep 23 '22 18:09 Slyne

Onnx can support negative index. JIT maybe not ? You may use other value instead of -1 in this case.

thanks, I'll try.

murphypei avatar Sep 26 '22 06:09 murphypei

this error if there is -1 in hyps_pad_sos_eos ? I have the same problem, in my case is hyps_pad_sos_eos have the value greater than output_dim

thanks, I'll try.

murphypei avatar Sep 26 '22 06:09 murphypei

This issue has been automatically closed due to inactivity.

github-actions[bot] avatar Jan 17 '24 01:01 github-actions[bot]