mpc.pytorch icon indicating copy to clipboard operation
mpc.pytorch copied to clipboard

Fail to backprop through LQRStepFn

Open anby-dmr opened this issue 8 months ago • 0 comments
trafficstars

When I tried to do loss back propagation through LQRStepFn, an error occured:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 10
---> 10 loss.backward()

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    511 if has_torch_function_unary(self):
    512     return handle_torch_function(
    513         Tensor.backward,
    514         (self,),
   (...)
    519         inputs=inputs,
    520     )
--> 521 torch.autograd.backward(
    522     self, gradient, retain_graph, create_graph, inputs=inputs
    523 )

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    284     retain_graph = create_graph
    286 # The reason we repeat the same comment below is that
    287 # some Python versions print out the first line of a multi-line function
    288 # calls in the traceback and some print out the last line
--> 289 _engine_run_backward(
    290     tensors,
    291     grad_tensors_,
    292     retain_graph,
    293     create_graph,
    294     inputs,
    295     allow_unreachable=True,
    296     accumulate_grad=True,
    297 )

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs)
    767     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    768 try:
--> 769     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    770         t_outputs, *args, **kwargs
    771     )  # Calls into the C++ engine to run the backward pass
    772 finally:
    773     if attach_logging_hooks:

File d:\Anaconda3\envs\graduation\lib\site-packages\torch\autograd\function.py:306, in BackwardCFunction.apply(self, *args)
    300     raise RuntimeError(
    301         "Implementing both 'backward' and 'vjp' for a custom "
    302         "Function is not allowed. You should only implement one "
    303         "of them."
    304     )
    305 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 306 return user_fn(self, *args)

TypeError: backward() takes from 3 to 5 positional arguments but 7 were given

I guess the problem is in lqr_step.py - class LQRStepFn, where the number of parameters of forward() and backward() do not match properly:

    class LQRStepFn(Function):
        # @profile
        @staticmethod
        def forward(ctx, x_init, C, c, F, f=None):
            ........
            return new_x, new_u, torch.Tensor([n_total_qp_iter]), \
              for_out.costs, for_out.full_du_norm, for_out.mean_alphas

        @staticmethod
        def backward(ctx, dl_dx, dl_du, temp=None, temp2=None):
            start = time.time()
            x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors

In the above code snippet, forward() returns 6 parameters but backward() only accept 4 parameters. After adding 2 arguments to backward(), the loss can backprop without reporting an error:

        @staticmethod
        def backward(ctx, dl_dx, dl_du, temp=None, temp2=None, temp3=None, temp4=None):
            start = time.time()
            x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors

anby-dmr avatar Mar 17 '25 13:03 anby-dmr