iree icon indicating copy to clipboard operation
iree copied to clipboard

CUDA out of memory due to huge memory allocation request

Open pxanthopoulos opened this issue 1 year ago • 5 comments

What happened?

I compiled GPT and tried to run it using iree-run-module. It errored with the following message:

iree/runtime/src/iree/hal/drivers/cuda/memory_pools.c:236: RESOURCE_EXHAUSTED; CUDA error 'CUDA_ERROR_OUT_OF_MEMORY' (2): out of memory; while invoking native function hal.device.queue.alloca; while calling import;
[ 1] bytecode module.tf2onnx$async:14026 [
    model.mlir:192:12,
    model.mlir:1511:13
  ]
[ 0] bytecode module.tf2onnx:66 model.mlir:2:3; invoking function 'tf2onnx'

Before the error message there is the statistics dump:

[[ iree_hal_allocator_t memory statistics ]]
  HOST_LOCAL:         1536B peak /            0B allocated /            0B freed /            0B live
DEVICE_LOCAL:    466140144B peak /         1216B allocated /         1216B freed /            0B live

And right before it, is the command that caused the error:

[module.tf2onnx$async+000036AE]    %r8 = vm.call @hal.device.queue.alloca(%r7(!hal.device/0x0x56299d139b70), %i262(4294967295), %r3(!hal.fence/0x0x56299dcc7050), %r5(!hal.fence/0x0x56299dcc7cd0), %i189(0), %i37(48), %i36(3075), %i290(4294964224))

If I understand the statistics correctly, this is not caused by my GPU not having enough memory, as the allocation is for ~4GB, the peak GPU usage so far was ~466MB and my GPU has 32GB of memory. So I went digging.

I inserted a print statement at iree/runtime/src/iree/hal/drivers/cuda/memory_pools.c:234, inside the function iree_hal_cuda_memory_pools_alloca(...) right before the call to cuMemAllocFromPoolAsync(...) (actually right before the wrapper IREE_CURESULT_TO_STATUS) to see how much memory the runtime tried to allocate (the allocation size was a 64-bit unsigned integer for my machine so i used the %zu format identifier).

...
  CUdeviceptr device_ptr = 0;
  printf("~~~~~~~~~~~~~ TRYING TO ALLOCATE %zu BYTES\n", allocation_size);
  iree_status_t status = IREE_CURESULT_TO_STATUS(
      pools->cuda_symbols,
      cuMemAllocFromPoolAsync(&device_ptr, (size_t)allocation_size, memory_pool,
                              stream),
      "cuMemAllocFromPoolAsync");
...

The amount printed was actually 18446744073709548544 bytes which explains the error. I traced the call chain back to the file iree/runtime/src/iree/modules/hal/module.c and specifically to the shim definition at lines 1096-1122. In that function, the argument to the allocation command is cast to the HAL device size. So, I added 2 print statements to see what the value of the argument was originally and what the value of the allocation size is, after the cast.

...
  iree_device_size_t allocation_size = iree_hal_cast_device_size(args->i7);

  const iree_hal_buffer_params_t params = {
      .type = memory_types,
      .usage = buffer_usage,
  };
  printf("\n~~~~~~~~~~~~~ CALLING iree_hal_device_queue_alloca TO ALLOCATE %zu BYTES\n", allocation_size);
  printf("~~~~~~~~~~~~~ PREVIOUS SIZE BEFORE CAST: %ld BYTES\n", args->i7);
  iree_hal_buffer_t* buffer = NULL;
  IREE_RETURN_IF_ERROR(iree_hal_device_queue_alloca(
      device, queue_affinity, iree_hal_fence_semaphore_list(wait_fence),
      iree_hal_fence_semaphore_list(signal_fence), pool, params,
      allocation_size, &buffer));
...

The size before the cast was -3072 (though compiler errors, I determined the right format identifier for the args->i7 variable). For int32_t, 4294964224 (which is the supposed allocation size, see trace) overflows to -3072 so I suspect an overflow is the cause of all this.

