loopy icon indicating copy to clipboard operation
loopy copied to clipboard

[Transform API] Simple tiling can be tedious to implement

Open kaushikcfd opened this issue 1 year ago • 3 comments

Consider the simple batched matvec example:

knl = lp.make_kernel(
    "{[e,i,j]: 0<=e<1000 and 0<=i,j<10}",
    """
    out[e, i] = sum(j, D[i, j] * u[e, j])
    """,
    [lp.GlobalArg("u,D,out", "float64", lp.auto)])

# Tile "j"+prefetch to reduce the cache reuse distance of "D"
knl = lp.split_iname(knl, "j", 5)
knl = lp.add_prefetch(knl, "D", sweep_inames=["i", "j_inner"])

print(lp.generate_code_v2(knl).device_code())

This results in a linearization error with:

* Duplicate j_outer within instructions (id:insn_j_outer_j_inner_update)
* Duplicate i within instructions (id:insn_j_outer_j_inner_init or id:insn)

Notice how the duplication options force us to realize the reduction. Making this simple transformation quite tedious to implement. We should have some interface to make this transformation easier as (at least in the context of einsums) this is a pretty common transformation.

kaushikcfd avatar Aug 09 '22 23:08 kaushikcfd

Some options could be:

  1. a kernel could have an attribute post_realize_reduction_tansforms_callback
  2. making reduction nodes taggable and defining some implementation tags that would help in this case.

kaushikcfd avatar Aug 09 '22 23:08 kaushikcfd

For posterity, the resulting kernel is this:

  for D_dim_0, D_dim_1, j_outer
↱       D_fetch[D_dim_0, D_dim_1] = D[D_dim_0, D_dim_1 + 5*j_outer]  {id=D_fetch_rule}
│ end D_dim_0, D_dim_1, j_outer
│ for i, e
└     out[e, i] = reduce(sum, [j_outer, j_inner], D_fetch[i, j_inner]*u[e, j_inner + j_outer*5])  {id=insn}
  end i, e

and after reduction realization, it becomes

    for j_outer, D_dim_0, D_dim_1
↱         D_fetch[D_dim_0, D_dim_1] = D[D_dim_0, D_dim_1 + 5*j_outer]  {id=D_fetch_rule}
│   end j_outer, D_dim_0, D_dim_1
│   for e, i
│↱      acc_j_outer_j_inner = 0.0  {id=insn_j_outer_j_inner_init}
││      for j_inner, j_outer
├└↱         acc_j_outer_j_inner = acc_j_outer_j_inner + D_fetch[i, j_inner]*u[e, j_inner + j_outer*5]  {id=insn_j_outer_j_inner_update}
│ │     end j_inner, j_outer
└ └     out[e, i] = acc_j_outer_j_inner  {id=insn}
    end e, i

I guess my first question here is what you would like to happen. Duplicating j_outer will lead to invalid code, I think. FWIW, I get different duplication options:

* Duplicate j_outer within instructions (id:D_fetch_rule)
* Duplicate j_outer within instructions (id:insn_j_outer_j_inner_update)
* Duplicate e within instructions (id:insn_j_outer_j_inner_init or id:insn)

inducer avatar Aug 18 '22 19:08 inducer

I would first privatize the temporary acc_j_outer_j_inner in the iname i and duplicate i in the instructions insn and insn_j_outer_j_inner_init. This way the accumulator's state is stored as we perform the computation of a tile.

kaushikcfd avatar Aug 18 '22 19:08 kaushikcfd