TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

`configurePlugin` is called repeatedly for my BF16 SDPA plugin — how to run initialization graph only once?

Open KarlDe1 opened this issue 3 weeks ago • 6 comments

Describe the issue I’m implementing a TensorRT plugin for SDPA with BF16 input/output. My goal is to build the compute graph only once during initialization, so I placed the graph-construction logic inside configurePlugin().

However, I found that configurePlugin() is invoked every time the model executes, which forces the plugin to rebuild the graph repeatedly and causes significant overhead.

My expectation was that TensorRT would call configurePlugin() only during engine build time (IBuilder phase). But in practice, it is also called again during execution (runtime phase).

More specifically:

  • Why is configurePlugin() called at runtime for each execution?

  • Is there a plugin API or recommended method to place one-time initialization logic?

  • Or is there a separate mechanism for caching prebuilt graph structures inside the plugin?

System Environment:

  • cudnn_frontend version: v1.16.0
  • cudnn_backend version: 9.16.0.29_cuda12
  • GPU arch: RTX5060Ti / Thor U
  • cuda runtime version: 12.8
  • cuda driver version: 580.95.05
  • OS: ubuntu22.04
std::shared_ptr<fe::graph::Graph> create_sdpa_forward_graph(int64_t const b,
                                                            int64_t const h_q,
                                                            int64_t const h_k,
                                                            int64_t const h_v,
                                                            int64_t const s_q,
                                                            int64_t const s_kv,
                                                            int64_t const d_qk,
                                                            int64_t const d_v,
                                                            float const attn_scale = 1.0f,
                                                            bool const causal_mask = true) {
  // Create a graph and set common global properties.
  auto graph = std::make_shared<fe::graph::Graph>();
  graph->set_io_data_type(fe::DataType_t::BFLOAT16)
      .set_intermediate_data_type(fe::DataType_t::FLOAT)
      .set_compute_data_type(fe::DataType_t::FLOAT);

  auto Q = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("Q")
                             .set_uid(Q_UID)
                             .set_dim({b, h_q, s_q, d_qk})
                             .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}));

  auto K = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("K")
                             .set_uid(K_UID)
                             .set_dim({b, h_k, s_kv, d_qk})
                             .set_stride({h_k * s_kv * d_qk, s_kv * d_qk, d_qk, 1}));

  auto V = graph->tensor(fe::graph::Tensor_attributes()
                             .set_name("V")
                             .set_uid(V_UID)
                             .set_dim({b, h_v, s_kv, d_v})
                             .set_stride({h_v * s_kv * d_v, s_kv * d_v, d_v, 1}));

  auto sdpa_options =
      fe::graph::SDPA_attributes().set_name("flash_attention").set_attn_scale(attn_scale);

  if (causal_mask) {
    sdpa_options.set_diagonal_alignment(cudnn_frontend::DiagonalAlignment_t::TOP_LEFT)
        .set_diagonal_band_right_bound(0);
  }

  auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options);

  O->set_output(true)
      .set_dim({b, h_q, s_q, d_v})
      .set_stride({h_q * s_q * d_v, s_q * d_v, d_v, 1})
      .set_uid(O_UID);

  return graph;
}

void SdpaCudnn::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputDesc,
                                int32_t nbInputs,
                                const nvinfer1::DynamicPluginTensorDesc* out,
                                int32_t nbOutputs) noexcept {
  // [batch, head, seq, dim]
  std::vector<int64_t> q_shape = get_shape_from_dims(inputDesc[INOUT_POS::q].desc.dims);
  std::vector<int64_t> k_shape = get_shape_from_dims(inputDesc[INOUT_POS::k].desc.dims);
  std::vector<int64_t> v_shape = get_shape_from_dims(inputDesc[INOUT_POS::v].desc.dims);

  int64_t b = q_shape[0];     // batch size
  int64_t h_q = q_shape[1];   // head dim
  int64_t h_k = k_shape[1];   // head dim
  int64_t h_v = v_shape[1];   // head dim
  int64_t s_q = q_shape[2];   // q tensor is padded to this seq length
  int64_t s_kv = k_shape[2];  // k and v tensor is padded to this seq length
  int64_t d_qk = q_shape[3];  // hidden dim
  int64_t d_v = v_shape[3];   // hidden dim

  float attn_scale = 1 / sqrt(d_qk);
  cudnn_handle_ptr_ = create_cudnn_handle();

  graph_ptr_ = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, true);
  if (!graph_ptr_->build(*cudnn_handle_ptr_, {fe::HeurMode_t::A}).is_good()) {
    LFATAL("graph is not good.");
  }

  o_tensor_sz_ = b * h_q * s_q * d_v * sizeof(half);
  CUDA_CHECK(cudaMalloc(&o_tensor_ptr_, o_tensor_sz_));
}