Thank you in advance.

Steps to reproduce your issue

  1. Download the GPT onnx model Either directly download it from https://drive.google.com/file/d/1w-TgnDylg43YUOQtffjo1SeRTCF_h7lE/view?usp=sharing Or:         Export the pretrained huggingface TF model from openai-community/openai-gpt (view gist)         Convert it to ONNX using the tf2onnx package and the command python -m tf2onnx.convert --saved-model ./gpt-tf/ --output model.onnx --opset 17         **(pip install transformers tensorflow tf-keras tf2onnx)

  2. Import it with the command iree-import-onnx model.onnx -o model.mlir (**pip install iree-compiler[onnx])

  3. Compile it with the command iree-compile --iree-hal-target-backends=cuda --iree-cuda-target=sm_70 --dump-compilation-phases-to=./model-phases/ --iree-vm-target-index-bits=64 --iree-stream-resource-index-bits=64 model.mlir -o model.vmfb > output.mlir 2>&1

  4. Run the compiled module with the command iree-run-module --trace_execution=true --print_statistics=true --device=cuda --module=model.vmfb --function=tf2onnx --input="1x4xsi32=1" --input="1x4xsi32=1" --input="1x4xsi32=1" > trace.txt 2>&1

  5. View the trace file trace.txt

What component(s) does this issue relate to?

Runtime

Version information

Commit hash of iree: 1f3382d7305d7b2920fe7cb6072b07ca81945f28 Versions of pip packages used:

  • transformers: 4.45.2
  • tensorflow: 2.17.0
  • tf-keras: 2.17.0
  • tf2onnx: 1.16.1
  • iree-compiler[onnx]: 20240828.999

Additional context

Build environment and commands:

My system has a 80-thread Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz and a Tesla V100-PCIE-32GB GPU.

Im using the following docker file:

FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04

RUN apt-get update && apt-get upgrade -y 
RUN apt-get install -y --allow-change-held-packages libcudnn8 libcudnn8-dev libnccl-dev libnccl2
RUN apt-get install -y cmake ninja-build lld ccache python-is-python3 python3-pip git libtinfo5

RUN ccache --max-size=40G

RUN python -m pip install --upgrade pip

WORKDIR /workspace

COPY clang+llvm-18.1.8-x86_64-linux-gnu-ubuntu-18.04/ /usr/

ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}

To build the compiler and runtime Im using the following commands:

  1. python -m pip install -r runtime/bindings/python/iree/runtime/build_requirements.txt
cmake -G Ninja -B ../iree-build/ -S . \
    -DCMAKE_BUILD_TYPE=Debug \
    -DCMAKE_C_COMPILER=clang \
    -DCMAKE_CXX_COMPILER=clang++ \
    -DCMAKE_C_COMPILER_LAUNCHER=ccache \
    -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
    \
    -DIREE_BUILD_DOCS=ON \
    \
    -DIREE_BUILD_BINDINGS_TFLITE=OFF \
    -DIREE_BUILD_BINDINGS_TFLITE_JAVA=OFF \
    \
    -DIREE_HAL_DRIVER_DEFAULTS=OFF \
    -DIREE_HAL_DRIVER_CUDA=ON \
    -DIREE_HAL_DRIVER_LOCAL_SYNC=ON \
    -DIREE_HAL_DRIVER_LOCAL_TASK=ON \
    \
    -DIREE_TARGET_BACKEND_DEFAULTS=OFF \
    -DIREE_TARGET_BACKEND_LLVM_CPU=ON \
    -DIREE_TARGET_BACKEND_CUDA=ON \
    \
    -DIREE_DEV_MODE=ON \
    -DIREE_ENABLE_ASSERTIONS=ON \
    -DIREE_ENABLE_SPLIT_DWARF=ON \
    -DIREE_ENABLE_LLD=ON
  1. cmake --build ../iree-build/ -j 30

pxanthopoulos avatar Oct 13 '24 18:10 pxanthopoulos

