t3f
t3f copied to clipboard
Automatic operation fusion
For example, instead of writing a separate function project_matmul
we can implement something like project_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))
To do that lets express each operation as a sequence of recurrent steps of the form
def recurrent(a, b):
res = 1.0
for a_core, b_core in zip(a.tt_cores, b.tt_cores):
res = einsum('rule', a_core, b_core, res)
and of independent steps
def independent(a, b):
res_cores = []
for a_core, b_core in zip(a.tt_cores, b.tt_cores):
res.append(einsum('rule', a_core, b_core))
Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.
From the top of my head we can support any combinations of
- matmul(A, B)
- add a + b
- elementwise product a * b
Additionally, as the last operation of the combination, we can support
- dot product a^t b
- gram matrix G_ij = ai^t bj
- projection on the tangent space P_x y
- trace
By combining this ops we can for example automatically get fast versions of
- x^t A y (already implemented as a separate fast operation)
- ||A B||
- A B x
- P_x A y (already implemented)
- ||a * b||
- Px A B y
- ||A + B||
- P_x (a * b)
- x^t A B y
- ||(Ax) * (By)||
- trace(A^T B A)
Does anyone need this?
A potential way to design the API:
with t3f.Fuse() as f:
Ax = t3f.matmul(A, x)
xAx = t3f.flat_inner(x, Ax)
fast_xAx = f.optimize(xAx)