Candle Inference ~8.5x Slower Than PyTorch on CPU
Issue
Inference with the sentence-transformers/all-MiniLM-L6-v2 model using candle-transformers (v0.8.4) is significantly slower (approximately 8.5 times) compared to the standard PyTorch implementation using transformers and PyTorch. The average time per batch (size 4) is around 122.30 ms in Candle versus ~14.34 ms in PyTorch.
Environment
- Candle Version: 0.8.4 (using
candle-core,candle-nn,candle-transformers) - Candle Features:
cuda,cudnn,mkl - CUDA Version: 12.8 (V12.8.93)
- Model:
sentence-transformers/all-MiniLM-L6-v2 - Batch Size: 4
- Hardware: i9-13900HX
- Operating System: Windows 11
Profiling Data (Batch Size = 4)
Time taken for the forward pass.
| Language | Operation | Mean (ms) | Std (ms) | Min (ms) | Max (ms) |
|---|---|---|---|---|---|
| Python PyTorch | Inference (ms) | 14.34 | 3.15 | 8.00 | 120.78 |
| Rust Candle | Inference (ms) | 122.30 | 43.38 | 18.00 | 319.00 |
Expected Result
Inference time should be closer to the PyTorch performance, ideally around 14-20 ms on CPU. I am observing Candle GPU performance better than PyTorch vanilla.
Actual Result
The average inference time with Candle is ~122.30 ms per batch, with significant variance (standard deviation: 43.38 ms). This is much slower than expected and slower than previous reports comparing Candle and Python implementations for this model.
Hi @msminhas93! Can you please make a flamegraph using https://github.com/flamegraph-rs/flamegraph?
Hey Eric! Here is the flamegraph. Seems like bulk of it is used by rayon. Since I'm not using any rayon in my code it seems like candle is using it. I'm not that great with flamegraphs yet so my observation could be incorrect.
Hi @msminhas93 how are you building you binary?
@BerserkerMother the normal cargo build --release with debug set to 1 for flamegraph [profile.release] debug = 1
You may want to check whether your simd instructions were properly detected, e.g. by printing candle::utils::with_avx() for x86.
You may want to check whether your simd instructions were properly detected, e.g. by printing
candle::utils::with_avx()for x86.
Shouldn't he build with RUSTFLAGS="-C target-feature=+avx" for this?
The best is probably to have the same .cargo/config.toml at the root of the project as in the main candle repo here.
I tried adding the config.toml but doesn't seem to help. AVX is enabled.
println!("AVX enabled: {}", utils::with_avx()); AVX enabled: true
Rayon does seem to be the bulk of the flamegraph. I don't get it.
I enabled mkl and that helped some to reduce the mean from 122.3ms to 83.3ms. This is the flamegraph of mkl enabled run.
Flamegraphs are pretty hard to look at as indeed lots of things will be hidden by rayon. Instead could you use the tracing functionality, e.g. something similar to this, this should provide a trace that is compatible with perfetto. Re mkl, could you report the min and max time taken in this case. In your original stats, it seemed that one of the forward could have been much slower than the rest with may be indicative of some initialization work taking place.
Flamegraphs are pretty hard to look at as indeed lots of things will be hidden by rayon. Instead could you use the tracing functionality, e.g. something similar to this, this should provide a trace that is compatible with perfetto. Re mkl, could you report the min and max time taken in this case. In your original stats, it seemed that one of the forward could have been much slower than the rest with may be indicative of some initialization work taking place.
I didn’t know that. Thank you very much! Do you think we can improve the book to include such things? I was stuck trying to find execution times with Flamegraph, but now I know there’s a much better way.
Yeah the book is not in a very good state and lags quite a bit compared to the current candle stuff + probably doesn't focus on the most relevant things. Maybe some rewrite effort is needed there but it's hard to ensure that it stays current. @greenrazer maybe that's something that you would be interested in looking at ?
Also just to mention another trick, you may want to set RAYON_NUM_THREADS=1 for performance investigation. Obviously it will impact performance by disabling multi-threading but for some of the kernels (at least gemm I think) it will also turn off rayon entirely so could result in much nicer flamegraphs.
@LaurentMazare I'll be happy to help and learn in the way.
Rayon is adding overhead for inference at least for this model and mkl version. These are the numbers for different threads mkl vs no mkl and python. I'll look at the chrome profiling and post my findings soon. The data has 1904 baches at batch size of 4.
Comparison of 'Inference (ms)' Statistics:
| count | mean | std | min | 25% | 50% | 75% | max | |
|---|---|---|---|---|---|---|---|---|
| profiling_results_rust_mkl_1_threads.csv | 1904 | 53.9296 | 14.6833 | 15 | 46 | 53 | 59 | 159 |
| profiling_results_rust_mkl_4_threads.csv | 1904 | 60.9974 | 18.3994 | 15 | 51 | 59 | 68 | 196 |
| profiling_results_rust_mkl_8_threads.csv | 1904 | 73.8015 | 22.3596 | 17 | 57 | 73 | 87 | 223 |
| profiling_results_rust_no_mkl_1_threads.csv | 1904 | 64.583 | 15.0718 | 14 | 54 | 65 | 74 | 135 |
| profiling_results_rust_no_mkl_4_threads.csv | 1904 | 50.7468 | 11.8582 | 11 | 42 | 51 | 58 | 108 |
| profiling_results_rust_no_mkl_8_threads.csv | 1904 | 50.4695 | 11.7823 | 11 | 42 | 51 | 58 | 110 |
| profiling_results_python.csv | 1904 | 14.3405 | 3.1515 | 8.00014 | 13 | 14.0002 | 15.0008 | 120.781 |
Well I think I'm going down a rabbit hole. The chrome profiler didn't show much. I did some profiling using vtune and the results are bizarre.
Without setting any rayon thread env variable. The top level break down for 100 batches.
Performance Profiling Results
| Function / Call Stack | CPU Time | Instructions Retired | Microarchitecture Usage | Module | Function (Full) | Source File | Start Address |
|---|---|---|---|---|---|---|---|
core::cell::Cell::get |
4.197s | 16,691,100,000 | candle_batch_inf.exe | core::cell::Cell::get |
cell.rs | 0x1406ae195 |
|
func@0x18007d91b |
2.098s | 452,353,000 | ntdll.dll | func@0x18007d91b |
[Unknown] | 0x18007d91b |
|
func@0x140249070 |
1.799s | 416,068,000 | ntoskrnl.exe | func@0x140249070 |
[Unknown] | 0x140249070 |
|
crossbeam_epoch::epoch::Epoch::pinned |
1.540s | 1,403,020,000 | candle_batch_inf.exe | crossbeam_epoch::epoch::Epoch::pinned |
epoch.rs | 0x1406ae186 |
|
gemm_common::simd::x86::impl10::vectorize |
1.185s | 6,664,345,000 | candle_batch_inf.exe | gemm_common::simd::x86::impl10::vectorize |
simd.rs | 0x14069a458 |
|
func@0x14068ec83 |
1.084s | 2,341,592,000 | ntoskrnl.exe | func@0x14068ec83 |
[Unknown] | 0x14068ec83 |
|
candle_core::cpu::erf::erf_impl |
1.079s | 9,168,010,000 | candle_batch_inf.exe | candle_core::cpu::erf::erf_impl() |
erf.rs | 0x1404cd7a0 |
|
core::sync::atomic::atomic_load |
0.825s | 149,978,000 | candle_batch_inf.exe | core::sync::atomic::atomic_load |
atomic.rs | 0x140d68aa2 |
|
crossbeam_epoch::internal::Global::try_advance |
0.806s | 1,178,053,000 | candle_batch_inf.exe | crossbeam_epoch::internal::Global::try_advance() |
internal.rs | 0x140d58730 |
|
crossbeam_epoch::internal::Local::pin |
0.599s | 3,064,873,000 | candle_batch_inf.exe | crossbeam_epoch::internal::Local::pin |
internal.rs | 0x1406ae15a |
|
crossbeam_deque::deque::Stealer::steal<rayon_core::job::JobRef> |
0.588s | 5,217,783,000 | candle_batch_inf.exe | crossbeam_deque::deque::Stealer::steal<rayon_core::job::JobRef>() |
deque.rs | 0x1406adab0 |
Notes:
- The table lists functions by CPU time in descending order, highlighting the most time-consuming operations.
Mostly crossbeam related stuff when rayon isn't disabled.
With RAYON_NUM_THREADS=1 the top level break down for 100 batches.
Performance Profiling Results
| Function / Call Stack | CPU Time | Instructions Retired | Microarchitecture Usage | Module | Function (Full) | Source File | Start Address |
|---|---|---|---|---|---|---|---|
candle_core::cpu::erf::erf_impl |
1.015s | 9,211,552,000 | candle_batch_inf.exe | candle_core::cpu::erf::erf_impl() |
erf.rs | 0x1404cdfa0 |
|
gemm_common::simd::x86::impl10::vectorize |
0.354s | 3,425,304,000 | candle_batch_inf.exe | gemm_common::simd::x86::impl10::vectorize |
simd.rs | 0x14069ac58 |
|
candle_core::cpu::erf::evaluate::polynomial |
0.226s | 1,572,350,000 | candle_batch_inf.exe | candle_core::cpu::erf::evaluate::polynomial |
erf.rs | 0x1404ce0f4 |
|
candle_core::cpu_backend::utils::binary_map::closure3 |
0.198s | 3,737,355,000 | candle_batch_inf.exe | candle_core::cpu_backend::utils::binary_map::closure3 |
utils.rs | 0x14059233a |
|
gemm_f32::microkernel::fma::f32::x3x4::KernelIter::execute |
0.189s | 3,241,460,000 | candle_batch_inf.exe | gemm_f32::microkernel::fma::f32::x3x4::KernelIter::execute |
microkernel.rs | 0x14069255c |
|
core::ptr::write |
0.164s | 2,160,167,000 | candle_batch_inf.exe | core::ptr::write |
mod.rs | 0x1405980e0 |
|
candle_core::op::impl8::f64 |
0.160s | 1,664,272,000 | candle_batch_inf.exe | candle_core::op::impl8::f64 |
op.rs | 0x1405980d0 |
|
exp |
0.158s | 1,603,797,000 | ucrtbase.dll | exp |
[Unknown] | 0x1800a79f0 |
|
core::core_arch::x86::avx::_mm256_setr_ps |
0.142s | 2,044,055,000 | candle_batch_inf.exe | core::core_arch::x86::avx::_mm256_setr_ps |
avx.rs | 0x1406925a0 |
|
gemm_f32::microkernel::fma::f32::x3x4::KernelIter::execute |
0.142s | 2,392,391,000 | candle_batch_inf.exe | gemm_f32::microkernel::fma::f32::x3x4::KernelIter::execute |
microkernel.rs | 0x1406924f0 |
|
candle_core::op::impl8::f32 |
0.139s | 1,678,786,000 | candle_batch_inf.exe | candle_core::op::impl8::f32 |
op.rs | 0x1405980d0 |
|
core::core_arch::x86::avx::_mm256_setr_ps |
0.129s | 2,285,955,000 | candle_batch_inf.exe | core::core_arch::x86::avx::_mm256_setr_ps |
avx.rs | 0x140692530 |
With and without rayon the processing time seems similar. This has mkl disabled though. When rayon is disabled erf becomes function/call stack entry.
For CPU ort https://github.com/pykeio/ort/blob/main/examples/sentence-transformers/semantic-similarity.rs speed for the same model faster than pytorch version.
| Metric | Mean | Std | Min | Max |
|---|---|---|---|---|
| Inference (ms) | 8.498950 | 2.709629 | 2.000000 | 67.000000 |
The core utilization, vectorization is better overall. Note this is for all the dataset (1904 batches) and not the first 100 batches the one I posted for candle.
I’m evaluating a switch from my existing Python inference stack to Rust for an edge-deployment scenario. The target node is an IBM x3650 M4 (Intel Xeon, 12 cores / 24 threads, 16 GB RAM) on Windows Server 2016, but the service will only get a fraction of those resources because other workloads share the machine.
Based on your experience working on this project @msminhas93, do the latency-and-footprint gains you’ve seen with Rust warrant the migration effort in this context?