brevitas
brevitas copied to clipboard
fixed wrong runtime shape inference for BatchNorm1dToQuantScaleBias
Fixed issue where the output shape of class BatchNorm1dToQuantScaleBias
was giving an unexpected shape of [1, input_dim, batch_dim, input_dim]
instead of [batch_dim, input_dim]
.
The issue came from the fact that the class ultimately had a default value of runtime_shape = (1, -1, 1, 1)
when it should be runtime_shape = (1, -1)
.
We had runtime_shape = (1, -1, 1, 1)
because class BatchNorm1dToQuantScaleBias
wasn't properly passing down the runtime_shape
parameter to class ScaleBias
. That issue has been also fixed.
Original issue from https://github.com/Xilinx/brevitas/issues/450