functorch
functorch copied to clipboard
Default Partition: Backward module does not contain computation of the parameters' gradients.
I am trying to use functorch to train a model in a more JAX-like way. I use the aot_function to get a forward graph module and a backward graph module, but find out that in the backward module, it does not contain the computation of the parameters' gradients.
After reading the source code, I think the below function erases the corresponding computation part, as the parameter gradient calculation is irrelevant to the output of the backward module.
https://github.com/pytorch/functorch/blob/6c3b57f3a3fd54a2f3e3db12c2059669112bed6c/functorch/_src/partitioners.py#L94
I think it would be better for you to offer an api that can capture these important computation in training a neural network. Would you develop this in the future?
Here is the backward module of alexnet that I generates. As you can see, it does not involves the computation of the gradients of parameters.
import torch
from torch.nn import *
class alexnet_backward(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('_tensor_constant3', torch.empty([1000, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant4', torch.empty([4096, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant5', torch.empty([4096, 9216], dtype=torch.float32))
self._param_constant8 = torch.nn.Parameter(torch.empty([256, 256, 3, 3], dtype=torch.float32))
self._param_constant6 = torch.nn.Parameter(torch.empty([256, 384, 3, 3], dtype=torch.float32))
self._param_constant4 = torch.nn.Parameter(torch.empty([384, 192, 3, 3], dtype=torch.float32))
self._param_constant2 = torch.nn.Parameter(torch.empty([192, 64, 5, 5], dtype=torch.float32))
self._param_constant0 = torch.nn.Parameter(torch.empty([64, 3, 11, 11], dtype=torch.float32))
self.load_state_dict(torch.load(r'alexnet_backward/state_dict.pt'))
def forward(self, primals_1, relu_, getitem, relu__1, getitem_2, relu__2, relu__3, relu__4, getitem_4, div_, relu__5, div__1, relu__6, tangents_1):
detach = torch.ops.aten.detach(relu_)
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices(relu_, [3, 3], [2, 2])
getitem_1 = max_pool2d_with_indices[1]; max_pool2d_with_indices = None
detach_1 = torch.ops.aten.detach(relu__1)
max_pool2d_with_indices_1 = torch.ops.aten.max_pool2d_with_indices(relu__1, [3, 3], [2, 2])
getitem_3 = max_pool2d_with_indices_1[1]; max_pool2d_with_indices_1 = None
detach_2 = torch.ops.aten.detach(relu__2)
detach_3 = torch.ops.aten.detach(relu__3)
detach_4 = torch.ops.aten.detach(relu__4)
max_pool2d_with_indices_2 = torch.ops.aten.max_pool2d_with_indices(relu__4, [3, 3], [2, 2])
getitem_5 = max_pool2d_with_indices_2[1]; max_pool2d_with_indices_2 = None
detach_5 = torch.ops.aten.detach(relu__5); relu__5 = None
detach_6 = torch.ops.aten.detach(relu__6); relu__6 = None
_tensor_constant3 = self._tensor_constant3
mm = torch.ops.aten.mm(tangents_1, _tensor_constant3); tangents_1 = _tensor_constant3 = None
detach_7 = torch.ops.aten.detach(detach_6); detach_6 = None
threshold_backward = torch.ops.aten.threshold_backward(mm, detach_7, 0); mm = detach_7 = None
_tensor_constant4 = self._tensor_constant4
mm_2 = torch.ops.aten.mm(threshold_backward, _tensor_constant4); threshold_backward = _tensor_constant4 = None
mul_2 = torch.ops.aten.mul(mm_2, div__1); mm_2 = div__1 = None
detach_8 = torch.ops.aten.detach(detach_5); detach_5 = None
threshold_backward_1 = torch.ops.aten.threshold_backward(mul_2, detach_8, 0); mul_2 = detach_8 = None
_tensor_constant5 = self._tensor_constant5
mm_4 = torch.ops.aten.mm(threshold_backward_1, _tensor_constant5); threshold_backward_1 = _tensor_constant5 = None
mul_3 = torch.ops.aten.mul(mm_4, div_); mm_4 = div_ = None
view_4 = torch.ops.aten.view(mul_3, [1, 256, 6, 6]); mul_3 = None
_adaptive_avg_pool2d_backward = torch.ops.aten._adaptive_avg_pool2d_backward(view_4, getitem_4); view_4 = getitem_4 = None
max_pool2d_with_indices_backward = torch.ops.aten.max_pool2d_with_indices_backward(_adaptive_avg_pool2d_backward, relu__4, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_5); _adaptive_avg_pool2d_backward = relu__4 = getitem_5 = None
detach_9 = torch.ops.aten.detach(detach_4); detach_4 = None
threshold_backward_2 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward, detach_9, 0); max_pool2d_with_indices_backward = detach_9 = None
_param_constant8_2 = self._param_constant8
convolution_backward = torch.ops.aten.convolution_backward(threshold_backward_2, relu__3, _param_constant8_2, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]); threshold_backward_2 = relu__3 = _param_constant8_2 = None
getitem_6 = convolution_backward[0]; convolution_backward = None
detach_10 = torch.ops.aten.detach(detach_3); detach_3 = None
threshold_backward_3 = torch.ops.aten.threshold_backward(getitem_6, detach_10, 0); getitem_6 = detach_10 = None
_param_constant6_2 = self._param_constant6
convolution_backward_1 = torch.ops.aten.convolution_backward(threshold_backward_3, relu__2, _param_constant6_2, [256], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]); threshold_backward_3 = relu__2 = _param_constant6_2 = None
getitem_9 = convolution_backward_1[0]; convolution_backward_1 = None
detach_11 = torch.ops.aten.detach(detach_2); detach_2 = None
threshold_backward_4 = torch.ops.aten.threshold_backward(getitem_9, detach_11, 0); getitem_9 = detach_11 = None
_param_constant4_2 = self._param_constant4
convolution_backward_2 = torch.ops.aten.convolution_backward(threshold_backward_4, getitem_2, _param_constant4_2, [384], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, True]); threshold_backward_4 = getitem_2 = _param_constant4_2 = None
getitem_12 = convolution_backward_2[0]; convolution_backward_2 = None
max_pool2d_with_indices_backward_1 = torch.ops.aten.max_pool2d_with_indices_backward(getitem_12, relu__1, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_3); getitem_12 = relu__1 = getitem_3 = None
detach_12 = torch.ops.aten.detach(detach_1); detach_1 = None
threshold_backward_5 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward_1, detach_12, 0); max_pool2d_with_indices_backward_1 = detach_12 = None
_param_constant2_2 = self._param_constant2
convolution_backward_3 = torch.ops.aten.convolution_backward(threshold_backward_5, getitem, _param_constant2_2, [192], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]); threshold_backward_5 = getitem = _param_constant2_2 = None
getitem_15 = convolution_backward_3[0]; convolution_backward_3 = None
max_pool2d_with_indices_backward_2 = torch.ops.aten.max_pool2d_with_indices_backward(getitem_15, relu_, [3, 3], [2, 2], [0, 0], [1, 1], False, getitem_1); getitem_15 = relu_ = getitem_1 = None
detach_13 = torch.ops.aten.detach(detach); detach = None
threshold_backward_6 = torch.ops.aten.threshold_backward(max_pool2d_with_indices_backward_2, detach_13, 0); max_pool2d_with_indices_backward_2 = detach_13 = None
_param_constant0_2 = self._param_constant0
convolution_backward_4 = torch.ops.aten.convolution_backward(threshold_backward_6, primals_1, _param_constant0_2, [64], [4, 4], [2, 2], [1, 1], False, [0, 0], 1, [True, True, True]); threshold_backward_6 = primals_1 = _param_constant0_2 = None
getitem_18 = convolution_backward_4[0]; convolution_backward_4 = None
return [getitem_18]
Another question is that, I also tried to output the code of joint_forward_backward
. The joint_forward_backward
GraphModule contains the computation of weight gradients. But I find out that the parameters of Linear
layer is redefined in the __init__
method. As you can see in the below code. _tensor_constant0
and _tensor_connstant5
are supposed to be the one and the same parameter, but with different shape. _tensor_constant5
is the transpose of _tensor_constant0
. But here it seems to register two buffers for it. Any suggestion to avoid doing this?
class FxModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('_tensor_constant0', torch.empty([9216, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant1', torch.empty([4096, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant2', torch.empty([4096, 1000], dtype=torch.float32))
self.register_buffer('_tensor_constant3', torch.empty([1000, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant4', torch.empty([4096, 4096], dtype=torch.float32))
self.register_buffer('_tensor_constant5', torch.empty([4096, 9216], dtype=torch.float32))
self._param_constant0 = torch.nn.Parameter(torch.empty([64, 3, 11, 11], dtype=torch.float32))
self._param_constant1 = torch.nn.Parameter(torch.empty([64], dtype=torch.float32))
self._param_constant2 = torch.nn.Parameter(torch.empty([192, 64, 5, 5], dtype=torch.float32))
self._param_constant3 = torch.nn.Parameter(torch.empty([192], dtype=torch.float32))
self._param_constant4 = torch.nn.Parameter(torch.empty([384, 192, 3, 3], dtype=torch.float32))
self._param_constant5 = torch.nn.Parameter(torch.empty([384], dtype=torch.float32))
self._param_constant6 = torch.nn.Parameter(torch.empty([256, 384, 3, 3], dtype=torch.float32))
self._param_constant7 = torch.nn.Parameter(torch.empty([256], dtype=torch.float32))
self._param_constant8 = torch.nn.Parameter(torch.empty([256, 256, 3, 3], dtype=torch.float32))
self._param_constant9 = torch.nn.Parameter(torch.empty([256], dtype=torch.float32))
self._param_constant10 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
self._param_constant11 = torch.nn.Parameter(torch.empty([4096], dtype=torch.float32))
self._param_constant12 = torch.nn.Parameter(torch.empty([1000], dtype=torch.float32))
self.load_state_dict(torch.load(r'forward_backward/state_dict.pt'))
@ConnollyLeon If you want to compile it, use aot_module
, which will lift up the parameters to inputs of the function.
If you're just trying to accelerate it, you can use memory_efficient_fusion
, which has some preconfigured settings that should work well for acceleration on CUDA.
@Chillee Thanks for you reply. But why did the weight
parameters of Linear layers turns out to become tensor_constant
in the FxModule? Could you please help explain this?
@ConnollyLeon If you trace with aot_function
, then it'll only treat the inputs to that function as "changeable values", and it'll assume everything else is constant (including parameters!).