torchdyn icon indicating copy to clipboard operation
torchdyn copied to clipboard

shape of gradient does not match the parameter shape in vector field while using adjoint method

Open ljxw88 opened this issue 2 years ago • 0 comments

In torchdyn -> numerics -> sensitivity.py

function _gather_odefunc_adjoint(), line 71:

dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) for el in dμ], dim=-1)

should be fixed by

param_shapes = [p.shape for p in vf.parameters()]
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(param_shapes[i]).to(t.device).flatten() for i, el in enumerate(dμ)], dim=-1)

otherwise, the shape of gradient (torch.zeros(1)) does not match the parameter shape in vector field.

ljxw88 avatar Nov 15 '23 18:11 ljxw88