pfeatherstone
pfeatherstone
How do i "ignore the masked values" ? The output of the loss function is a batch of numbers. I.e if `ref` has shape [B,T], `deg` has shape [B,T], then...
Pretty much. Basically nothing can be optional with tracing and ONNX. So both mems and mask must be specified
I haven't been able to train my models yet just with normal transformers, using larger context lengths (my weird TTS + STT system). CTC loss isn't converging at all. So...
You still need to pass mems in the forward pass when tracing. So i think the fix would be to pass mems full of zeros and make the code handle...
Then probably need some control flow in a torch.jit.script function which handles the different cases
Gave this a go, it turns out that `torch.jit.trace()` doesn't accept `None` in `example_inputs`. So we cannot trace with `mems` not None and expect to work when None, or vice...
Yeah I have. I could submit a PR but I didn't know if passing zeros and allowing to attend to that was ok.
Ah yes I need to try. Can it wait till Tuesday?
In the memory replay backpropagation algorithm, the labels are partitioned in the same way as the logits. The loss is evaluated per block. For CTC that doesn't make sense since...
@lucidrains Or if we forget CTC, can you think of a way to make this work with unaligned targets ?