oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

带有 BatchNorm2d 的模型在开启 amp 和 grad acc 时会报错

Open marigoold opened this issue 2 years ago • 29 comments

Summary

带有 BatchNorm2d 的模型在开启 amp 和 grad acc 时会报错。

  • 注释掉模型中 BN 层,保留 amp 和 grad acc,不会报错
  • 注释掉 grad acc,保留 BN 层和 amp,不会报错
  • 注释掉 amp,保留 BN 层和 grad acc,依然报错

Code to reproduce bug

import oneflow as flow

class Model(flow.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = flow.nn.Sequential(
            flow.nn.Conv2d(3, 3, 3),
            flow.nn.BatchNorm2d(3),
            flow.nn.ReLU(),
        )
    def forward(self, x):
        x = self.model(x)
        x = x.flatten().sum()
        return x

class TrainGraph(flow.nn.Graph):
    def __init__(self):
        super().__init__()


        self.model = Model().cuda()
        optimizer = flow.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        self.add_optimizer(optimizer)

        self.config.enable_amp(True)
        grad_scaler = flow.amp.GradScaler(
            init_scale=2 ** 30,
            growth_factor=2.0,
            backoff_factor=0.5,
            growth_interval=2000,
        )
        self.set_grad_scaler(grad_scaler)

        self.config.set_gradient_accumulation_steps(2)

    def build(self, image):
        loss = self.model(image)
        loss.backward()
        return loss

train_graph = TrainGraph()
inp = flow.randn(8, 3, 112, 112).cuda()
train_graph(inp)

报错:

F20220909 12:45:02.264842 2431218 op_graph.h:98] (9 vs 2)
  File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.h", line 98, in OpGraph
    Init(job)
  File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 178, in Init
    InferLogicalBlobDesc(job)
  File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 355, in InferLogicalBlobDesc
    TopoForEachNodeWithErrorCaptured([&](OpNode* op_node ... InferLogicalOutBlobDescsIf()); return Maybe<void>::Ok(); })
  File "/workspace/oneflow_main/oneflow/oneflow/core/graph/graph.h", line 657, in TopoForEachNodeWithErrorCaptured
    Handler(cur_node)
  File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 391, in operator()
    op_node->mut_op()->InferLogicalOutBlobDescsIf()
  File "/workspace/oneflow_main/oneflow/oneflow/core/operator/operator.cpp", line 328, in InferLogicalOutBlobDescsIf
    InferLogicalOutBlobDescs(BlobDesc4BnInOp, *JUST(GetOpParallelDesc()))
  File "/workspace/oneflow_main/oneflow/oneflow/core/operator/user_op.cpp", line 629, in InferLogicalOutBlobDescs
    val_->data_type_infer_fn(&infer_ctx)
  File "/workspace/oneflow_main/oneflow/oneflow/user/ops/normalization_op.cpp", line 190, in operator()
    CheckParamDataType("moving_mean")
  File "/workspace/oneflow_main/oneflow/oneflow/user/ops/normalization_op.cpp", line 43, in operator()

Error Type: oneflow.ErrorProto.check_failed_error

看了一下,这里是 op 检查 tensor 和 moving_mean 的 datatype 不同,所以挂掉了。tensor 是 kFloat,而后者是 kFloat16。

System Information

  • What is your OneFlow installation (pip, source, dockerhub): source
  • OS: Ubuntu 20.04
  • OneFlow version (run python3 -m oneflow --doctor):
version: 0.8.1+cu114.git.77ae9c3bea
git_commit: 77ae9c3bea
cmake_build_type: Release
rdma: False
mlir: True
  • Python version: Python 3.9.12
  • CUDA driver version: Build cuda_11.4.r11.4/compiler.30521435_0
  • GPU models: NVIDIA GeForce RTX 2080 Ti
  • Other info:

marigoold avatar Sep 09 '22 04:09 marigoold

这个bug应该是一直以来就存在的,主要原因是repeat op在clear list里,amp算法可能会把repeat op推导为half,导致normalization op的输入moving mean和moving var在repeat之前就转换成了half,而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half),但并没有考虑输入已经是half的情况下怎么处理。

hjchen2 avatar Sep 09 '22 05:09 hjchen2

这个bug应该是一直以来就存在的,主要原因是repeat op在clear list里,amp算法可能会把repeat op推导为half,导致normalization op的输入moving mean和moving var在repeat之前就转换成了half,而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half),但并没有考虑输入已经是half的情况下怎么处理。

是的,Libai 里面没有带 BN 的模型,所以这个 bug 一直没暴露出来

marigoold avatar Sep 09 '22 05:09 marigoold

两种比较简单的fix办法是:

    1. 把repeat直接从clear list里移除掉。
    1. 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。

