keras
keras copied to clipboard
[Contributors Wanted] Implement `compute_output_shape()` method for `MultiHeadAttention`
Nature of the task
Implement the method compute_output_shape() on the layer MultiHeadAttention. It should have the following signature:
def compute_output_shape(self, query_shape, value_shape, key_shape=None):
It should accept either TensorShape instances or tuples as shapes, and it should return a TensorShape instance.
Why?
The method is not required in order to use the layer, but it can be very useful to be able to compute the output shape of a layer without having to actually call it.
I would like to take it if it is still available.
@Pouyanpi thanks! Feel free to open a PR.