Traceback (most recent call last):
File "train_ctfp.py", line 160, in
loss.backward()
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torch/autograd/init.py", line 147, in backward
Variable._execution_engine.run_backward(
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torch/autograd/function.py", line 87, in apply
return self._forward_cls.backward(self, *args) # type: ignore[attr-defined]
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torchdiffeq/_impl/adjoint.py", line 126, in backward
aug_state = odeint(
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint
shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torchdiffeq/_impl/misc.py", line 207, in _check_inputs
rtol = _tuple_tol('rtol', rtol, shapes)
File "/home/fry/anaconda3/envs/CTFP1/lib/python3.8/site-packages/torchdiffeq/_impl/misc.py", line 115, in _tuple_tol
assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name)
AssertionError: If using tupled rtol it must have the same length as the tuple y0