hjchen2 avatar Sep 09 '22 05:09 hjchen2

两种比较简单的fix办法是:

    1. 把repeat直接从clear list里移除掉。
    1. 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。

moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn

liujuncheng avatar Sep 09 '22 06:09 liujuncheng

两种比较简单的fix办法是:

    1. 把repeat直接从clear list里移除掉。
    1. 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。

moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn

那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?

hjchen2 avatar Sep 09 '22 06:09 hjchen2

两种比较简单的fix办法是:

    1. 把repeat直接从clear list里移除掉。
    1. 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。

moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn

那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?

这个问题主要是前向mutable消费的问题,原来的处理方法是这个 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp

liujuncheng avatar Sep 09 '22 06:09 liujuncheng

两种比较简单的fix办法是:

    1. 把repeat直接从clear list里移除掉。
    1. 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。

moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn

那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?

这个问题主要是前向mutable消费的问题,原来的处理方法是这个 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp

我的意思是某些op可能存在某个输入是no cast的,但它是非mutable消费的,然后这个op本身又被推导为half,这种情况下如果不针对no cast的输入加一个fp32的cast总是会导致类似的错误。

hjchen2 avatar Sep 09 '22 06:09 hjchen2

我的意思是某些op可能存在某个输入是no cast的,但它是非mutable消费的,然后这个op本身又被推导为half,这种情况下如果不针对no cast的输入加一个fp32的cast总是会导致类似的错误。

嗯嗯,我明白了。不过原来有BN的情况下也是支持repeat的

liujuncheng avatar Sep 09 '22 06:09 liujuncheng

而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half)

还有哪些 op 是 no cast ? no cast 跟 black 是同样的区别吗? @hjchen2

chengtbf avatar Sep 09 '22 06:09 chengtbf

总结一下为什么老版 (一年之前)的 GradAcc 没问题

  1. 老版 GradAcc 的 pass (GradientAccumulationRewritePass)在 (NormalizationExponentialAverageAutoTickPass) 这个 pass 之后。
  2. NormalizationExponentialAverageAutoTickPass 会专门给 bn 的 moving mean 链接一个 tick,这个 tick 是 bn 的 input 发出的,时序是经过 repeat 的 tick。
  3. GradientAccumulationRewritePass 并不是处理全部的 Variable,而是处理 source op 中的 Variable,由于之前 bn 的两个前向 Variable 已经插入了 tick ,因此跳过了这些 Variable,不会插入 repeat。
  4. 因此在 老版的 grad acc 下, bn 的这两个 前向更新的 Variable 不会插入 repeat,是直连消费的。 因此 amp 也不会出错。

代码:

  1. 老版 pass 顺序

https://github.com/Oneflow-Inc/oneflow/blob/b47eda61a0304b3a7d71a1eb2bd830f82d041538/oneflow/core/job/job_build_and_infer_ctx.cpp#L981

  1. 老版 NormalizationExponentialAverageAutoTickPass 插入 tick

https://github.com/Oneflow-Inc/oneflow/blob/dfe64e20cca9fb9baf8c8b96eef538d43d93b3ba/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp#L54

  1. 老版 grad acc pass 只查找 input 为 空的 Variable:

https://github.com/Oneflow-Inc/oneflow/blob/b47eda61a0304b3a7d71a1eb2bd830f82d041538/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp#L51

chengtbf avatar Sep 09 '22 06:09 chengtbf

还有哪些 op 是 no cast ? no cast 跟 black 是同样的区别吗? @hjchen2

REGISTER_NO_CAST_REGISTRY("normalization", "moving_mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "moving_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "beta", 0)

REGISTER_NO_CAST_REGISTRY("normalization_grad", "gamma", 0)

REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "beta", 0)

REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "beta", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "reserve_space", 0)

REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)

上面列了一些目前no cast的输入,no cast和black不一样,black是说这个op必须是fp32,no cast是op是half,但某个输入是fp32,其他非no cast的输入还是half

hjchen2 avatar Sep 09 '22 06:09 hjchen2

moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn

这个在之前是问题。但是即将合并的 PR: https://github.com/Oneflow-Inc/oneflow/pull/8961 里, repeat 是 inplace 的 repeat,所以通过 repeat 连接到 BN 也是合法的,此时 BN 修改 repeat 的 tensor,依然是修改 Variable 本身。

但如果是这样的话, repeat 就需要处理这种 no cast 转换的问题。 @hjchen2 @liujuncheng

chengtbf avatar Sep 09 '22 06:09 chengtbf

另一种解决方案是:

依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。

不过目前还不确定是否可以在 functional 阶段判断出来当前的 Variable 是 moving mean

chengtbf avatar Sep 09 '22 06:09 chengtbf

另一种解决方案是:

依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。

我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half

