oneflow
oneflow copied to clipboard
batch norm 模块处理half的输入报错
import oneflow as flow
m = flow.nn.BatchNorm2d(10).to("cuda")
m.half()
print(m.running_mean.dtype)
x = flow.randn(1, 10, 20, 20).to("cuda").half()
print(m(x))
torch可以正常运行,oneflow挂在 op 的type推导:
oneflow.float16
Traceback (most recent call last):
File "../../debug.py", line 9, in <module>
print(m(x))
File "/home/zhangxiaoyu/oneflow/python/oneflow/nn/module.py", line 158, in __call__
res = self.forward(*args, **kwargs)
File "/home/zhangxiaoyu/oneflow/python/oneflow/nn/modules/batchnorm.py", line 134, in forward
return flow._C.normalization(
**RuntimeError: InferDataType Failed. Expected kFloat16, but got kFloat**
File "../oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 140, in Dispatch
Dispatch<TensorTuple>(op_expr, inputs, ctx)
File "../oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 131, in Dispatch
Dispatch(op_expr, inputs, outputs.get(), ctx)
File "../oneflow/core/framework/op_interpreter/op_interpreter.cpp", line 96, in Apply
internal_->Apply(op_expr, inputs, outputs, ctx)
File "../oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 84, in NaiveInterpret
[&]() -> Maybe<const LocalTensorInferResult> { LocalTensorMetaInferArgs ... mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); }()
File "../oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 83, in operator()
user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)
File "../oneflow/core/framework/local_tensor_infer_cache.cpp", line 207, in GetOrInfer
Infer(*user_op_expr, infer_args)
File "../oneflow/core/framework/local_tensor_infer_cache.cpp", line 178, in Infer
user_op_expr.InferPhysicalTensorDesc( infer_args.attrs ... ) -> TensorMeta* { return &output_mut_metas.at(i); })
File "../oneflow/core/framework/op_expr.cpp", line 548, in InferPhysicalTensorDesc
dtype_infer_fn_(&infer_ctx)
File "../oneflow/user/ops/normalization_op.cpp", line 200, in operator()
CheckParamDataType("moving_mean")
File "../oneflow/user/ops/normalization_op.cpp", line 43, in operator()
Error Type: oneflow.ErrorProto.check_failed_error
F20221107 06:11:31.696609 3746810 cuda_stream.cpp:158] Check failed: cublasCreate(&cublas_handle_) : CUBLAS_STATUS_NOT_INITIALIZED (1)
*** Check failure stack trace: ***
@ 0x7ff67136e0bc google::LogMessageFatal::~LogMessageFatal()
@ 0x7ff69c329b3a oneflow::ep::CudaStream::CudaStream()
@ 0x7ff69c325a9a oneflow::ep::CudaDevice::CreateStream()
@ 0x7ff6a0bd96d5 oneflow::vm::EpStreamPolicyBase::GetOrCreateEpStream()
@ 0x7ff6a0bde0c1 oneflow::vm::OpCallInstructionPolicy::Compute()
@ 0x7ff6a0bdb39f oneflow::vm::InstructionPolicy::ComputeIf()
@ 0x7ff6a0bda4ef oneflow::vm::EpStreamPolicyBase::Run()
@ 0x7ff6a0bfd73b oneflow::vm::VirtualMachineEngine::DispatchInstruction<>()
@ 0x7ff6a0bfbea3 oneflow::vm::VirtualMachineEngine::DispatchAndPrescheduleInstructions()
@ 0x7ff6a0bfcf81 oneflow::vm::VirtualMachineEngine::Schedule()
@ 0x7ff6a0be5bba oneflow::VirtualMachine::ScheduleLoop()
@ 0x7ff6707e2de4 (unknown)
@ 0x7ff6f8b54609 start_thread
@ 0x7ff6f8a79133 clone
Aborted (core dumped)
看起来我们不允许BN除了输入以外的参数为float16以及bfloat16,而PyTorch则没有这个限制。
能不能再functor上判断一下并做个cast
能不能再functor上判断一下并做个cast
可以