CTranslate2 icon indicating copy to clipboard operation
CTranslate2 copied to clipboard

GEMM operator calculates the `c` output shape incorrectly when input `a` is transpose?

Open kandrio opened this issue 6 months ago • 0 comments

Hey everyone!

I believe I have found a bug in the GEMM operator.

To the best of my knowledge, the output shape of the c StorageView in the GEMM operator should always be: {m, n}.

However, Gemm::compute() calcultes the shape of c incorrectly if a is transpose (a has shape {k, m} instead of {m, k}).

Take a look at this piece of code: https://github.com/OpenNMT/CTranslate2/blob/4f8a4f334c59588223b6f1f24b707d7e8d5fe08c/src/ops/gemm.cc#L84-L88

  • Shape output_shape(a.shape());: sets output_shape to: {k, m}
  • output_shape[output_shape.size() - 1] = n;: sets output_shape to: {k, n}
  • c.resize(std::move(output_shape));: resizes c to {k, n}.

Am I missing something here?

Bare in mind that there are no unit tests to catch this currently.

kandrio avatar Dec 09 '23 15:12 kandrio