dace icon indicating copy to clipboard operation
dace copied to clipboard

Memlet WCR propagation bug

Open lamyiowce opened this issue 2 years ago • 1 comments

Describe the bug WCR edges are not propagated correctly. As a side effect, applying greedy fusion to an erroneous graph results in a graph that gives incorrect results. To Reproduce

import numpy as np
import dace
def test_greedy_fuse_bug():
    N = 3
    dtype = np.float32

    np.random.seed(42)

    # Create input.
    graph = (scipy.sparse.random(N, N, density=0.5, format='csr')
             + scipy.sparse.eye(N, format='csr'))
    graph.data = np.ones_like(graph.data)
    _, col = graph.indptr, graph.indices
    col = np.copy(col)
    num_entries = col.shape[0]

    out_e = np.random.rand(num_entries).astype(dtype=dtype)

    @dace.program
    def gat(columns, e):
        softmax_sum = np.zeros((N,), dtype=dtype)

        for j in dace.map[0:num_entries]:
            colj = columns[j]
            e[j] = np.exp(e[j])
            softmax_sum[colj] += e[j]

        for j in dace.map[0:num_entries]:
            colj = columns[j]
            e[j] = e[j] / softmax_sum[colj]

    sdfg = gat.to_sdfg(columns=col, e=out_e)
    greedy_fuse(sdfg, device=dace.dtypes.DeviceType.CPU, validate_all=True)
    sdfg(columns=col, e=out_e)

    expected_e = np.zeros((num_entries,), dtype=dtype)
    gat.f(columns=col, e=expected_e)

    check_equal(expected_e, out_e, 'attention_weights')

See the highlighted edge: image

lamyiowce avatar Jun 12 '23 16:06 lamyiowce

I stumbled across the same thing. There is also an example of this in the samples in spmv.py, which looks very similar to your example.

I was wondering, instead of having all edges have the WCR, shouldn't it be the exact inverse of what's happening now? That is, the WCR should only show on the edges in the outer SDFG and no WCR on the memlets inside the nested SDFG.

My reasoning would be that, semantically, there are no conflicts in the nested SDFG, since it's just a scalar operation in this example. So if we look at the nested SDFG on it's own[^1], the WCR doesn't make much sense. The conflicts only occur in the outer SDFG when results from multiple executions of the nested SDFG conflict with one another.

[^1]: My understanding is that this is the point of nested SDFGs, correct me if I'm wrong

JanKleine avatar Jun 15 '23 08:06 JanKleine