functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Batching rule for GRU cell

Open 0aqz0 opened this issue 2 years ago • 3 comments

Hi, I have a GRU model and want to calculate the jacobian of the model with functorch. But there is a performance drop because we have not yet implemented the batching rule for aten::_thnn_fused_gru_cell and aten::_thnn_differentiable_gru_cell_backward.

Thanks if anyone can prioritize its implementation.

/home/haodong/miniconda3/envs/pose/lib/python3.8/site-packages/torch/nn/modules/rnn.py:1279: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_thnn_fused_gru_cell. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:83.)
  ret = _VF.gru_cell(
/home/haodong/miniconda3/envs/pose/lib/python3.8/site-packages/torch/autograd/__init__.py:276: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_thnn_differentiable_gru_cell_backward. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:83.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Here is my model.

class Transition(nn.Module):
    def __init__(self, in_features, hidden1=256, hidden2=256):
        super(Transition, self).__init__()
        self.in_features = in_features
        self.gru = nn.GRUCell(in_features, hidden1)
        self.fc1 = nn.Linear(hidden1, hidden2)
        self.fc2 = nn.Linear(hidden2, in_features)

    def transition(self, mu, hidden):
        # mu: (batch, feature)
        hidden = self.gru(mu, hidden)
        # out: (batch, feature), choose the last time step
        out = F.relu(self.fc1(hidden))
        out = self.fc2(out)
        return mu + out, hidden

    def forward(self, l_mu, l_var, hidden):
        ###############################################
        # udpate mu
        ###############################################
        new_mu, hidden = self.transition(l_mu, hidden)

        ###############################################
        # update covariance
        ###############################################
        # calculate jacobian matrix
        J = vmap(jacrev(self.transition, argnums=0))(l_mu, hidden)[0]
        # covariance matrix of l t+1
        new_var = torch.matmul(torch.matmul(J, l_var), J.permute(0, 2, 1))

        return new_mu, new_var, hidden

0aqz0 avatar Aug 09 '22 04:08 0aqz0

This may be a bit difficult to do in a performant manner. One thing we could do is decompose the GRU and rely on AOTAutograd to generate a fast kernel for this. Not sure how feasible that is, cc @Chillee ?

zou3519 avatar Aug 10 '22 17:08 zou3519

I suspect that would be more likely, yeah.

Chillee avatar Aug 10 '22 18:08 Chillee

potentially related: https://github.com/pytorch/pytorch/issues/82577, cc @samdow

zou3519 avatar Aug 10 '22 18:08 zou3519