oneflow
oneflow copied to clipboard
带有 BatchNorm2d 的模型在开启 amp 和 grad acc 时会报错
Summary
带有 BatchNorm2d 的模型在开启 amp 和 grad acc 时会报错。
- 注释掉模型中 BN 层,保留 amp 和 grad acc,不会报错
- 注释掉 grad acc,保留 BN 层和 amp,不会报错
- 注释掉 amp,保留 BN 层和 grad acc,依然报错
Code to reproduce bug
import oneflow as flow
class Model(flow.nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = flow.nn.Sequential(
flow.nn.Conv2d(3, 3, 3),
flow.nn.BatchNorm2d(3),
flow.nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
x = x.flatten().sum()
return x
class TrainGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.model = Model().cuda()
optimizer = flow.optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
self.add_optimizer(optimizer)
self.config.enable_amp(True)
grad_scaler = flow.amp.GradScaler(
init_scale=2 ** 30,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
)
self.set_grad_scaler(grad_scaler)
self.config.set_gradient_accumulation_steps(2)
def build(self, image):
loss = self.model(image)
loss.backward()
return loss
train_graph = TrainGraph()
inp = flow.randn(8, 3, 112, 112).cuda()
train_graph(inp)
报错:
F20220909 12:45:02.264842 2431218 op_graph.h:98] (9 vs 2)
File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.h", line 98, in OpGraph
Init(job)
File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 178, in Init
InferLogicalBlobDesc(job)
File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 355, in InferLogicalBlobDesc
TopoForEachNodeWithErrorCaptured([&](OpNode* op_node ... InferLogicalOutBlobDescsIf()); return Maybe<void>::Ok(); })
File "/workspace/oneflow_main/oneflow/oneflow/core/graph/graph.h", line 657, in TopoForEachNodeWithErrorCaptured
Handler(cur_node)
File "/workspace/oneflow_main/oneflow/oneflow/core/graph/op_graph.cpp", line 391, in operator()
op_node->mut_op()->InferLogicalOutBlobDescsIf()
File "/workspace/oneflow_main/oneflow/oneflow/core/operator/operator.cpp", line 328, in InferLogicalOutBlobDescsIf
InferLogicalOutBlobDescs(BlobDesc4BnInOp, *JUST(GetOpParallelDesc()))
File "/workspace/oneflow_main/oneflow/oneflow/core/operator/user_op.cpp", line 629, in InferLogicalOutBlobDescs
val_->data_type_infer_fn(&infer_ctx)
File "/workspace/oneflow_main/oneflow/oneflow/user/ops/normalization_op.cpp", line 190, in operator()
CheckParamDataType("moving_mean")
File "/workspace/oneflow_main/oneflow/oneflow/user/ops/normalization_op.cpp", line 43, in operator()
Error Type: oneflow.ErrorProto.check_failed_error
看了一下,这里是 op 检查 tensor 和 moving_mean 的 datatype 不同,所以挂掉了。tensor 是 kFloat,而后者是 kFloat16。
System Information
- What is your OneFlow installation (pip, source, dockerhub): source
- OS: Ubuntu 20.04
- OneFlow version (run
python3 -m oneflow --doctor
):
version: 0.8.1+cu114.git.77ae9c3bea
git_commit: 77ae9c3bea
cmake_build_type: Release
rdma: False
mlir: True
- Python version: Python 3.9.12
- CUDA driver version: Build cuda_11.4.r11.4/compiler.30521435_0
- GPU models: NVIDIA GeForce RTX 2080 Ti
- Other info:
这个bug应该是一直以来就存在的,主要原因是repeat op在clear list里,amp算法可能会把repeat op推导为half,导致normalization op的输入moving mean和moving var在repeat之前就转换成了half,而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half),但并没有考虑输入已经是half的情况下怎么处理。
这个bug应该是一直以来就存在的,主要原因是repeat op在clear list里,amp算法可能会把repeat op推导为half,导致normalization op的输入moving mean和moving var在repeat之前就转换成了half,而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half),但并没有考虑输入已经是half的情况下怎么处理。
是的,Libai 里面没有带 BN 的模型,所以这个 bug 一直没暴露出来
两种比较简单的fix办法是:
-
- 把repeat直接从clear list里移除掉。
-
- 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。
两种比较简单的fix办法是:
- 把repeat直接从clear list里移除掉。
- 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。
moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn
两种比较简单的fix办法是:
- 把repeat直接从clear list里移除掉。
- 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。
moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn
那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?
两种比较简单的fix办法是:
- 把repeat直接从clear list里移除掉。
- 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。
moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn
那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?
这个问题主要是前向mutable消费的问题,原来的处理方法是这个 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp
两种比较简单的fix办法是:
- 把repeat直接从clear list里移除掉。
- 在no cast的时候(https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/auto_mixed_precision.cpp#L130 )插入一个转换为fp32的cast op,然后可以考虑再加一个消除fp32 cast成fp32的pass。
moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn
那这里就有两个问题,1)有bn的情况下不能支持grad acc,2)如果这里不是repeat和bn,而是其他op,也是可能导致一样的问题?
这个问题主要是前向mutable消费的问题,原来的处理方法是这个 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp
我的意思是某些op可能存在某个输入是no cast的,但它是非mutable消费的,然后这个op本身又被推导为half,这种情况下如果不针对no cast的输入加一个fp32的cast总是会导致类似的错误。
我的意思是某些op可能存在某个输入是no cast的,但它是非mutable消费的,然后这个op本身又被推导为half,这种情况下如果不针对no cast的输入加一个fp32的cast总是会导致类似的错误。
嗯嗯,我明白了。不过原来有BN的情况下也是支持repeat的
而amp算法只是标记了moving mean和moving var是no cast的(即不插入cast op转换成half)
还有哪些 op 是 no cast ? no cast 跟 black 是同样的区别吗? @hjchen2
总结一下为什么老版 (一年之前)的 GradAcc 没问题
- 老版 GradAcc 的 pass (GradientAccumulationRewritePass)在 (NormalizationExponentialAverageAutoTickPass) 这个 pass 之后。
- NormalizationExponentialAverageAutoTickPass 会专门给 bn 的 moving mean 链接一个 tick,这个 tick 是 bn 的 input 发出的,时序是经过 repeat 的 tick。
- GradientAccumulationRewritePass 并不是处理全部的 Variable,而是处理 source op 中的 Variable,由于之前 bn 的两个前向 Variable 已经插入了 tick ,因此跳过了这些 Variable,不会插入 repeat。
- 因此在 老版的 grad acc 下, bn 的这两个 前向更新的 Variable 不会插入 repeat,是直连消费的。 因此 amp 也不会出错。
代码:
- 老版 pass 顺序
https://github.com/Oneflow-Inc/oneflow/blob/b47eda61a0304b3a7d71a1eb2bd830f82d041538/oneflow/core/job/job_build_and_infer_ctx.cpp#L981
- 老版 NormalizationExponentialAverageAutoTickPass 插入 tick
https://github.com/Oneflow-Inc/oneflow/blob/dfe64e20cca9fb9baf8c8b96eef538d43d93b3ba/oneflow/core/job_rewriter/normalization_exponential_average_auto_tick_rewrite_pass.cpp#L54
- 老版 grad acc pass 只查找 input 为 空的 Variable:
https://github.com/Oneflow-Inc/oneflow/blob/b47eda61a0304b3a7d71a1eb2bd830f82d041538/oneflow/core/job_rewriter/gradient_accumulation_rewrite_pass.cpp#L51
还有哪些 op 是 no cast ? no cast 跟 black 是同样的区别吗? @hjchen2
REGISTER_NO_CAST_REGISTRY("normalization", "moving_mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "moving_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization", "beta", 0)
REGISTER_NO_CAST_REGISTRY("normalization_grad", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "moving_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu", "beta", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "gamma", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "beta", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("normalization_add_relu_grad", "reserve_space", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_grad", "inv_variance", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "mean", 0)
REGISTER_NO_CAST_REGISTRY("layer_norm_param_grad", "inv_variance", 0)
上面列了一些目前no cast的输入,no cast和black不一样,black是说这个op必须是fp32,no cast是op是half,但某个输入是fp32,其他非no cast的输入还是half
moving mean 和 moving var 是不能通过repeat在连接到BN,因为是mutable消费,必须是var直接连到 Bn
这个在之前是问题。但是即将合并的 PR: https://github.com/Oneflow-Inc/oneflow/pull/8961 里, repeat 是 inplace 的 repeat,所以通过 repeat 连接到 BN 也是合法的,此时 BN 修改 repeat 的 tensor,依然是修改 Variable 本身。
但如果是这样的话, repeat 就需要处理这种 no cast 转换的问题。 @hjchen2 @liujuncheng
另一种解决方案是:
依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。
不过目前还不确定是否可以在 functional 阶段判断出来当前的 Variable 是 moving mean
另一种解决方案是:
依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。
我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half
但如果是这样的话, repeat 就需要处理这种 no cast 转换的问题。 @hjchen2 @liujuncheng
我觉得是这样的
要不要把repeat移出clear list
我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。
要不要把repeat移出clear list
我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。
如果是inplace的,repeat是不是half的关系都不大吧
要不要把repeat移出clear list
我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。
如果是inplace的,repeat是不是half的关系都不大吧
但是不是说 moving mean 和 moving_variance 都需要是 float32 才行吗? 不做 cast ?所以 amp 算法里一定不能在 moving_variance 的 Variable op 到 bn 之间插入非 repeat 以外的其他 op,尤其是 cast。
要不要把repeat移出clear list
我觉得 repeat 还是需要是 clear 的吧,否则 amp 在通常情况下就不高效了? repeat 需要支持推导输入是 half 的情形,这个在绝大多数 amp + acc 下都是 half。
如果是inplace的,repeat是不是half的关系都不大吧
但是不是说 moving mean 和 moving_variance 都需要是 float32 才行吗? 不做 cast ?所以 amp 算法里一定不能在 moving_variance 的 Variable op 到 bn 之间插入非 repeat 以外的其他 op,尤其是 cast。
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了
那 repeat 换成 gray ? 可以解决这个问题么
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了
那 repeat 换成 gray ? 可以解决这个问题么
也不行
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了
那 repeat 换成 gray ? 可以解决这个问题么
也不行
😂 那 repeat 应该是什么呢
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了
那 repeat 换成 gray ? 可以解决这个问题么
也不行
😂 那 repeat 应该是什么呢
不是gray,也不是clear就可以了
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了 那 repeat 换成 gray ? 可以解决这个问题么 也不行 😂 那 repeat 应该是什么呢 不是gray,也不是clear就可以了
那就是 black 了吧? 反而不对了? 我记得如果有个逻辑是 有个 op 如果不在所有的 list 里,默认就是 clear 或者 black (忘记是哪种了) @leaves-zwx
是,但目前看到的这几个op的no cast输入都是variable,本来就不应该插入其他的op,如果repeat不是clear,就不会在repeat之前插入cast了 那 repeat 换成 gray ? 可以解决这个问题么 也不行 😂 那 repeat 应该是什么呢 不是gray,也不是clear就可以了
那就是 black 了吧? 反而不对了? 我记得如果有个逻辑是 有个 op 如果不在所有的 list 里,默认就是 clear 或者 black (忘记是哪种了) @leaves-zwx
repeat只会在variable后面吧,那就把它看成和variable一样的就行了
repeat只会在variable后面吧,那就把它看成和variable一样的就行了
还有可能在 source op 前面, tick 后面,repeat tick 给 source op (不过不参与 amp)
如果跟 Variable 一样,那 cast 就会在 repeat 后面,而不是 repeat 前面? 这个情况下, cast 会重复执行多次。 且还需要考虑 zero 下的影响。
repeat只会在variable后面吧,那就把它看成和variable一样的就行了
还有可能在 source op 前面, tick 后面,repeat tick 给 source op (不过不参与 amp)
如果跟 Variable 一样,那 cast 就会在 repeat 后面,而不是 repeat 前面? 这个情况下, cast 会重复执行多次。 且还需要考虑 zero 下的影响。
好像不会对 zero 产生影响,zero 的修改 sbp 的op在cast之后。
另一种解决方案是: 依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。
我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half
这个问题的最终结论是? @hjchen2
另一种解决方案是: 依然沿袭着 bn 的两个 前向 Variable 不插入 repeat 的逻辑。在 新版(不是 v3,是 v2) grad acc 的 Variable functional 接口里特判是否是特殊的前向 Variable,对前向 Var 不做处理,沿用当前 NormalizationExponentialAverageAutoTickPass 插入 tick 的方式。
我觉得不用,只需要repeat支持inplace就可以,剩下的就是看要不要把repeat移出clear list,还是修改一下amp算法,让这些no cast的上游op不会被推导为half
这个问题的最终结论是? @hjchen2
还没有最终结论,我只是有两个改法建议:
- 等grad acc新版合并之后,把repeat从clear list里删除,这是最简单粗暴的做法,但不够彻底。
- 修改amp算法,找到no cast的输入边上游的第一个非mutable消费输入的op,如果能找到,就在它前面插入一个any->float32的cast,如果找不到,那访问到的所有op都不能转成half,或者直接将no cast输入边的上游op都简单粗暴的标记为不能转成half也行。