nos icon indicating copy to clipboard operation
nos copied to clipboard

an error will occur when the input shape dimension is greater than 2

Open Nipi64310 opened this issue 2 years ago • 1 comments

Hi @nebuly-ai

Thanks for sharing this! When I test @accelerate_model() in bert, the input size is (batch_size * text_length * hidden_size). there will be an error RuntimeError: self must be a matrix, because output = input.mm(weight.t()) is not Supports matrix multiplication with more than 2 dimensions.

Nipi64310 avatar Jun 16 '22 10:06 Nipi64310

Hello @Nipi64310! That's right, torch.mm doesn't work with tensors having more than 2 dimensions! From the next release, I'll replace it with torch.matmul instead and that should solve the problem. Thank you for letting us know about the issue!

diegofiori avatar Jun 17 '22 22:06 diegofiori