TensorComprehensions
TensorComprehensions copied to clipboard
Support for triangular dependencies in TC language
one of our users reported on slack channel that they were trying to translate the following Numpy code
def A_matvec_batch(A, X):
n, m = X.shape
Y = np.zeros((n, m))
for j in range(m):
for i in range(n):
# Out of bound values lead to the following rules to define limit in s index
# s_min = -1 for all rows except the zeroth and 0 for the zero row
# s_max = 2 for all rows except the last one and 1 for the last row
s_min = max(-1, -i)
s_max = min(2, n - i)
for s in range(s_min, s_max): # this loop simplifies differentiation
Y[i, j] += A[i, s + 1] * X[i + s, j]
return Y
to the TC:
lang = """
def A_mv(float(n, 3) A, float(n, size) X) -> (O) {
output(i, j) +=! A(i, s+1) * X(i+s, j) where s in (-1<=-i ? -i : -1):(2<=n-i ? 2 : n-i)
}
"""
but this TC fails to compile. The reason being: there are triangular dependencies in the bound inference i.e. s depends on i bounds. This is currently not supported in TC.
But the good news is that it is entirely language and inference issue and not the polyhedral backend.
@abadams has a nice proposal to handle this and we will discuss this further and support such TCs.
cc @abadams @nicolasvasilache @zdevito