oneDNN icon indicating copy to clipboard operation
oneDNN copied to clipboard

xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors)

Open pv-pterab-s opened this issue 7 months ago • 1 comments

THIS PR IS A DRAFT AND NOT YET READY FOR REVIEW

REBASE/SQUASH WILL OCCUR AFTER FULL CI PLATFORM TESTING

Implements major changes:

  1. Addition of host-side-scalar memory descriptors: indicate to primitive descriptors that a scalar input is stored in host memory and is to be passed as a scalar parameter - not pointer - to OpenCL kernels
  2. Modify SDPA to accept a host-side-scalar descriptor for scale input. On such a descriptor, pass scale as a scalar kernel parameter to the OpenCL kernel while maintaining old behavior for pre-existing descriptor types.
  3. Update internal tests for SDPA to include host-side-scalar support.

By passing scale (a single value) as a scalar kernel parameter as opposed to a device memory pointer, this PR avoids costly host/device memory transfers on SDPA invocation (as scale changes).

To utilize these additions, instantiate the SDPA primitive with the descriptor for scale set as host_side_scalar. Then, on SDPA execution, pass the scale value as a 1-element-sized host-side memory object (via an engine of type CPU). Example:

dnnl::engine gpu_engine = dnnl::engine(engine::kind::gpu, 0);
dnnl::stream gpu_stream = dnnl::stream(gpu_engine);
dnnl::engine cpu_engine = dnnl::engine(engine::kind::cpu, 0);
dnnl::stream cpu_stream = dnnl::stream(cpu_eng);

// Create a host-side-scalar memory descriptor for SDPA's primitive descriptor
dnnl_memory_desc_t tmp_scale_md;
dnnl_memory_desc_create_host_side_scalar(&tmp_scale_md, memory::data_type::f16);
memory::desc scale_md = memory::desc(tmp_scale_md);

// Create a host-side memory block of 1-element for SDPA's execution
out.m_scale = memory(memory::desc({1, 1, 1, 1, 1}, 
                             memory::format_tag data_type::f16,
                             p.qdt, abcde), cpu_engine);

// Provide host-side-scalar scale descriptor to SDPA: scale will be passed as scalar kernel parameter
sdpa::primitive_desc sdpa_prim_pd = sdpa::primitive_desc(gpu_engine ... scale_md ...);

std::unordered_map<int, memory> sdpa_args;
sdpa_args.insert({DNNL_ARG_QUERIES, m_query});
sdpa_args.insert({DNNL_ARG_KEYS, m_keys});
sdpa_args.insert({DNNL_ARG_VALUES, m_value});
sdpa_args.insert({DNNL_ARG_DST, m_output});

// Provide host-side memory to SDPA execute: 
sdpa_args.insert({DNNL_ARG_SCALE, m_scale});
sdpa_prim.execute(gpu_stream, sdpa_args);

pv-pterab-s avatar Jun 11 '25 01:06 pv-pterab-s

You should also update the verbose log here: https://github.com/uxlfoundation/oneDNN/blob/main/src/common/verbose.cpp#L1562

umar456 avatar Jun 11 '25 18:06 umar456

make test enable os_win disable test_device_cpu enable test_device_gpu disable build_cpu_runtime_omp disable build_cpu_runtime_sycl disable build_cpu_runtime_tbb enable build_graph enable compiler_icx-previous enable compiler_gnu9 enable compiler_clang14 enable compiler_vs2022 disable build_gpu_runtime_sycl disable benchdnn_all enable benchdnn_softmax enable benchdnn_graph enable arch_gpu_xe-hpc enable arch_gpu_xe-hpg-atsm enable arch_gpu_xe-hpg-dg2 disable arch_gpu_xe-lp disable arch_gpu_xe-lpg disable arch_gpu_xe-lpg+ enable arch_gpu_xe2-hpg-bmg disable arch_gpu_xe2-lpg

pv-pterab-s avatar Jun 18 '25 17:06 pv-pterab-s

What about benchdnn validation?

SDPA is an internal-only primitive, so the only checking is in the internal testing and not benchdnn.

pv-pterab-s avatar Jun 20 '25 12:06 pv-pterab-s

What about benchdnn validation?

SDPA is an internal-only primitive, so the only checking is in the internal testing and not benchdnn.

OK (just because there's no Graph extension for this feature in this PR), then, at least new API tests should appear in gtests folder.

dzarukin avatar Jun 22 '25 21:06 dzarukin

Closing this one.

  • The API RFC #3236 has been implemented via #3506 by @mzhukova .
  • The SDPA ukernel support has been implemented via #3909 by @skazakov1 .

TaoLv avatar Sep 24 '25 06:09 TaoLv