t3f icon indicating copy to clipboard operation
t3f copied to clipboard

Automatic operation fusion

Open Bihaqo opened this issue 5 years ago • 1 comments

For example, instead of writing a separate function project_matmul we can implement something like project_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))

To do that lets express each operation as a sequence of recurrent steps of the form

def recurrent(a, b):
  res = 1.0
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res = einsum('rule', a_core, b_core, res)

and of independent steps

def independent(a, b):
  res_cores = []
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res.append(einsum('rule', a_core, b_core))

Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.

From the top of my head we can support any combinations of

  1. matmul(A, B)
  2. add a + b
  3. elementwise product a * b

Additionally, as the last operation of the combination, we can support

  1. dot product a^t b
  2. gram matrix G_ij = ai^t bj
  3. projection on the tangent space P_x y
  4. trace

By combining this ops we can for example automatically get fast versions of

  1. x^t A y (already implemented as a separate fast operation)
  2. ||A B||
  3. A B x
  4. P_x A y (already implemented)
  5. ||a * b||
  6. Px A B y
  7. ||A + B||
  8. P_x (a * b)
  9. x^t A B y
  10. ||(Ax) * (By)||
  11. trace(A^T B A)

Does anyone need this?

Bihaqo avatar Jan 09 '19 10:01 Bihaqo

A potential way to design the API:

with t3f.Fuse() as f:
  Ax = t3f.matmul(A, x)
  xAx = t3f.flat_inner(x, Ax)
  fast_xAx = f.optimize(xAx)

Bihaqo avatar Jan 27 '19 21:01 Bihaqo