int32_t SdpaCudnn::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
                           const nvinfer1::PluginTensorDesc* outputDesc,
                           const void* const* inputs,
                           void* const* outputs,
                           void* workspace,
                           cudaStream_t stream) noexcept {
  if (cudnn_handle_ptr_ == nullptr || graph_ptr_ == nullptr) {
    LFATAL("CUDNN handle or graph is not initialized.");
  }

  const void* q_ptr = inputs[INOUT_POS::q];
  const void* k_ptr = inputs[INOUT_POS::k];
  const void* v_ptr = inputs[INOUT_POS::v];

  std::unordered_map<fe::graph::Tensor_attributes::uid_t, void*> variant_pack = {
      {Q_UID, const_cast<void*>(q_ptr)},
      {K_UID, const_cast<void*>(k_ptr)},
      {V_UID, const_cast<void*>(v_ptr)},
      {O_UID, o_tensor_ptr_}};

  auto handle = *cudnn_handle_ptr_;
  cudnnSetStream(handle, stream);

  if (!graph_ptr_->execute(handle, variant_pack, workspace).is_good()) {
    LFATAL("SDPA CUDNN graph execution failed.");
  }

  CUDA_CHECK(cudaStreamSynchronize(stream));

  CUDA_CHECK(cudaMemcpyAsync(outputs[INOUT_POS::output],
                             o_tensor_ptr_,
                             o_tensor_sz_,
                             cudaMemcpyDeviceToDevice,
                             stream));
  return 0;
}

KarlDe1 avatar Dec 03 '25 08:12 KarlDe1

@poweiw Can you help on this? @KarlDe1 saw configurePlugin() is invoked each time enqueueV3() which should be unexpected.

zhenhuaw-me avatar Dec 04 '25 07:12 zhenhuaw-me

Hi @poweiw,

My model contains multiple usages of my custom plugin, but all shapes in the model are fixed. At the moment, I am not sure what could be causing configurePlugin to be called repeatedly.

Looking forward to your reply.

KarlDe1 avatar Dec 04 '25 08:12 KarlDe1

related_issue

KarlDe1 avatar Dec 04 '25 09:12 KarlDe1

@KarlDe1 configurePlugin is not designed to do one time initializations, as it will be called if outputs like getOutputDimensions(), supportsFormatCombination(), getOutputDataType() has changed. Can you take a look at output of these functions?

For one time initializations, it's better to do at initialize() or attachToContext. If you need the dimensions, a workaround would be to do lazy initialization (i.e a warmup run) to build the cudnn graph before the majority of enqueue comes in.

poweiw avatar Dec 04 '25 19:12 poweiw

@poweiw

Thank for your reply. I have two questions as follows:

(1)

I would like to clarify what “changed” means here. Does it mean that configurePlugin() will be triggered when different plugin instances return different values from these functions? In my case, the overall ONNX model uses fixed shapes, but different instances of this plugin indeed have different shapes. Would this cause TensorRT to call configurePlugin() on every enqueue, even though the model shape itself is static?

(2)

Yes, in my case I do need to know the input shapes for each plugin instance before I can perform the initialization. Other than doing lazy initialization inside enqueue(), is there any more elegant or recommended way to handle this initialization based on the actual input shapes?

KarlDe1 avatar Dec 05 '25 02:12 KarlDe1

  1. I don't think that's expected. I can't repro the issue effectively though without an example.
  2. This is the only way now. At least for usecases that I'm aware, this should be effective enough and shouldn't result in huge extra code change (just checking for nullptr/setup a init flag/std::call_once). Please file a new feature request ticket if necessary so we can evaluate internally.

poweiw avatar Dec 05 '25 19:12 poweiw