wasmtime-py
wasmtime-py copied to clipboard
Performance questions
Incredible work! I was trying to see what kind of overhead there is to call a wasm function from Python. I am using WSL2 on Windows, with a recent Fedora, CPython3.10. Re-using the exact gcd.wat from the examples, it looks like any call has a "cost" of about 30μs.
The code I used is a simple timer to compare performance:
import wasmtime.loader
import time
from math import gcd as math_gcd
from gcd import gcd as wasm_gcd
def python_gcd(x, y):
while y:
x, y = y, x % y
return abs(x)
N = 1_000
for gcdf in math_gcd, python_gcd, wasm_gcd:
start_time = time.perf_counter()
for _ in range(N):
g = gcdf(16516842, 154654684)
total_time = time.perf_counter() - start_time
print(f"{total_time / N * 1_000_000:8.3f}μs")
This returns about:
0.152μs # C code from math.gcd
0.752μs # Python code from python_gcd
31.389μs # gcd.wat
Note that I tested this with an empty "hello world", and the 30μs are still there. I am wondering about 2 things:
- is this overhead inevitable and linked to the design, or are there ways to reduce it?
- I noticed in the gcd.wat code that input parameters are expected to be i32, but the code does not fail when parameters exceed 2**32, so I was wondering how this works
Not much attention has been paid to performance in the bindings in this repository, so this isn't altogether unsurprising. I have never personally written high-performance Python and consequently probably made a ton of mistakes when writing these bindings.
If you're looking for performance though I would recommend the Rust bindings rather than the Python bindings, as there we have indeed focused on performance and the overhead is significantly smaller.
Thanks, I was actually wondering if using wasm was a possibility for Python libraries that rely on C/C++/Rust, as you would only need to build the one file, then read it using wasmtime-py, instead of generating many platform wheels, which can be a headache. So I wanted to compare performance as a first step.
Thanks a lot for your response, I will close the issue.
That's definitely an intended use case for bindings like this, and I'll emphasize again that the cost here isn't intrinsic to these bindings or Python, it's probably just that I don't know how to write high-performance Python. PRs are of course always welcome for improvements as well.
I've been doing some experiments and identifying bottleneck the first one which is solved now is moving data in and out this was demonstrated by moving large data (several MB of images)
the next obvious bottleneck is that there seems to be a fixed context switch cost between python and WASM or vice versa
I can demonstrate this bottleneck by considering the following cdb_djp_hash.c it's a simple 32-bit string hash
and estimated the cost of find hashes of strings of different length, 13, 26, 130, 1300 that is 1x 2x 10x and 100x the work a loop of 100x
and the shocking finding was all this took the same time ~40ms regardless of the length of the string on the other hand the pure python and the c version the time was proportional to the length of the string the pure python started faster than wasm for the 13 byte string with 1.7ms (vs ~40.0 ms in WASM) but for the 1300 string the pure python took 170ms (vs. ~40.0 ms in WASM)
this indicates that most of the 40ms is a constant cost of switching regardless of the actual work done inside the loop
I've some commented profiling to identify exactly where is this waste of time happening but did not have enough time to continue
https://github.com/muayyad-alsadi/wasm-demos/blob/main/cdb_djp_hash/cdb_djp_hash.py#L88
doing 10k calls
pr=Profile()
pr.enable()
for i in range(10000):
cdb_djp_hash_wasm(large_a)
pr.disable()
#pr.print_stats()
#pr.print_stats('cumulative')
pr.print_stats('calls')
so what we do got those "fishy" number of calls
out of the 800ms only 169ms was taken by the actual WASM call wasmtime_func_call
Ordered by: call count
ncalls tottime percall cumtime percall filename:lineno(function)
210000 0.018 0.000 0.018 0.000 {built-in method builtins.isinstance}
140000 0.014 0.000 0.014 0.000 {built-in method builtins.hasattr}
130000 0.013 0.000 0.013 0.000 {built-in method builtins.len}
80000 0.011 0.000 0.011 0.000 {built-in method _ctypes.byref}
70000 0.008 0.000 0.008 0.000 {built-in method _ctypes.POINTER}
60000 0.009 0.000 0.009 0.000 {built-in method __new__ of type object at 0x7f34f4f430a0}
50000 0.033 0.000 0.049 0.000 _types.py:45(_from_ptr)
50000 0.025 0.000 0.047 0.000 _types.py:86(__del__)
40000 0.022 0.000 0.022 0.000 _bindings.py:178(wasm_valtype_kind)
40000 0.003 0.000 0.003 0.000 {method 'append' of 'list' objects}
30000 0.007 0.000 0.010 0.000 _value.py:161(_unwrap_raw)
30000 0.018 0.000 0.018 0.000 _bindings.py:2440(wasmtime_val_delete)
30000 0.006 0.000 0.006 0.000 _value.py:109(__init__)
30000 0.020 0.000 0.044 0.000 _value.py:117(__del__)
20000 0.032 0.000 0.183 0.000 _value.py:129(_convert)
20000 0.014 0.000 0.043 0.000 _types.py:12(i32)
20000 0.017 0.000 0.040 0.000 _types.py:54(__eq__)
20000 0.033 0.000 0.065 0.000 _types.py:94(_from_list)
20000 0.008 0.000 0.011 0.000 _func.py:256(enter_wasm)
20000 0.012 0.000 0.012 0.000 _bindings.py:120(wasm_valtype_delete)
20000 0.010 0.000 0.010 0.000 _bindings.py:172(wasm_valtype_new)
20000 0.028 0.000 0.033 0.000 _value.py:39(i32)
20000 0.006 0.000 0.018 0.000 {built-in method builtins.next}
20000 0.002 0.000 0.002 0.000 {built-in method _ctypes.addressof}
10000 0.006 0.000 0.006 0.000 _value.py:167(_value)
10000 0.006 0.000 0.016 0.000 _value.py:183(value)
10000 0.008 0.000 0.013 0.000 _types.py:139(_from_ptr)
10000 0.007 0.000 0.056 0.000 _types.py:148(params)
10000 0.007 0.000 0.038 0.000 _types.py:157(results)
10000 0.006 0.000 0.015 0.000 _types.py:169(__del__)
10000 0.093 0.000 0.685 0.000 _func.py:59(__call__)
10000 0.011 0.000 0.036 0.000 _func.py:52(type)
10000 0.011 0.000 0.193 0.000 _func.py:83(<listcomp>)
10000 0.007 0.000 0.018 0.000 _memory.py:56(data_ptr)
10000 0.006 0.000 0.015 0.000 _memory.py:65(data_len)
10000 0.054 0.000 0.112 0.000 wasmtime_fast_memory.py:47(__setitem__)
and when doing it with
np_mem = np.frombuffer(fast_mem.get_buffer_ptr(), dtype=np.uint8)
ncalls tottime percall cumtime percall filename:lineno(function)
170000 0.015 0.000 0.015 0.000 {built-in method builtins.isinstance}
140000 0.014 0.000 0.014 0.000 {built-in method builtins.hasattr}
120000 0.012 0.000 0.012 0.000 {built-in method builtins.len}
70000 0.009 0.000 0.009 0.000 {built-in method _ctypes.POINTER}
60000 0.009 0.000 0.009 0.000 {built-in method __new__ of type object at 0x7f988c7430a0}
60000 0.009 0.000 0.009 0.000 {built-in method _ctypes.byref}
50000 0.033 0.000 0.050 0.000 _types.py:45(_from_ptr)
50000 0.024 0.000 0.046 0.000 _types.py:86(__del__)
40000 0.022 0.000 0.022 0.000 _bindings.py:178(wasm_valtype_kind)
40000 0.003 0.000 0.003 0.000 {method 'append' of 'list' objects}
30000 0.018 0.000 0.018 0.000 _bindings.py:2440(wasmtime_val_delete)
30000 0.006 0.000 0.006 0.000 _value.py:109(__init__)
30000 0.020 0.000 0.045 0.000 _value.py:117(__del__)
30000 0.008 0.000 0.010 0.000 _value.py:161(_unwrap_raw)
20000 0.008 0.000 0.011 0.000 _func.py:256(enter_wasm)
20000 0.011 0.000 0.011 0.000 _bindings.py:172(wasm_valtype_new)
20000 0.034 0.000 0.040 0.000 _value.py:39(i32)
20000 0.032 0.000 0.189 0.000 _value.py:129(_convert)
20000 0.013 0.000 0.044 0.000 _types.py:12(i32)
20000 0.017 0.000 0.040 0.000 _types.py:54(__eq__)
20000 0.032 0.000 0.065 0.000 _types.py:94(_from_list)
20000 0.012 0.000 0.012 0.000 _bindings.py:120(wasm_valtype_delete)
20000 0.006 0.000 0.017 0.000 {built-in method builtins.next}
10000 0.011 0.000 0.200 0.000 _func.py:83(<listcomp>)
we are doing a 10k call, I should see ~10k wasmtime_func_call
and 10k wasmtime_func_call
and this makes since, those with 20k or 40k are done on each value or before and after function call
the count makes since but why we compare lists?
10000 0.011 0.000 0.206 0.000 _func.py:83(<listcomp>)
20000 0.033 0.000 0.195 0.000 _value.py:129(_convert)
but what does not make since is having 210k operations it seems to be types, although all of my work is in primitive type that does not need those operations and things does not need to be on variable width lists ...etc. like
40000 0.003 0.000 0.003 0.000 {method 'append' of 'list' objects}
TL;DR: actionable items
- make it fixed size not list
- document away to shortcut convert
- do the convert using some sort of map or tuple construct not dynamic list, avoid having the loop in python.
Re-using the exact gcd.wat from the examples, it looks like any call has a "cost" of about 30μs.
I think I fixed that in #137
@alexprengere would you please test my approach and report how much performance gain did it achieve
I've was able to eliminate the 40ms wasted time
here is the code
# gcd_alt.py
import ctypes
from wasmtime import Store, Module, Instance, WasmtimeError
from functools import partial
from wasmtime import _ffi as ffi
from wasmtime._func import enter_wasm
from wasmtime._bindings import wasmtime_val_raw_t
store = Store()
module = Module.from_file(store.engine, './gcd.wat')
instance = Instance(store, module, [])
def func_init(func, store):
ty = func.type(store)
ty_params = ty.params
ty_results = ty.results
params_str = (str(i) for i in ty_params)
params_n = len(ty_params)
results_n = len(ty_results)
n = max(params_n, results_n)
raw_type = wasmtime_val_raw_t*n
func.raw_type = raw_type
def _create_raw(*params):
raw = raw_type()
for i, param_str in enumerate(params_str):
setattr(raw[i], param_str, params[i])
return raw
func._create_raw = _create_raw
_gcd_in = instance.exports(store)["gcd"]
func_init(_gcd_in, store)
def gcd(a, b):
raw = _gcd_in._create_raw(a, b)
raw_ptr_casted = ctypes.cast(raw, ctypes.POINTER(wasmtime_val_raw_t))
with enter_wasm(store) as trap:
error = ffi.wasmtime_func_call_unchecked(
store._context,
ctypes.byref(_gcd_in._func),
raw_ptr_casted,
trap)
if error:
raise WasmtimeError._from_ptr(error)
return raw[0].i32
print("gcd(6, 27) = %d" % gcd(6, 27))
and here is the benchmark
import wasmtime.loader
import time
from math import gcd as math_gcd
from gcd import gcd as wasm_gcd
from gcd_alt import gcd as wasm_gcd_alt
def python_gcd(x, y):
while y:
x, y = y, x % y
return abs(x)
N = 1_000
for gcdf in math_gcd, python_gcd, wasm_gcd, wasm_gcd_alt:
start_time = time.perf_counter()
for _ in range(N):
g = gcdf(16516842, 154654684)
total_time = time.perf_counter() - start_time
print(total_time)
and here is the result
0.0002523580042179674 # math_gcd
0.0014094869984546676 # python_gcd
0.043804362998344004 # wasm_gcd
0.005873051006346941 # wasm_gcd_alt <-------- my WASM
we have 3 types of performance bottlenecks
- overhead of passing large data #135 #81 #134 which is fixed, merged and released as part of 7.0.0 release
- overhead of calling a small wasm function from python #137 #139 which is fixed not yet merged. in this case the function itself does not take time (ex. gcd of two integer or simple string hash like cdb hash). there is a constant time wasted when calling a wasm function from python regardless of what's inside the function
- overhead of calling python from wasm. the loop is inside wasm and the small function lives in python
I've addressed the first two, and I'll create a ticket for the last one with details