oneflow
oneflow copied to clipboard
Aborted (core dumped) in `flow.nn.functional.sparse_softmax_cross_entropy`
Summary
A crash is triggered when lables and logits have different data types
Code to reproduce bug
import oneflow as flow
labels = flow.tensor([1, 2, 3, 4, 5], dtype=flow.int64)
logits = flow.tensor([[0.1, 0.9], [0.4, 0.6], [0.7, 0.3], [0.8, 0.2], [0.9, 0.1]], dtype=flow.float32)
loss = flow.nn.functional.sparse_softmax_cross_entropy(labels, logits)
output:
F20241205 09:14:55.806568 2441830 sparse_softmax_cross_entropy_kernel.cpp:88] Check failed: labels[i] < depth (2 vs. 2)
*** Check failure stack trace: ***
@ 0x7f30831d09ca google::LogMessage::Fail()
@ 0x7f30831d0cb2 google::LogMessage::SendToLog()
@ 0x7f30831d0537 google::LogMessage::Flush()
@ 0x7f30831d30a9 google::LogMessageFatal::~LogMessageFatal()
@ 0x7f307ef3049b oneflow::user_op::SparseSoftmaxCrossEntropyKernel<>::Compute()
@ 0x7f307ef4e54d oneflow::one::StatefulOpKernel::Compute()
@ 0x7f307d1e8cab oneflow::vm::OpCallInstructionUtil::Compute()
@ 0x7f307d1e6787 oneflow::vm::OpCallInstructionPolicy::Compute()
@ 0x7f307d1e25bc oneflow::vm::Instruction::Compute()
@ 0x7f307d1e0a6f oneflow::vm::EpStreamPolicyBase::Run()
@ 0x7f307d1ec086 oneflow::vm::StreamPolicy::RunIf()
@ 0x7f307d1f36de oneflow::vm::ThreadCtx::TryReceiveAndRun()
@ 0x7f307d1f5d2d oneflow::(anonymous namespace)::WorkerLoop()
@ 0x7f307d1f611f _ZNSt6thread11_State_implINS_8_InvokerISt5tupleIJPFvPN7oneflow2vm9ThreadCtxERKSt8functionIFvS6_EEES6_ZNS3_14VirtualMachine15CreateThreadCtxENS3_6SymbolINS3_6DeviceEEENS3_10StreamTypeEmEUlS6_E3_EEEEE6_M_runEv
@ 0x7f30831e540f execute_native_thread_routine
@ 0x7f316aca1b43 (unknown)
@ 0x7f316ad33a00 (unknown)
Aborted (core dumped)
System Information
- What is your OneFlow installation (pip, source, dockerhub): pip
- OS: Ubuntu 22.04.3 LTS
- OneFlow version (run
python3 -m oneflow --doctor):
path: ['/home/miniconda3/envs/oneflow/lib/python3.9/site-packages/oneflow']
version: 0.9.0
git_commit: 381b12c
cmake_build_type: Release
rdma: True
mlir: True
- Python version: 3.9.13
- CUDA driver version: 12.2
- GPU models: NVIDIA GeForce RTX 4090
- Other info: None