nos
nos copied to clipboard
an error will occur when the input shape dimension is greater than 2
Hi @nebuly-ai
Thanks for sharing this!
When I test @accelerate_model() in bert, the input size is (batch_size * text_length * hidden_size). there will be an error RuntimeError: self must be a matrix
, because output = input.mm(weight.t())
is not Supports matrix multiplication with more than 2 dimensions.
Hello @Nipi64310! That's right, torch.mm
doesn't work with tensors having more than 2 dimensions! From the next release, I'll replace it with torch.matmul
instead and that should solve the problem. Thank you for letting us know about the issue!