opt_einsum_fx
opt_einsum_fx copied to clipboard
[Feature] convert transpose into einsum
Make a transformer that transforms x.transpose(1, 2) into torch.einsum('abc...->acb...')
in order to then have these operations fused with the rest of the einsums