tensorflow icon indicating copy to clipboard operation
tensorflow copied to clipboard

Error while loading Tensorflow plugins - cuFFT, cuDNN, cuBLAS

Open Humbulani1234 opened this issue 1 year ago • 2 comments
trafficstars

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

binary

TensorFlow version

tf 2.17

Custom code

No

OS platform and distribution

Linux Ubuntu 22.04

Mobile device

No response

Python version

3.10

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Tensorflow==2.16.1 does not produce the following log output errors. I tried to read the source for tensorflow==2.17.0 and 2.16.1 to at least try find out what might be the issue. The following is what I found out:

The executed piece of code for registering the plugin cuFFT is located at the file: tensorflow-2.16.1/third_party/xla/xla/stream_executor/cuda/cuda_fft.c (you can change the tensorflow version # appropriately) and reproduced below:

void initialize_cufft() {
  absl::Status status =
      PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
          cuda::kCudaPlatformId, "cuFFT",
          [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
            gpu::GpuExecutor *cuda_executor =
                dynamic_cast<gpu::GpuExecutor *>(parent);
            if (cuda_executor == nullptr) {
              LOG(ERROR) << "Attempting to initialize an instance of the cuFFT "
                         << "support library with a non-CUDA StreamExecutor";
              return nullptr;
            }

            return new gpu::CUDAFft(cuda_executor);
          });
  if (!status.ok()) {
    LOG(ERROR) << "Unable to register cuFFT factory: " << status.message();
  }
}

This function should be responsible for creating the PluginRegistry object defined in the file: tensorflow-2.16.1/third_party/xla/xla/stream_executor/plugin_registry.h. This object has a very important comment, reproduced below:

//The PluginRegistry is a singleton that maintains the set of registered
// "support library" plugins. Currently, there are four kinds of plugins:
// BLAS, DNN, and FFT. Each interface is defined in the corresponding
// gpu_{kind}.h header.

// Registers the specified factory with the specified platform.
 // Returns a non-successful status if the factory has already been registered
 // with that platform (but execution should be otherwise unaffected).

The class should be a Singleton, and even if it has been registered once an attempt to register it again will fail but tensorflow should work as expected.

And below is the function responsible for the registration, from the file: tensorflow-2.16.1/third_party/xla/xla/stream_executor/plugin_registry.cc:


template <typename FACTORY_TYPE>
absl::Status PluginRegistry::RegisterFactoryInternal(
    const std::string& plugin_name, FACTORY_TYPE factory,
    std::optional<FACTORY_TYPE>* factories) {
  absl::MutexLock lock{&GetPluginRegistryMutex()};

  if (factories->has_value()) {
    return absl::AlreadyExistsError(
        absl::StrFormat("Attempting to register factory for plugin %s when "
                        "one has already been registered",
                        plugin_name));
  }

  (*factories) = factory;
  return absl::OkStatus();
}

I am not entirely sure as to when and where the very first object of cuFFT PluginRegistery is created for tensorflow to display this error. I believe there has to be a point from running import tensorflow and calling the above function initialize_cufft where the PluginRegistry object is created and since it must be a Singleton, hence the error. I hope someone can elaborate further on this, or provide better clarity.

Standalone code to reproduce the issue

python -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"

Relevant log output

2024-08-26 16:31:19.008920: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-26 16:31:19.027228: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-26 16:31:19.032798: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-26 16:31:19.047347: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-26 16:31:20.104940: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1724682680.723847    4894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1724682680.805189    4894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1724682680.805836    4894 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-26 16:31:20.806043: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2432] TensorFlow was not built with CUDA kernel binaries compatible with compute capability 5.0. CUDA kernels will be jit-compiled from PTX, which could take 30 minutes or longer.
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Humbulani1234 avatar Aug 26 '24 14:08 Humbulani1234

The post from itbear-shu is a scam, I received a similar comment in an issue I opened today.

misterBart avatar Aug 26 '24 15:08 misterBart

Thank you for reporting the issue. This is a known issue where other issues are still open and developers are working on the same.

I request you to take a look at those issues where a similar issue has been proposed. Also I request to follow the similar issue which has been proposed to have the updates on the similar issue.

https://github.com/tensorflow/tensorflow/issues/71791#issuecomment-2237115569 https://github.com/tensorflow/tensorflow/issues/70947 https://github.com/tensorflow/tensorflow/issues/62075

Thank you!

tilakrayal avatar Aug 28 '24 06:08 tilakrayal

Closing as duplicate of #62075

belitskiy avatar Aug 28 '24 14:08 belitskiy