functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Test functorch transform behavior on BatchNorm and other weird in-place operations

Open zou3519 opened this issue 4 years ago • 10 comments

Most in-place operations look the same: the first argument is mutable and no other arguments are mutable.

F.batch_norm is an exception to that rule: it has multiple mutable arguments. We should check that our transforms work correctly on it.

zou3519 avatar Nov 16 '21 16:11 zou3519

Hi @zou3519 , is this error related to the current issue? It seems all models with batchnorm can't be used for jacobian calculation.

Adding model.eval() makes it work because it no longer mutates num_batches_tracked, but I think it's not a good workaround.

import torch
import torch.nn as nn
import functorch
from torch.nn.utils import _stateless
import functools

model = nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4), nn.Linear(4, 5), nn.Softmax(dim=-1))
_input = torch.randn(32, 3)
weight = model[0].weight.clone()
weight.requires_grad_()


def func(param: torch.Tensor, _input: torch.Tensor = None):
    _output: torch.Tensor = _stateless.functional_call(
        model, {'0.weight': param}, _input)
    return _output  # (N, C)


jac: torch.Tensor = functorch.jacrev(functools.partial(func, _input=_input))(weight)
Traceback (most recent call last):
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 33, in <module>
    jac: torch.Tensor = functorch.jacrev(functools.partial(func, _input=_input))(weight)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 441, in wrapper_fn
    vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 270, in vjp
    primals_out = func(*diff_primals)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 624, in f_wrapper
    return f(*replaced_args)
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 28, in func
    _output: torch.Tensor = _stateless.functional_call(
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\utils\stateless.py", line 141, in functional_call
    out = module(args, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\container.py", line 139, in forward
    input = module(input)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\batchnorm.py", line 148, in forward
    self.num_batches_tracked.add_(1)  # type: ignore[has-type]
RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::add_.Tensor) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.

ain-soph avatar Aug 04 '22 04:08 ain-soph

Additionally, even if I set the model.eval() to avoid previous issue, I still can't apply vmap onto it.


def vmap_func(_input: torch.Tensor):
    return functorch.jacrev(functools.partial(func, _input=_input.unsqueeze(0)))(weight)


jac = functorch.vmap(vmap_func)(_input)
Traceback (most recent call last):
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 41, in <module>
    jac = functorch.vmap(vmap_func)(_input)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\vmap.py", line 365, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 38, in vmap_func
    return functorch.jacrev(functools.partial(func, _input=_input.unsqueeze(0)))(weight)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 441, in wrapper_fn
    vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 270, in vjp
    primals_out = func(*diff_primals)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 624, in f_wrapper
    return f(*replaced_args)
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 31, in func
    _output: torch.Tensor = _stateless.functional_call(
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\utils\stateless.py", line 141, in functional_call
    out = module(args, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\container.py", line 139, in forward
    input = module(input)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\functional.py", line 2438, in batch_norm
    return torch.batch_norm(
RuntimeError: Batch norm got a batched tensor as input while the running_mean or running_var, which will be updated in place, were not batched.
If you are using a module and do not need eval mode, please set `track_running_stats` to be False.If you are using a prebuilt module and do not need eval mode, please see the functorch website for resources on how to patch your module to work with vmap

ain-soph avatar Aug 04 '22 04:08 ain-soph

It seems all models with batchnorm can't be used for jacobian calculation.

The error message is complaining that the buffers to batch_norm are not being passed into the function func that is being transformed (func only accepts param and _input). Please rewrite func to also accept and use the buffers, something like:

def func(param: torch.Tensor, _input: torch.Tensor = None, running_mean=None, running_var=None):
    _output: torch.Tensor = _stateless.functional_call(
        model, {'0.weight': param, '1.running_var': running_var, '1.running_mean': running_mean}, _input)
    return _output  # (N, C)

Additionally, even if I set the model.eval() to avoid previous issue, I still can't apply vmap onto it.

Setting the model to eval() works for me:

import torch
import torch.nn as nn
import functorch
from torch.nn.utils import _stateless
import functools

model = nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4), nn.Linear(4, 5), nn.Softmax(dim=-
1))
model.eval()

_input = torch.randn(32, 3)
weight = model[0].weight.clone()
weight.requires_grad_()


def func(param: torch.Tensor, _input: torch.Tensor = None):
    _output: torch.Tensor = _stateless.functional_call(
        model, {'0.weight': param}, _input)
    return _output  # (N, C)


jac: torch.Tensor = functorch.jacrev(functools.partial(func, _input=_input))(weight)


def vmap_func(_input: torch.Tensor):
    return functorch.jacrev(functools.partial(func, _input=_input.unsqueeze(0)))(weight)


jac = functorch.vmap(vmap_func)(_input)

cc @samdow -- is there more we should say about batch norm here?

zou3519 avatar Aug 04 '22 14:08 zou3519

Btw, @ain-soph -- I noticed you're using stateless.functional_call rather than functorch.make_functional_with_buffers. We're trying to figure out how to consolidate those two APIs, given that they do almost the same thing. Is there one that you prefer over the other?

zou3519 avatar Aug 04 '22 14:08 zou3519

@zou3519 got it all! Only small things to add 😄

Setting the model to eval() works for me

This was something that was updated somewhat recently. I believe it came in a little too late for the 0.2.0 release but ~will be available for 0.2.1 and~ (EDIT: my bad, it's not available in 0.2.1) is currently available on master

is there more we should say about batch norm here?

Only other thing is that if the second case (vmap over jacrev, during training mode) is something that is important to your use case, it would be great to hear more about it. Specifically understanding how the Jacobian is used can help us understand what behavior makes sense when updating the running mean and variance

samdow avatar Aug 04 '22 14:08 samdow

@zou3519 Thanks for your kind explanation.

Btw, @ain-soph -- I noticed you're using stateless.functional_call rather than functorch.make_functional_with_buffers. We're trying to figure out how to consolidate those two APIs, given that they do almost the same thing. Is there one that you prefer over the other?

Personally, I prefer stateless.functional_call because they allow me to do a dict match, while make_functional requires me to match by list index for params. When there is a huge model and I only want to modify one layer weight, it'll be difficult to use make_functional.

However, stateless.functional_call has a big problem because it can't use nn.DataParallel to utilize multiple GPUs. I don't know if there is any solution yet because it requires the input module as nn.Module.

ain-soph avatar Aug 04 '22 14:08 ain-soph

@zou3519 And I just checked the previously mentioned vmap snippet in release v0.2.0, it still raises the issue I posted.

Not quite sure if it's get solved in master though.

ain-soph avatar Aug 04 '22 15:08 ain-soph

@samdow I’m currently calculating NTK matrix for each input, and making average over the batch. Then I use the largest eigenvalue divided by the smallest one of The averaged NTK matrix as a score function to optimize model (arch) parameters.

This is somewhat motivated by the TENAS paper. (Although their original codes calculates a different matrix with size (Batch,Batch) rather than NTK (num_class,num_class) )

I’ll provide a minimal snippet here.

ain-soph avatar Aug 04 '22 15:08 ain-soph

I use a trick to disable track_running_stats for all batchnorm modules. However, applying vmap is still problematic.

for a in model.modules():
    if isinstance(a, nn.modules.batchnorm._BatchNorm):
        if a.track_running_stats:
            a.track_running_stats = False
            a.register_buffer("running_mean", None)
            a.register_buffer("running_var", None)
            a.register_buffer("num_batches_tracked", None)

The snippet becomes now

import torch
import torch.nn as nn
import functorch
from torch.nn.utils import _stateless
import functools

model = nn.Sequential(nn.Linear(3, 4), nn.BatchNorm1d(4), nn.Linear(4, 5), nn.Softmax(dim=-
1))
model.eval()

for a in model.modules():
    if isinstance(a, nn.modules.batchnorm._BatchNorm):
        if a.track_running_stats:
            a.track_running_stats = False
            a.register_buffer("running_mean", None)
            a.register_buffer("running_var", None)
            a.register_buffer("num_batches_tracked", None)

_input = torch.randn(32, 3)
weight = model[0].weight.clone()
weight.requires_grad_()


def func(param: torch.Tensor, _input: torch.Tensor = None):
    _output: torch.Tensor = _stateless.functional_call(
        model, {'0.weight': param}, _input)
    return _output  # (N, C)


jac: torch.Tensor = functorch.jacrev(functools.partial(func, _input=_input))(weight)


def vmap_func(_input: torch.Tensor):
    return functorch.jacrev(functools.partial(func, _input=_input.unsqueeze(0)))(weight)


jac = functorch.vmap(vmap_func)(_input)

It still raise the error even though model.eval() is called previously, because in BatchNorm forward, there is bn_training = (self.running_mean is None) and (self.running_var is None). bn_training is always true when they are both None

Traceback (most recent call last):
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 37, in <module>
    jac = functorch.vmap(vmap_func)(_input)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\vmap.py", line 365, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 34, in vmap_func
    return functorch.jacrev(functools.partial(func, _input=_input.unsqueeze(0)))(weight)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 441, in wrapper_fn
    vjp_out = vjp(f_wrapper, *primals, has_aux=has_aux)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 270, in vjp
    primals_out = func(*diff_primals)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\functorch\_src\eager_transforms.py", line 624, in f_wrapper
    return f(*replaced_args)
  File "C:\Users\ain-s\OneDrive\workspace\trojanzoo\test.py", line 25, in func
    _output: torch.Tensor = _stateless.functional_call(
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\utils\stateless.py", line 141, in functional_call
    out = module(args, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\container.py", line 139, in forward
    input = module(input)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\modules\batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\functional.py", line 2436, in batch_norm
    _verify_batch_size(input.size())
  File "C:\Users\ain-s\miniconda3\lib\site-packages\torch\nn\functional.py", line 2404, in _verify_batch_size
    raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 4])

ain-soph avatar Aug 05 '22 01:08 ain-soph

@ain-soph, setting the model to eval should work on master https://gist.github.com/zou3519/f581a5de3bbd4721a49923752973c938 . You're correct that it doesn't work on 0.2.0 or earlier releases.

zou3519 avatar Aug 05 '22 15:08 zou3519