CTranslate2
CTranslate2 copied to clipboard
GEMM operator calculates the `c` output shape incorrectly when input `a` is transpose?
Hey everyone!
I believe I have found a bug in the GEMM operator.
To the best of my knowledge, the output shape of the c
StorageView in the GEMM operator should always be: {m, n}
.
However, Gemm::compute()
calcultes the shape of c
incorrectly if a
is transpose (a
has shape {k, m}
instead of {m, k}
).
Take a look at this piece of code: https://github.com/OpenNMT/CTranslate2/blob/4f8a4f334c59588223b6f1f24b707d7e8d5fe08c/src/ops/gemm.cc#L84-L88
-
Shape output_shape(a.shape());
: setsoutput_shape
to:{k, m}
-
output_shape[output_shape.size() - 1] = n;
: setsoutput_shape
to:{k, n}
-
c.resize(std::move(output_shape));
: resizesc
to{k, n}
.
Am I missing something here?
Bare in mind that there are no unit tests to catch this currently.