torchdiffeq copied to clipboard
RuntimeError in odeint_adjoint
Hello, I have run my code by using odeint successfully, however when I use the odeint_adjoint, it comes out the error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
here is my code:
class M50_Func(nn.Module):
def __init__(self, ET_net, Q_net, params, interps, ode_lib='torchdiffeq'):
self.f, self.Smax, self.Qmax, self.Df, self.Tmax, self.Tmin = params
self.ET_net = ET_net
self.ode_lib = ode_lib
self.Q_net = Q_net
self.precp_interp, self.temp_interp, self.lday_interp = interps
def forward(self, t, S):
from models.common_net import Ps, Pr, M, step_fct
S_snow, S_water = S[0][0], S[0][1]
precp = self.precp_interp.evaluate(t).to(torch.float32)
temp = self.temp_interp.evaluate(t).to(torch.float32)
lday = self.lday_interp.evaluate(t).to(torch.float32)
# precp = torch.from_numpy(self.precp_interp(t.numpy()).astype(np.float32)).to(device)
# temp = torch.from_numpy(self.temp_interp(t.numpy()).astype(np.float32)).to(device)
# lday = torch.from_numpy(self.lday_interp(t.numpy()).astype(np.float32)).to(device)
ET_output = self.ET_net(torch.tensor([S_snow, S_water, temp]))
Q_output = self.Q_net(torch.tensor([S_water, precp]))
melt_output = M(S_snow, temp, self.Df, self.Tmax)
dS_1 = Ps(precp, temp, self.Tmin) - melt_output
dS_2 = Pr(precp, temp, self.Tmin) + melt_output - step_fct(S_water) * lday * torch.exp(
ET_output) - step_fct(S_water) * torch.exp(Q_output)
return torch.tensor([dS_1, dS_2]).unsqueeze(0)
class M50_Solver(BaseLearner):
def __init__(self, solve_func: nn.Module, rtol=1e-6, atol=1e-6, ode_lib='torchdiffeq',
loss_metric=torch.nn.MSELoss(), eval_metric_list=None, lr=0.01, optimizer=None):
super().__init__(solve_func, loss_metric, eval_metric_list, lr, optimizer)
self.solve_func = solve_func
self.ode_lib = ode_lib
self.rtol = rtol
self.atol = atol
def forward(self, x, t_eval):
if len(x.shape) > 2:
x = x[0]
if len(t_eval.shape) > 1:
t_eval = t_eval[0]
t_eval =
y0 = torch.tensor([[x[0, 0], x[0, 1]]])
sol = odeint_adjoint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol,
adjoint_options={"norm": "seminorm"})
# adjoint_params=list(self.solve_func.ET_net.parameters())
# + list(self.solve_func.Q_net.parameters()))
# sol = odeint(self.solve_func, y0=y0, t=t_eval, rtol=self.rtol, atol=self.atol)
sol_1 = sol[:, 0, 1]
y_hat = torch.exp(self.solve_func.Q_net(torch.concat([sol_1.unsqueeze(1), x[:, 2].unsqueeze(1)], dim=1)))
return y_hat
The BaseLearner
extends from the pytorch_lightning.LightningModule