tfga
tfga copied to clipboard
Improve geometric product performance
Right now the geometric product is basically done by using a 3-Tensor C_ijk
which is very sparse:
a_i b_j C_ijk -> y_k
Initially in this project we attempted to exploit this sparsity by slicing out all the zeroes of the 3-tensor. However this turned out to be ~50 times slower on GPU for full mv mv products in STA. Find a better, GPU-friendly way to exploit this sparsity.
Another idea:
- Keep track of blade values and indices like initially done
- In the geometric product, instead of using gather on the blade values, create a cayley tensor that shuffles the first 2 axes of the original cayley tensor according to the blade indices of the 2 inputs (and drops the non-appearing ones). Also slice out the output zeros like before. The cayley tensor and its resulting blade indices could be cached for each blade-indices pair.
- This still leaves the one-hot zeroes on the 3rd axis (ie. the ones where we densified the sparse mult. table)
Example (but with cached get_partial_cayley
):
def get_partial_cayley(cayley, blade_indices_a, blade_indices_b, blade_indices_out):
cayley = tf.gather(cayley, blade_indices_a, axis=0)
cayley = tf.gather(cayley, blade_indices_b, axis=1)
cayley = tf.gather(cayley, blade_indices_out, axis=2)
return cayley
# e01 e02 e03
# 5 6 7
a = tf.ones([3, 4, 3])
# e0 e1
# 1 2
b = tf.ones([3, 4, 2])
# e1, e012, e013
# 1, 11, 12
"""
s
0
e0 e1 e2 e3
1 2 3 4
e01 e02 e03 e12 e13 e23
5 6 7 8 9 10
e012 e013 e023 e123
11 12 13 14
e0123
15
"""
partial_cayley = get_partial_cayley(ga.cayley, [5, 6, 7], [1, 2], [1, 11, 12])
print(partial_cayley.shape)
x = ga.geom_prod(ga.from_tensor(a, [5, 6, 7]), ga.from_tensor(b, [1, 2]))
y = tfga.mv_ops.mv_multiply(a, b, partial_cayley)
print(x[0, 0])
print(y[0, 0])
Started some initial work here https://github.com/RobinKa/tfga/tree/feature/faster-cayley-prod
Initial results for PGA (16 basis blades), runtime of multiplying two even-graded multivectors (8 basis blades):
- Full (old): 1x
- Partial cayley (not cached): 0.43x
- Partial cayley (cached with dict and
str(indices)
as key): 0.17x - Theoretical: 8^3 / 16^3 = 0.125x
Wow, that is a really nice speedup