Thank you for the analysis. This is most likely an index/shape issue in the new onnx path and will need triage. It may be incomplete support for a negative index on an op that then gets multiplied through to what you are seeing.

Let me raise it to the appropriate folks.

stellaraccident avatar Oct 13 '24 18:10 stellaraccident

I'm able to repro the error on cuda and see the same failure with hip too on Mi250 Trace:

[module.tf2onnx$async+0000395A]    %r9 = vm.call @hal.device.queue.alloca(%r8(!hal.device/0x0x55d5a0d56930), %i266(4294967295), 
%r5(null), %r6(!hal.fence/0x0x55d5a06d7930), %i192(0), %i40(48), %i39(3075), %i292(4294964224))
---
[[ iree_hal_allocator_t memory statistics ]]
  HOST_LOCAL:         1536B peak /            0B allocated /            0B freed /            0B live
DEVICE_LOCAL:    466140272B peak /         1216B allocated /         1216B freed /            0B live
---
iree/runtime/src/iree/hal/drivers/hip/memory_pools.c:236: RESOURCE_EXHAUSTED; HIP driver error 'hipErrorOutOfMemory' (2): out of memory; while invoking native function hal.device.queue.alloca; while calling import;
[ 1] bytecode module.tf2onnx$async:14710 [
    model.mlir:192:12,
    model.mlir:1510:13,
    model.mlir:1498:13,
    model.mlir:1502:13,
    model.mlir:1499:13
  ]
[ 0] bytecode module.tf2onnx:66 model.mlir:2:3; invoking function 'tf2onnx'

On llvm-cpu there is a compilation crash at the final stage (hal -> vm): backtrace:

