allo
allo copied to clipboard
The compiler couldn't identify the loop listed before
Describe the bug When running a transformation buffer_at, the compiler cannot find the band of loop j. But I indeed defined the loop j before.
To Reproduce It happened after customizing the kernel. ` def kernel_md[T :(float32,int32), M: int32, N:int32 ](position_x: "T[M]", position_y: "T[M]", position_z: "T[M]", NL:"T[M, N]", force_x:"T[M]", fx:"T[M]", delx:"T[N]", dely:"T[N]", delz:"T[N]", r2inv:"T[N]", r6inv:"T[N]"):
for i in range(M):
for j in range(N):
jidx:int32 = NL[i ,j]
delx[j] = position_x[i] - position_x[jidx]
dely[j] = position_y[i] - position_y[jidx]
delz[j] = position_z[i] - position_z[jidx]
if((delx[j] * delx[j] + dely[j] * dely[j] + delz[j] * delz[j])==0):
r2inv[j]=(domainEdge*domainEdge*3.0)*1000
else:
r2inv[j] = 1.0 / (delx[j] * delx[j] + dely[j] * dely[j] + delz[j] * delz[j])
r6inv[j] = r2inv[j] * r2inv[j] * r2inv[j]
fx[i] = fx[i]+delx[j] * r2inv[j] * r6inv[j]*(lj1*r6inv[j]-lj2)
force_x[i] = fx[i]
sch0 =allo.customize(kernel_md, instantiate=[concrete_type, m,n])
print(sch0.module)
sch0.split("i",factor=8)
sch0.split("j", factor=8)
sch0.buffer_at(sch0.force_x, axis="j")`
Buggy output
Traceback (most recent call last): File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/examples/polybench/md_2_knn.py", line 60, in <module> mod_test =md(float32, M,N) ^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/examples/polybench/md_2_knn.py", line 47, in md sch0.buffer_at(sch0.force_x, axis="j") File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/customize.py", line 110, in wrapper res = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/customize.py", line 429, in buffer_at band_name, axis = find_loop_in_bands(func, axis) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/ir/transform.py", line 104, in find_loop_in_bands raise RuntimeError(f"Cannot find the band of loop {axis_name}") RuntimeError: Cannot find the band of loop j
Expected behavior The Allo should automatically create an intermediate buffer for force_x and attach it inside the loop j.
It seems this is not an Allo bug, but the misusage of the schedule primitivies. There are two problems here:
- Since you have already split the
jloop, there is nojloop when you call thebuffer_atfunction, but onlyj.outerandj.inner. - There is no need to create an intermediate buffer for
force_xasforce_xis already outside thejloop.
Please let us know if you need further guidance on optimizing this kernel.
We should print out proper error messages instead of letting the compiler crash.
We have already printed out the error message in this case:
File "/Users/rhodama/CORNELL/Design_project/bin_sp24/allo/allo/ir/transform.py", line 104, in find_loop_in_bands
raise RuntimeError(f"Cannot find the band of loop {axis_name}")
RuntimeError: Cannot find the band of loop j in "kernel_md"
We can further provide a list of available loop axes for users to choose from, but we cannot instruct users where to insert the buffers as it may need additional analysis (which is costly).
Let's keep this issue open for now. I have some suggestions on the error messaging mechanism.