FastSpeech2 icon indicating copy to clipboard operation
FastSpeech2 copied to clipboard

Modify model to allow JIT tracing

Open xDuck opened this issue 4 years ago • 7 comments

Hi, thanks for the repo! I am wondering if you have plans to convert the model to be JIT-traceable for exporting to C++? I tried to JIT trace and it generated some critical warnings:

FastSpeech2/env/lib/python3.7/site-packages/torch/tensor.py:593: RuntimeWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  'incorrect results).', category=RuntimeWarning)
FastSpeech2/utils/tools.py:97: TracerWarning: Converting a tensor to a NumPy array might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = max_len.detach().cpu().numpy()[0]
FastSpeech2/transformer/Models.py:82: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.training and src_seq.shape[1] > self.max_seq_len:
FastSpeech2/transformer/Models.py:90: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  :, :max_len, :
FastSpeech2/model/modules.py:186: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  expand_size = predicted[i].item()
FastSpeech2/model/modules.py:180: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return output, torch.LongTensor(mel_len).to(device)
FastSpeech2/utils/tools.py:94: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = torch.max(lengths).item()
FastSpeech2/transformer/Models.py:145: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not self.training and enc_seq.shape[1] > self.max_seq_len:
FastSpeech2/transformer/Models.py:154: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  max_len = min(max_len, self.max_seq_len)
FastSpeech2/transformer/Models.py:158: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  dec_output = enc_seq[:, :max_len, :] + self.position_enc[
FastSpeech2/transformer/Models.py:159: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  :, :max_len, :
FastSpeech2/transformer/Models.py:161: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  mask = mask[:, :max_len]
FastSpeech2/transformer/Models.py:162: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  slf_attn_mask = slf_attn_mask[:, :, :max_len]

I made the following changes:

tools.py:91

def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths).item()
    else:
        print(max_len)
        max_len = max_len.detach().cpu().numpy()[0]
        print(max_len)
    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask

and

synthesize:87

def synthesize(model, step, configs, vocoder, batchs, control_values):
    preprocess_config, model_config, train_config = configs
    pitch_control, energy_control, duration_control = control_values

    for batch in batchs:
        batch = to_device(batch, device)
        with torch.no_grad():
            traced_script_module = torch.jit.trace(
                model, (batch[2], batch[3], batch[4], torch.tensor([batch[5]]))
            )
            traced_script_module.save("traced_fastspeech_model.pt")

It seems like most of the issues are with max_len being used in conditionals and array slices. I will look into this more but wanted to see if you had tried this before

xDuck avatar Mar 17 '21 20:03 xDuck

@xDuck As far as I know, jit.trace only works for models with fixed shape inputs. This model uses inputs of variable size. Does it make sense to use jit.trace? (I also tried to use jit.script(), it makes more errors...)

KinamSalad avatar Mar 23 '21 06:03 KinamSalad

@xDuck @KinamSalad I think the code should be modified to enable the use of torch.jit. It's in my future plan for the next major update.

ming024 avatar Mar 24 '21 01:03 ming024

Thank you! I have done some work on it and got it almost complete, I ended up removing the ability to do batch runs (batch size now always = 1) because I didn’t need them. I planned on going back to add it back on when I had time but got busy.

The one thing I didn’t finish was the Length Regulator module.

On Tue, Mar 23, 2021 at 9:33 PM Chung-Ming Chien @.***> wrote:

@xDuck https://github.com/xDuck @KinamSalad https://github.com/KinamSalad I think the code should be modified to enable the use of torch.jit. It's in my future plan for the next major update.

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/ming024/FastSpeech2/issues/35#issuecomment-805403946, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABICRIJ6R74HUGI4T5LSWLTTFE6OXANCNFSM4ZLIQCGA .

xDuck avatar Mar 24 '21 01:03 xDuck

@xDuck Yeah I think the length regulator may be a major problem of scripting the whole model. Looking forward to your result!

ming024 avatar Mar 24 '21 02:03 ming024

I was able to find a couple hours to work at this again. Here is the updated length regulator that compiles with JIT. I now have the whole model running through JIT but I cheated by removing all of the batch stuff in favor of only supporting single mode because I'm lazy, so I won't make a PR on this repo - but the rest of the model is pretty straight forward for converting to JIT.

The other catch here is the model no longer returns mel_len but that can be derived from the outputs that already exist.

Credit to https://github.com/rishikksh20/FastSpeech2 - I referenced their code pretty heavily in this.

@torch.jit.script
def pad_2d_tensor(xs: List[torch.Tensor], pad_value: float = 0.0):
    max_len = max([xs[i].size(0) for i in range(len(xs))])

    out_list = []

    for i, batch in enumerate(xs):
        one_batch_padded = F.pad(
            batch, (0, 0, 0, max_len - batch.size(0)), "constant", pad_value
        )
        out_list.append(one_batch_padded)

    out_padded = torch.stack(out_list)
    return out_padded


@torch.jit.script
def expand(x: torch.Tensor, d: torch.Tensor):
    if d.sum() == 0:
        d = d.fill_(1)
    out = []
    for x_, d_ in zip(x, d):
        if d_ != 0:
            out.append(x_.repeat(int(d_), 1))
    return out

@torch.jit.script
def repeat_one_sequence(x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
        if d.sum() == 0:
            d = d.fill_(1)
        out = []
        for x_, d_ in zip(x, d):
            if d_ != 0:
                out.append(x_.repeat(int(d_), 1))

        return torch.cat(out, dim=0)

@torch.jit.script
def LR(x: torch.Tensor, duration: torch.Tensor):
    output = [repeat_one_sequence(x, d) for x, d in zip(x, duration)]
    output = pad_2d_tensor(output, 0.0)
    return output

@ming024

(Running my model in C++ I am running at about 15x realtime for the FastSpeech2 portion)

xDuck avatar Mar 31 '21 17:03 xDuck

@xDuck Great job!!!! Thanks for your work! I will try it several days later!

ming024 avatar Apr 01 '21 08:04 ming024

@xDuck Great job!!!! Thanks for your work! I will try it several days later!

Thanks a lot for your repo, I am wondering if you providing the update?

YoLi-sw avatar Oct 19 '22 09:10 YoLi-sw