oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

batch norm 模块处理half的输入报错

Open BBuf opened this issue 3 years ago • 2 comments

import oneflow as flow

m = flow.nn.BatchNorm2d(10).to("cuda")
m.half()

print(m.running_mean.dtype)
x = flow.randn(1, 10, 20, 20).to("cuda").half()

print(m(x))

torch可以正常运行,oneflow挂在 op 的type推导:

oneflow.float16
Traceback (most recent call last):
  File "../../debug.py", line 9, in <module>
    print(m(x))
  File "/home/zhangxiaoyu/oneflow/python/oneflow/nn/module.py", line 158, in __call__
    res = self.forward(*args, **kwargs)
  File "/home/zhangxiaoyu/oneflow/python/oneflow/nn/modules/batchnorm.py", line 134, in forward
    return flow._C.normalization(
**RuntimeError: InferDataType Failed. Expected kFloat16, but got kFloat**
  File "../oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 140, in Dispatch
    Dispatch<TensorTuple>(op_expr, inputs, ctx)
  File "../oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 131, in Dispatch
    Dispatch(op_expr, inputs, outputs.get(), ctx)
  File "../oneflow/core/framework/op_interpreter/op_interpreter.cpp", line 96, in Apply
    internal_->Apply(op_expr, inputs, outputs, ctx)
  File "../oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 84, in NaiveInterpret
    [&]() -> Maybe<const LocalTensorInferResult> { LocalTensorMetaInferArgs ... mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); }()
  File "../oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 83, in operator()
    user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)
  File "../oneflow/core/framework/local_tensor_infer_cache.cpp", line 207, in GetOrInfer
    Infer(*user_op_expr, infer_args)
  File "../oneflow/core/framework/local_tensor_infer_cache.cpp", line 178, in Infer
    user_op_expr.InferPhysicalTensorDesc( infer_args.attrs ... ) -> TensorMeta* { return &output_mut_metas.at(i); })
  File "../oneflow/core/framework/op_expr.cpp", line 548, in InferPhysicalTensorDesc
    dtype_infer_fn_(&infer_ctx)
  File "../oneflow/user/ops/normalization_op.cpp", line 200, in operator()
    CheckParamDataType("moving_mean")
  File "../oneflow/user/ops/normalization_op.cpp", line 43, in operator()
    
Error Type: oneflow.ErrorProto.check_failed_error
F20221107 06:11:31.696609 3746810 cuda_stream.cpp:158] Check failed: cublasCreate(&cublas_handle_) : CUBLAS_STATUS_NOT_INITIALIZED (1) 
*** Check failure stack trace: ***
    @     0x7ff67136e0bc  google::LogMessageFatal::~LogMessageFatal()
    @     0x7ff69c329b3a  oneflow::ep::CudaStream::CudaStream()
    @     0x7ff69c325a9a  oneflow::ep::CudaDevice::CreateStream()
    @     0x7ff6a0bd96d5  oneflow::vm::EpStreamPolicyBase::GetOrCreateEpStream()
    @     0x7ff6a0bde0c1  oneflow::vm::OpCallInstructionPolicy::Compute()
    @     0x7ff6a0bdb39f  oneflow::vm::InstructionPolicy::ComputeIf()
    @     0x7ff6a0bda4ef  oneflow::vm::EpStreamPolicyBase::Run()
    @     0x7ff6a0bfd73b  oneflow::vm::VirtualMachineEngine::DispatchInstruction<>()
    @     0x7ff6a0bfbea3  oneflow::vm::VirtualMachineEngine::DispatchAndPrescheduleInstructions()
    @     0x7ff6a0bfcf81  oneflow::vm::VirtualMachineEngine::Schedule()
    @     0x7ff6a0be5bba  oneflow::VirtualMachine::ScheduleLoop()
    @     0x7ff6707e2de4  (unknown)
    @     0x7ff6f8b54609  start_thread
    @     0x7ff6f8a79133  clone
Aborted (core dumped)

看起来我们不允许BN除了输入以外的参数为float16以及bfloat16,而PyTorch则没有这个限制。

BBuf avatar Nov 07 '22 06:11 BBuf

能不能再functor上判断一下并做个cast

liujuncheng avatar Nov 07 '22 06:11 liujuncheng

能不能再functor上判断一下并做个cast

可以

BBuf avatar Nov 07 '22 06:11 BBuf