warp icon indicating copy to clipboard operation
warp copied to clipboard

[BUG] Array gradients should be zeroed in storage operations

Open daedalus5 opened this issue 1 month ago • 0 comments

Bug Description

When an array is written to, it either represents an intermediate state in the graph or a final state. In either case, it is not an input to the graph, and so although we may kickstart an adjoint chain by setting output grads to one, the forward tensor values for the same output array are merely functions of other inputs, and so end-to-end the output adjoint is zero.

We could consider accompanying array storage operations with gradient clearing in the backward pass. Doing so would arrest adjoints accumulating over (improper) inter- and intra- kernel overwrites.

Example:

import warp as wp

@wp.kernel
def test_kernel_1(x: wp.array(dtype=float), y: wp.array(dtype=float)):
    i = wp.tid()
    a = x[i]
    b = 2.0 * a
    y[i] = b


@wp.kernel
def test_kernel_2(x: wp.array(dtype=float), y: wp.array(dtype=float)):
    i = wp.tid()
    c = x[i]
    d = 3.0 * c
    y[i] = d


N = 4

x_a = wp.ones(N, dtype=float, requires_grad=True)
x_b = wp.ones(N, dtype=float, requires_grad=True)
y = wp.zeros_like(x_b)

with wp.Tape() as tape:
    wp.launch(test_kernel_1, dim=N, inputs=[x_a, y])
    wp.launch(test_kernel_2, dim=N, inputs=[x_b, y])

y.grad = wp.ones_like(y)

tape.backward()

print(y.numpy())
print(x_a.grad.numpy())
print(x_b.grad.numpy())
print(y.grad.numpy())

# prints:
# [3. 3. 3. 3.]
# [2. 2. 2. 2.]
# [3. 3. 3. 3.]
# [1. 1. 1. 1.]

# should print:
# [3. 3. 3. 3.]
# [0. 0. 0. 0.]
# [3. 3. 3. 3.]
# [0. 0. 0. 0.]

This would represent a departure from how Warp has handled gradients thus far. Currently we keep track of "partial" adjoints for non-leaf arrays and don't bother zeroing along the way. This works but it can bury user mistakes like the one above. Zeroing would ensure array overwrites don't propagate gradients to earlier inputs.

In some (admittedly edge) cases zeroing would yield a more accurate adjoint:

@wp.kernel()
def tile_math_matmul_kernel(
    ga: wp.array2d(dtype=wp.float16), gb: wp.array2d(dtype=wp.float32), gc: wp.array2d(dtype=wp.float64)
):
    i, j = wp.tid()
    a = wp.tile_load(ga, shape=(TILE_M, TILE_K), offset=(i * TILE_M, j * TILE_K))
    b = wp.tile_load(gb, shape=(TILE_K, TILE_N), offset=(i * TILE_K, j * TILE_N))
    c = wp.tile_load(gc, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N))
    wp.tile_matmul(a, b, c, alpha=0.5, beta=-1.3)
    wp.tile_store(gc, c, offset=(i * TILE_M, j * TILE_N))


def test_tile_math_matmul(test, device):
    rng = np.random.default_rng(42)

    A = rng.random((TILE_M, TILE_K), dtype=np.float64).astype(np.float16)
    B = rng.random((TILE_K, TILE_N), dtype=np.float32)
    C = rng.random((TILE_M, TILE_N), dtype=np.float64)

    A_wp = wp.array(A, requires_grad=True, device=device)
    B_wp = wp.array(B, requires_grad=True, device=device)
    C_wp = wp.array(C, requires_grad=True, device=device)

    with wp.Tape() as tape:
        wp.launch_tiled(
            tile_math_matmul_kernel,
            dim=[1, 1],
            inputs=[A_wp, B_wp, C_wp],
            block_dim=TILE_DIM,
            device=device,
        )

    # verify forward pass
    assert_np_equal(C_wp.numpy(), 0.5 * A @ B - 1.3 * C, tol=1e-2)

    adj_C = np.ones_like(C)

    tape.backward(grads={C_wp: wp.array(adj_C, device=device)})

    assert_np_equal(A_wp.grad.numpy(), 0.5 * adj_C @ B.T, tol=1e-2)
    assert_np_equal(B_wp.grad.numpy(), 0.5 * A.T @ adj_C, tol=1e-2)
    assert_np_equal(C_wp.grad.numpy(), adj_C - 1.3 * adj_C, tol=1e-2)

If we were to update array storage behavior, then we could write down the proper adjoint for C in this case:

assert_np_equal(C_wp.grad.numpy(), -1.3 * adj_C, tol=1e-2)

But we could also create an additional array D to get the correct result.

Alternatively we could add a warning if WRITE -> WRITE is detected in a graph.

System Information

No response

daedalus5 avatar Oct 30 '25 14:10 daedalus5