hjchen2 avatar Sep 09 '22 07:09 hjchen2

但如果是这样的话, repeat 就需要处理这种 no cast 转换的问题。 @hjchen2 @liujuncheng

我觉得是这样的

hjchen2 avatar Sep 09 '22 07:09 hjchen2

要不要把repeat移出clear list

我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。

chengtbf avatar Sep 09 '22 07:09 chengtbf

要不要把repeat移出clear list

我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。

如果是inplace的,repeat是不是half的关系都不大吧

hjchen2 avatar Sep 09 '22 07:09 hjchen2

要不要把repeat移出clear list

我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。

如果是inplace的,repeat是不是half的关系都不大吧

但是不是说 moving mean 和 moving_variance 都需要是 float32 才行吗? 不做 cast ?所以 amp 算法里一定不能在 moving_variance 的 Variable op 到 bn 之间插入非 repeat 以外的其他 op,尤其是 cast。

chengtbf avatar Sep 09 '22 07:09 chengtbf

要不要把repeat移出clear list

我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。

如果是inplace的,repeat是不是half的关系都不大吧

但是不是说 moving mean 和 moving_variance 都需要是 float32 才行吗? 不做 cast ?所以 amp 算法里一定不能在 moving_variance 的 Variable op 到 bn 之间插入非 repeat 以外的其他 op,尤其是 cast。

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了

hjchen2 avatar Sep 09 '22 07:09 hjchen2

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了

那 repeat 换成 gray ? 可以解决这个问题么

chengtbf avatar Sep 09 '22 07:09 chengtbf

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了

那 repeat 换成 gray ? 可以解决这个问题么

也不行

hjchen2 avatar Sep 09 '22 07:09 hjchen2

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了

那 repeat 换成 gray ? 可以解决这个问题么

也不行

😂 那 repeat 应该是什么呢

chengtbf avatar Sep 09 '22 07:09 chengtbf

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了

那 repeat 换成 gray ? 可以解决这个问题么

也不行

😂 那 repeat 应该是什么呢

不是gray,也不是clear就可以了

hjchen2 avatar Sep 09 '22 07:09 hjchen2

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了 那 repeat 换成 gray ? 可以解决这个问题么 也不行 😂 那 repeat 应该是什么呢 不是gray,也不是clear就可以了

那就是 black 了吧? 反而不对了? 我记得如果有个逻辑是 有个 op 如果不在所有的 list 里,默认就是 clear 或者 black (忘记是哪种了) @leaves-zwx

chengtbf avatar Sep 09 '22 07:09 chengtbf

是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了 那 repeat 换成 gray ? 可以解决这个问题么 也不行 😂 那 repeat 应该是什么呢 不是gray,也不是clear就可以了

那就是 black 了吧? 反而不对了? 我记得如果有个逻辑是 有个 op 如果不在所有的 list 里,默认就是 clear 或者 black (忘记是哪种了) @leaves-zwx

repeat只会在variable后面吧,那就把它看成和variable一样的就行了

hjchen2 avatar Sep 09 '22 08:09 hjchen2

repeat只会在variable后面吧,那就把它看成和variable一样的就行了

还有可能在 source op 前面, tick 后面,repeat tick 给 source op (不过不参与 amp)

如果跟 Variable 一样,那 cast 就会在 repeat 后面,而不是 repeat 前面? 这个情况下, cast 会重复执行多次。 且还需要考虑 zero 下的影响。

chengtbf avatar Sep 09 '22 08:09 chengtbf

repeat只会在variable后面吧,那就把它看成和variable一样的就行了

还有可能在 source op 前面, tick 后面,repeat tick 给 source op (不过不参与 amp)

如果跟 Variable 一样,那 cast 就会在 repeat 后面,而不是 repeat 前面? 这个情况下, cast 会重复执行多次。 且还需要考虑 zero 下的影响。

好像不会对 zero 产生影响,zero 的修改 sbp 的op在cast之后。

strint avatar Sep 09 '22 08:09 strint

另一种解决方案是: 依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。

我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half

这个问题的最终结论是? @hjchen2

chengtbf avatar Sep 13 '22 03:09 chengtbf

另一种解决方案是: 依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。

我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half

这个问题的最终结论是? @hjchen2

还没有最终结论,我只是有两个改法建议:

  • 等grad acc新版合并之后,把repeat从clear list里删除,这是最简单粗暴的做法,但不够彻底。
  • 修改amp算法,找到no cast的输入边上游的第一个非mutable消费输入的op,如果能找到,就在它前面插入一个any->float32的cast,如果找不到,那访问到的所有op都不能转成half,或者直接将no cast输入边的上游op都简单粗暴的标记为不能转成half也行。

hjchen2 avatar Sep 13 '22 04:09 hjchen2