gpt icon indicating copy to clipboard operation
gpt copied to clipboard

`accumulate_gradient` does not check for necessary traces for `tensor` objects

Open simon-pfahler opened this issue 1 year ago • 0 comments

grid = g.grid([2,2,2,2], g.double)

ms = g.ad.reverse.node(g.tensor(g.object_type.ot_matrix_spin(4)))
vsc = g.vspincolor(grid)
g.norm2(ms*vsc)()

The gradient backpropagation does not work because the color trace is not done before adding the gradients, see comment in the code below.

def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):
    lhs_gradient = lhs.gradient
    if getter is not None:
        lhs_gradient = getter(lhs_gradient)
    rhs_field = is_field(rhs_gradient)
    lhs_field = is_field(lhs_gradient)
    if rhs_field and not lhs_field:
        rhs_gradient = g.sum(rhs_gradient)
    if g.util.is_num(lhs_gradient) and isinstance(rhs_gradient, g.expr):
        rhs_gradient = g(rhs_gradient)

    # the following checks are only done when lhs is a lattice, but should also be similarily done for a tensor
    if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr): 
        grid, rhs_otype, is_list, nlist = rhs_gradient.container()
        assert not is_list  # for now
        lhs_otype = lhs_gradient.otype

        if lhs_otype.__name__ != rhs_otype.__name__:
            if rhs_otype.spintrace[2] is not None:
                rhs_spintrace_otype = rhs_otype.spintrace[2]()
                if accumulate_compatible(lhs_otype, rhs_spintrace_otype):
                    rhs_gradient = g(g.spin_trace(rhs_gradient))
                    rhs_otype = rhs_gradient.otype
                elif rhs_spintrace_otype.colortrace[2] is not None:
                    if accumulate_compatible(lhs_otype, rhs_spintrace_otype.colortrace[2]()):
                        rhs_gradient = g(g.trace(rhs_gradient))
                        rhs_otype = rhs_gradient.otype
            if rhs_otype.colortrace[2] is not None:
                if accumulate_compatible(lhs_otype, rhs_otype.colortrace[2]()):
                    rhs_gradient = g(g.color_trace(rhs_gradient))
                    rhs_otype = rhs_gradient.otype

    if setter is not None:
        setter(lhs.gradient, lhs_gradient + rhs_gradient)
    else:
        lhs.gradient += rhs_gradient

simon-pfahler avatar Jun 06 '24 14:06 simon-pfahler