opt_einsum_fx
opt_einsum_fx copied to clipboard
Einsum optimization using opt_einsum and PyTorch FX graph rewriting
It could be useful to provide some way of marking a Tensor as a scalar and adding it to the propagation logic.
Consider: - [ ] `torch.reshape` - [ ] `torch.cross` - [ ] `torch.dot` - [ ] `torch.transpose` - [ ] `torch.nn.functional.bilinear`
Make a transformer that transforms `x.transpose(1, 2)` into `torch.einsum('abc...->acb...')` in order to then have these operations fused with the rest of the einsums
When accumulating scalar constants at graph optimization time, arbitrary precision arithmetic should be used to ensure the best result.
Accumulate the theoretical speedup, scaling factor, and intermediate sizes across einsums processed in a graph and somehow report them.