mpc.pytorch
mpc.pytorch copied to clipboard
Fail to backprop through LQRStepFn
trafficstars
When I tried to do loss back propagation through LQRStepFn, an error occured:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[8], line 10
---> 10 loss.backward()
File d:\Anaconda3\envs\graduation\lib\site-packages\torch\_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
511 if has_torch_function_unary(self):
512 return handle_torch_function(
513 Tensor.backward,
514 (self,),
(...)
519 inputs=inputs,
520 )
--> 521 torch.autograd.backward(
522 self, gradient, retain_graph, create_graph, inputs=inputs
523 )
File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
284 retain_graph = create_graph
286 # The reason we repeat the same comment below is that
287 # some Python versions print out the first line of a multi-line function
288 # calls in the traceback and some print out the last line
--> 289 _engine_run_backward(
290 tensors,
291 grad_tensors_,
292 retain_graph,
293 create_graph,
294 inputs,
295 allow_unreachable=True,
296 accumulate_grad=True,
297 )
File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs)
767 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
768 try:
--> 769 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
770 t_outputs, *args, **kwargs
771 ) # Calls into the C++ engine to run the backward pass
772 finally:
773 if attach_logging_hooks:
File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\function.py:306, in BackwardCFunction.apply(self, *args)
300 raise RuntimeError(
301 "Implementing both 'backward' and 'vjp' for a custom "
302 "Function is not allowed. You should only implement one "
303 "of them."
304 )
305 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 306 return user_fn(self, *args)
TypeError: backward() takes from 3 to 5 positional arguments but 7 were given
I guess the problem is in lqr_step.py - class LQRStepFn, where the number of parameters of forward() and backward() do not match properly:
class LQRStepFn(Function):
# @profile
@staticmethod
def forward(ctx, x_init, C, c, F, f=None):
........
return new_x, new_u, torch.Tensor([n_total_qp_iter]), \
for_out.costs, for_out.full_du_norm, for_out.mean_alphas
@staticmethod
def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
start = time.time()
x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors
In the above code snippet, forward() returns 6 parameters but backward() only accept 4 parameters. After adding 2 arguments to backward(), the loss can backprop without reporting an error:
@staticmethod
def backward(ctx, dl_dx, dl_du, temp=None, temp2=None, temp3=None, temp4=None):
start = time.time()
x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors