functorch
functorch copied to clipboard
Batching rule for GRU cell
Hi, I have a GRU model and want to calculate the jacobian of the model with functorch. But there is a performance drop because we have not yet implemented the batching rule for aten::_thnn_fused_gru_cell
and aten::_thnn_differentiable_gru_cell_backward
.
Thanks if anyone can prioritize its implementation.
/home/haodong/miniconda3/envs/pose/lib/python3.8/site-packages/torch/nn/modules/rnn.py:1279: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_thnn_fused_gru_cell. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:83.)
ret = _VF.gru_cell(
/home/haodong/miniconda3/envs/pose/lib/python3.8/site-packages/torch/autograd/__init__.py:276: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_thnn_differentiable_gru_cell_backward. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /__w/functorch/functorch/functorch/csrc/BatchedFallback.cpp:83.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Here is my model.
class Transition(nn.Module):
def __init__(self, in_features, hidden1=256, hidden2=256):
super(Transition, self).__init__()
self.in_features = in_features
self.gru = nn.GRUCell(in_features, hidden1)
self.fc1 = nn.Linear(hidden1, hidden2)
self.fc2 = nn.Linear(hidden2, in_features)
def transition(self, mu, hidden):
# mu: (batch, feature)
hidden = self.gru(mu, hidden)
# out: (batch, feature), choose the last time step
out = F.relu(self.fc1(hidden))
out = self.fc2(out)
return mu + out, hidden
def forward(self, l_mu, l_var, hidden):
###############################################
# udpate mu
###############################################
new_mu, hidden = self.transition(l_mu, hidden)
###############################################
# update covariance
###############################################
# calculate jacobian matrix
J = vmap(jacrev(self.transition, argnums=0))(l_mu, hidden)[0]
# covariance matrix of l t+1
new_var = torch.matmul(torch.matmul(J, l_var), J.permute(0, 2, 1))
return new_mu, new_var, hidden
This may be a bit difficult to do in a performant manner. One thing we could do is decompose the GRU and rely on AOTAutograd to generate a fast kernel for this. Not sure how feasible that is, cc @Chillee ?
I suspect that would be more likely, yeah.
potentially related: https://github.com/pytorch/pytorch/issues/82577, cc @samdow