Memoize or precompute subgraphs that depend only on input shapes
Many models have subgraphs that depend only on the shape of inputs, and thus don't change when the model is called repeatedly with inputs of the same shape. These subgraphs are usually cheap since the tensors flowing through them are small, but there is nevertheless overhead for each operation that is run. These subgraphs could be memoized to avoid re-running them unnecessarily.
As a starting point, it would be useful to do some experiments to see how many operations can be saved on some popular models, especially decoder models which are run repeatedly.
A related problem is that models sometimes have subgraphs which depend on input shapes but end up producing constant values. In attention for example, where Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V, the sqrt(d_k) is usually a constant but model graphs sometimes compute it from the shapes of tensors which are partially dynamic. As a result the scale factor gets recomputed on each run and optimizations which would require it to be a constant scalar are blocked.
Example from the decoder model used in the Nougat example:
The output of the Shape -> Slice graph here ends up being constant. If that were known, nodes up to the Sqrt could be eliminated by constant propagation, and the MatMul(Mul(X, A), Mul(Y, B)) subgraph could be fused into FusedMatMul and the transpose would end up being fused into that. The cost of the Transpose operation here grows for each iteration of the decoder because the input is from the KV cache, so this has a large impact on inference time.
ONNX shape inference tools can actually work out here that the slice input is (unknown, 16, unknown, 64) and should be able to determine that the slice output is 64.
If the existing shape inference tool were expanded, it could replace the Slice output with its actual value.
https://github.com/robertknight/rten/pull/805 handles the case where Shape + Slice is used to extract a dimension from a shape, and the value of that dimension is fixed according to shape inference.