gpt
gpt copied to clipboard
`accumulate_gradient` does not check for necessary traces for `tensor` objects
grid = g.grid([2,2,2,2], g.double)
ms = g.ad.reverse.node(g.tensor(g.object_type.ot_matrix_spin(4)))
vsc = g.vspincolor(grid)
g.norm2(ms*vsc)()
The gradient backpropagation does not work because the color trace is not done before adding the gradients, see comment in the code below.
def accumulate_gradient(lhs, rhs_gradient, getter=None, setter=None):
lhs_gradient = lhs.gradient
if getter is not None:
lhs_gradient = getter(lhs_gradient)
rhs_field = is_field(rhs_gradient)
lhs_field = is_field(lhs_gradient)
if rhs_field and not lhs_field:
rhs_gradient = g.sum(rhs_gradient)
if g.util.is_num(lhs_gradient) and isinstance(rhs_gradient, g.expr):
rhs_gradient = g(rhs_gradient)
# the following checks are only done when lhs is a lattice, but should also be similarily done for a tensor
if isinstance(lhs_gradient, g.lattice) and isinstance(rhs_gradient, g.expr):
grid, rhs_otype, is_list, nlist = rhs_gradient.container()
assert not is_list # for now
lhs_otype = lhs_gradient.otype
if lhs_otype.__name__ != rhs_otype.__name__:
if rhs_otype.spintrace[2] is not None:
rhs_spintrace_otype = rhs_otype.spintrace[2]()
if accumulate_compatible(lhs_otype, rhs_spintrace_otype):
rhs_gradient = g(g.spin_trace(rhs_gradient))
rhs_otype = rhs_gradient.otype
elif rhs_spintrace_otype.colortrace[2] is not None:
if accumulate_compatible(lhs_otype, rhs_spintrace_otype.colortrace[2]()):
rhs_gradient = g(g.trace(rhs_gradient))
rhs_otype = rhs_gradient.otype
if rhs_otype.colortrace[2] is not None:
if accumulate_compatible(lhs_otype, rhs_otype.colortrace[2]()):
rhs_gradient = g(g.color_trace(rhs_gradient))
rhs_otype = rhs_gradient.otype
if setter is not None:
setter(lhs.gradient, lhs_gradient + rhs_gradient)
else:
lhs.gradient += rhs_gradient