iree-compile: /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2420: llvm::LogicalResult legalizeUnresolvedMaterialization(mlir::Rewrit
erBase &, (anonymous namespace)::UnresolvedMaterializationRewrite *): Assertion `newMaterialization.getType() == outputType && "materialization callback produced value of incorrect type"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.      Program arguments: ../iree-build/tools/iree-compile --iree-hal-target-backends=llvm-cpu --dump-compilation-phases-to=./model-phases-cpu/ --iree-stream-resource-index-bits
=64 model2.mlir -o modelcpu.vmfb
 #0 0x00007f019c6b4907 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:13
 #1 0x00007f019c6b2b50 llvm::sys::RunSignalHandlers() /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/lib/Support/Signals.cpp:106:18
 #2 0x00007f019c6b4fca SignalHandler(int) /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #3 0x00007f0195442520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f01954969fc pthread_kill (/lib/x86_64-linux-gnu/libc.so.6+0x969fc)
 #5 0x00007f0195442476 gsignal (/lib/x86_64-linux-gnu/libc.so.6+0x42476)
 #6 0x00007f01954287f3 abort (/lib/x86_64-linux-gnu/libc.so.6+0x287f3)
 #7 0x00007f019542871b (/lib/x86_64-linux-gnu/libc.so.6+0x2871b)
 #8 0x00007f0195439e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
 #9 0x00007f01a0b38315 mlir::InFlightDiagnostic& mlir::InFlightDiagnostic::append<char const (&) [53]>(char const (&) [53]) & /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/include/mlir/IR/Diagnostics.h:340:5
#10 0x00007f01a0b38315 mlir::InFlightDiagnostic&& mlir::InFlightDiagnostic::operator<<<char const (&) [53]>(char const (&) [53]) && /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/include/mlir/IR/Diagnostics.h:334:22
#11 0x00007f01a0b38315 legalizeUnresolvedMaterialization(mlir::RewriterBase&, (anonymous namespace)::UnresolvedMaterializationRewrite*) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2427:29
#12 0x00007f01a0b38315 mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:2494:18
#13 0x00007f01a0b3d70b mlir::applyPartialConversion(llvm::ArrayRef<mlir::Operation*>, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3183:22
#14 0x00007f01a0b3d70b mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:3189:10
#15 0x00007f019e05b6e2 mlir::iree_compiler::IREE::VM::ConversionPass::runOnOperation() /mnt/NOD/IREE/iree/compiler/src/iree/compiler/Dialect/VM/Transforms/Conversion.cpp:159:16
#16 0x00007f019c862625 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_7::operator()() const /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:17
#17 0x00007f019c862625 void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_7>(long) /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:12
#18 0x00007f019c862625 llvm::function_ref<void ()>::operator()() const /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#19 0x00007f019c862625 void mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275:7
#20 0x00007f019c862625 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:520:21
#21 0x00007f019c862d98 llvm::LogicalResult::failed() const /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/include/llvm/Support/LogicalResult.h:43:43
#22 0x00007f019c862d98 llvm::failed(llvm::LogicalResult) /mnt/NOD/IREE/iree/third_party/llvm-project/llvm/include/llvm/Support/LogicalResult.h:71:58
#23 0x00007f019c862d98 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:592:9
#24 0x00007f019c865109 mlir::PassManager::run(mlir::Operation*) /mnt/NOD/IREE/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:0
#25 0x00007f019c608de0 llvm::LogicalResult::failed() const /mnt/NOD/IREE/iree/third_party/llvmproject/llvm/include/llvm/Support/LogicalResult.h:43:43
#26 0x00007f019c608de0 llvm::failed(llvm::LogicalResult) /mnt/NOD/IREE/iree/third_party/llvmproject/llvm/include/llvm/Support/LogicalResult.h:71:58
#27 0x00007f019c608de0 mlir::iree_compiler::embed::(anonymous namespace)::Invocation::runPipeline(iree_compiler_pipeline_t) /mnt/NOD/IREE/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:997:7
#28 0x00007f019c608de0 ireeCompilerInvocationPipeline /mnt/NOD/IREE/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:1432:23
#29 0x00007f019c828fc8 mlir::iree_compiler::runIreecMain(int, char**)::$_2::operator()(iree_compiler_source_t*) const /mnt/NOD/IREE/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:254:11
#30 0x00007f019c828801 mlir::iree_compiler::runIreecMain(int, char**) /mnt/NOD/IREE/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:0:10
#31 0x00007f0195429d90 (/lib/x86_64-linux-gnu/libc.so.6+0x29d90)
#32 0x00007f0195429e40 __libc_start_main (/lib/x86_64-linux-gnu/libc.so.6+0x29e40)
#33 0x000055bd8c98f6b5 _start (../iree-build/tools/iree-compile+0x16b5)

CPU compilation failure may/not be related to the GPU runtime failures working with @AmosLewis on getting a minimal repro to further debug this.

PhaneeshB avatar Oct 15 '24 22:10 PhaneeshB

When trying to get a min repro for the above issue, I found that there is another issue just before, with onnx.Reshape

Turns out that onnx.Reshape with data having dynamic dims along with a -1 in the shape dim is causing an OUT OF RANGE error and output of this Reshape is fed into the indices input of onnx.Gather which results in the above RESOURCE_EXHAUSTED error Dealing with Reshape correctly may resolve the problem with Gather so documenting the steps to reproduce the error on Rocm (Mi250)

reproducer torch_onnx mlir : const_fold_opt__2625 => -1

module {
  func.func @tf2onnx(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.ml = 2 : si64}, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.16.1 15c810"} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_starts__778> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_fold_opt__2625> : tensor<1xsi32>} : () -> !torch.vtensor<[1],si32> 
    %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_ends__1662> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<const_axes__1378> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %none = torch.constant.none
    %4 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?],si32>) -> !torch.vtensor<[2],si64> 
    %5 = torch.operator "onnx.Cast"(%4) {torch.onnx.to = 6 : si64} : (!torch.vtensor<[2],si64>) -> !torch.vtensor<[2],si32> 
    %6 = torch.operator "onnx.Slice"(%5, %2, %0, %3) : (!torch.vtensor<[2],si32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si32> 
    %7 = torch.operator "onnx.Concat"(%1, %6) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si32>, !torch.vtensor<[1],si32>) -> !torch.vtensor<[2],si32> 
    %8 = torch.operator "onnx.Cast"(%7) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[2],si32>) -> !torch.vtensor<[2],si64> 
    %9 = torch.operator "onnx.Reshape"(%arg0, %8) : (!torch.vtensor<[?,?],si32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],si32> 
    return %9 : !torch.vtensor<[?,?],si32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      const_starts__778: "0x080000000200000000000000",
      const_fold_opt__2625: "0x08000000FFFFFFFF",
      const_ends__1662: "0x080000000100000000000000",
      const_axes__1378: "0x080000000000000000000000"
    }
  }
#-}

with the command below the expected value of %8 => [-1, 4]

commands:

iree-compile  --iree-hal-target-backends=rocm  --iree-hip-target=gfx90a --dump-compilation-phases-to=./model-phases-rocm/ model.torch.onnx.mlir -o model.vmfb > output.mlir 2>&1

iree-run-module --trace_execution=true --print_statistics=true --device=hip://<UUID> --module=model.vmfb --function=tf2onnx --input="1x4xsi32=1"  > trace.txt 2>&1

Runtime Error:

iree/runtime/src/iree/hal/buffer.c:591: OUT_OF_RANGE; attempted to access an address outside of the valid buffer range (offset=0, adjusted_length=18446744073709551600, end=18446744073709551599, buffer byte_length=16); invalid subspan of an existing buffer (source_offset=0, length=18446744073709551600); while invoking native function hal.buffer_view.create; while calling import; 
[ 1] bytecode module.tf2onnx$async:1930 /home/pbarwari/NOD/SHARK-TestSuite/alt_e2eshark/test-run/mygpt4_trunc_Reshape_0/model.torch_onnx.mlir:14:5
[ 0] bytecode module.tf2onnx:62 /home/pbarwari/NOD/SHARK-TestSuite/alt_e2eshark/test-run/mygpt4_trunc_Reshape_0/model.torch_onnx.mlir:2:3; invoking function 'tf2onnx'

From the onnx graph this is the computation we are looking at Image

PhaneeshB avatar Oct 16 '24 22:10 PhaneeshB

Can you please post the IR dump after all. For this model.

MaheshRavishankar avatar Oct 17 '24 00:10 MaheshRavishankar

%9 = torch.operator "onnx.Reshape"(%arg0, %8) : (!torch.vtensor<[?,?],si32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],si32> 

will be lower to

%42 = torch.aten.reshape %arg0, %41 : !torch.vtensor<[?,?],si32>, !torch.list<int> -> !torch.vtensor<[?,?],si32>

@PhaneeshB Since the input data size for torch.aten.reshape is ‘?x?’, we need to figure out a way to covert ‘?’ to a positive value/real shape size(materialize -1 in the IR). And since aten.reshape is decompose to aten.view op, we need to add/debug this sematic either in ConvertAtenViewOp/ConvertAtenViewOpStrict/ConvertAtenViewOpToReshape in TorchToLinalg/DataMovement.cpp. Here is the previous work on view op: https://github.com/llvm/torch-mlir/pull/2470 https://github.com/llvm/torch-mlir/issues/2567

Based on https://pytorch.org/docs/stable/generated/torch.reshape.html, a single dimension may be -1, in which case it’s inferred from the remaining dimensions and the number of elements in input.

AmosLewis avatar Oct 17 '24 05:10 AmosLewis

For the GPT model, I will open a new issue because the compilation fails due to an assertion fail. However, I tested some other models that exhibited the same behaviour (such as GPT2) and the problem is solved. I am on commit 5c45591244fe7499f37329e631ddff04493295d6. If you can confirm the problem is solved for the reproduction examples, I think this can be closed.

@PhaneeshB @AmosLewis

pxanthopoulos avatar Nov 12 '24 12:11 pxanthopoulos