[Question]: Regarding tuple iteration
I'm trying to implement flash attention 4 in CuTile and got stuck on the polynomial exponent. Essentially flash attention 4 uses a polynomial approximation for exp2 in order to reduce the MUFU pressure.
However, when I try to implement it in CuTile I get a compilation error. My current version is:
@ct.function(host=False)
def _poly_eval(coeffs: ConstTuple, x: ct.Tile) -> ct.Tile:
result = ct.zeros(x.shape, dtype=x.dtype)
num_coeffs = len(coeffs)
for i in range(0, num_coeffs):
result = result * x + coeffs[i]
return result
@ct.function(host=False)
def _poly_exp2(x: ct.Tile) -> ct.Tile:
poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625)
... # rest of ex2 approx
result = _poly_eval(poly_ex2_deg3, x_frac)
...
return result
However I get:
cuda.tile._exception.TileTypeError: Invalid argument #2 of getitem(): Expected an integer constant, but given value is not constant
On: result = result * x + coeffs[i]
How am I supposed to iterate over the coeffs tuple?
@odelame thank you for your question. This is a limitation we are actively working on lifting.
The issue is that i inside for ... is not a constant, and tuple indexing requires its index variable to be a const. As a workaround, is it possible to manually expand the the tuple, if the length of coeff is known, i.e.
def _poly_eval_deg3(coeffs: ConstTuple, x: ct.Tile) -> ct.Tile:
result = ct.zeros(x.shape, dtype=x.dtype)
result = ((coeffs[0] * x + coeffs[1]) * x + coeffs[2]) * x + coeffs[3]
return result