BERT-pytorch
BERT-pytorch copied to clipboard
ONNX conversion: TransformerBlock problem
I'm wondering if there's any workaround for an error when trying to convert BERT-pytorch to ONNX. The problem occurs when trying to trace through the TransformerBlock. I'm wondering if there's a way to rewrite the forward() to get around this error?
def forward(self, x, mask):
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) // <--- Error happens here!
x = self.output_sublayer(x, self.feed_forward)
return self.dropout(x)
The error is:
builtins.ValueError: Auto nesting doesn't know how to process an input object of type bert_pytorch.model.transformer.TransformerBlock.forward.<locals>.<lambda>. Accepted types: Tensors, or lists/tuples of them
The trace can't handle the type of the Lambda, so I'm wondering if rewriting without the Lambda would fix it? (Sorry, but I'm not good enough with Python yet to know how to do that without breaking anything.)