wasmtime-py icon indicating copy to clipboard operation
wasmtime-py copied to clipboard

Performance questions

Open alexprengere opened this issue 2 years ago • 8 comments

Incredible work! I was trying to see what kind of overhead there is to call a wasm function from Python. I am using WSL2 on Windows, with a recent Fedora, CPython3.10. Re-using the exact gcd.wat from the examples, it looks like any call has a "cost" of about 30μs.

The code I used is a simple timer to compare performance:

import wasmtime.loader
import time
from math import gcd as math_gcd
from gcd import gcd as wasm_gcd

def python_gcd(x, y):
    while y:
        x, y = y, x % y
    return abs(x)

N = 1_000
for gcdf in math_gcd, python_gcd, wasm_gcd:
    start_time = time.perf_counter()
    for _ in range(N):
        g = gcdf(16516842, 154654684)
    total_time = time.perf_counter() - start_time
    print(f"{total_time / N * 1_000_000:8.3f}μs")

This returns about:

   0.152μs  # C code from math.gcd
   0.752μs  # Python code from python_gcd
  31.389μs  # gcd.wat

Note that I tested this with an empty "hello world", and the 30μs are still there. I am wondering about 2 things:

  • is this overhead inevitable and linked to the design, or are there ways to reduce it?
  • I noticed in the gcd.wat code that input parameters are expected to be i32, but the code does not fail when parameters exceed 2**32, so I was wondering how this works

alexprengere avatar Sep 22 '22 12:09 alexprengere

Not much attention has been paid to performance in the bindings in this repository, so this isn't altogether unsurprising. I have never personally written high-performance Python and consequently probably made a ton of mistakes when writing these bindings.

If you're looking for performance though I would recommend the Rust bindings rather than the Python bindings, as there we have indeed focused on performance and the overhead is significantly smaller.

alexcrichton avatar Sep 22 '22 14:09 alexcrichton

Thanks, I was actually wondering if using wasm was a possibility for Python libraries that rely on C/C++/Rust, as you would only need to build the one file, then read it using wasmtime-py, instead of generating many platform wheels, which can be a headache. So I wanted to compare performance as a first step.

Thanks a lot for your response, I will close the issue.

alexprengere avatar Sep 22 '22 14:09 alexprengere

That's definitely an intended use case for bindings like this, and I'll emphasize again that the cost here isn't intrinsic to these bindings or Python, it's probably just that I don't know how to write high-performance Python. PRs are of course always welcome for improvements as well.

alexcrichton avatar Sep 22 '22 14:09 alexcrichton

I've been doing some experiments and identifying bottleneck the first one which is solved now is moving data in and out this was demonstrated by moving large data (several MB of images)

the next obvious bottleneck is that there seems to be a fixed context switch cost between python and WASM or vice versa

I can demonstrate this bottleneck by considering the following cdb_djp_hash.c it's a simple 32-bit string hash

and estimated the cost of find hashes of strings of different length, 13, 26, 130, 1300 that is 1x 2x 10x and 100x the work a loop of 100x

and the shocking finding was all this took the same time ~40ms regardless of the length of the string on the other hand the pure python and the c version the time was proportional to the length of the string the pure python started faster than wasm for the 13 byte string with 1.7ms (vs ~40.0 ms in WASM) but for the 1300 string the pure python took 170ms (vs. ~40.0 ms in WASM)

this indicates that most of the 40ms is a constant cost of switching regardless of the actual work done inside the loop

I've some commented profiling to identify exactly where is this waste of time happening but did not have enough time to continue

https://github.com/muayyad-alsadi/wasm-demos/blob/main/cdb_djp_hash/cdb_djp_hash.py#L88

muayyad-alsadi avatar Mar 20 '23 09:03 muayyad-alsadi

doing 10k calls

pr=Profile()
pr.enable()
for i in range(10000):
        cdb_djp_hash_wasm(large_a)
pr.disable()
#pr.print_stats()
#pr.print_stats('cumulative')
pr.print_stats('calls')

so what we do got those "fishy" number of calls

out of the 800ms only 169ms was taken by the actual WASM call wasmtime_func_call


   Ordered by: call count

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   210000    0.018    0.000    0.018    0.000 {built-in method builtins.isinstance}
   140000    0.014    0.000    0.014    0.000 {built-in method builtins.hasattr}
   130000    0.013    0.000    0.013    0.000 {built-in method builtins.len}
    80000    0.011    0.000    0.011    0.000 {built-in method _ctypes.byref}
    70000    0.008    0.000    0.008    0.000 {built-in method _ctypes.POINTER}
    60000    0.009    0.000    0.009    0.000 {built-in method __new__ of type object at 0x7f34f4f430a0}
    50000    0.033    0.000    0.049    0.000 _types.py:45(_from_ptr)
    50000    0.025    0.000    0.047    0.000 _types.py:86(__del__)
    40000    0.022    0.000    0.022    0.000 _bindings.py:178(wasm_valtype_kind)
    40000    0.003    0.000    0.003    0.000 {method 'append' of 'list' objects}
    30000    0.007    0.000    0.010    0.000 _value.py:161(_unwrap_raw)
    30000    0.018    0.000    0.018    0.000 _bindings.py:2440(wasmtime_val_delete)
    30000    0.006    0.000    0.006    0.000 _value.py:109(__init__)
    30000    0.020    0.000    0.044    0.000 _value.py:117(__del__)
    20000    0.032    0.000    0.183    0.000 _value.py:129(_convert)
    20000    0.014    0.000    0.043    0.000 _types.py:12(i32)
    20000    0.017    0.000    0.040    0.000 _types.py:54(__eq__)
    20000    0.033    0.000    0.065    0.000 _types.py:94(_from_list)
    20000    0.008    0.000    0.011    0.000 _func.py:256(enter_wasm)
    20000    0.012    0.000    0.012    0.000 _bindings.py:120(wasm_valtype_delete)
    20000    0.010    0.000    0.010    0.000 _bindings.py:172(wasm_valtype_new)
    20000    0.028    0.000    0.033    0.000 _value.py:39(i32)
    20000    0.006    0.000    0.018    0.000 {built-in method builtins.next}
    20000    0.002    0.000    0.002    0.000 {built-in method _ctypes.addressof}
    10000    0.006    0.000    0.006    0.000 _value.py:167(_value)
    10000    0.006    0.000    0.016    0.000 _value.py:183(value)
    10000    0.008    0.000    0.013    0.000 _types.py:139(_from_ptr)
    10000    0.007    0.000    0.056    0.000 _types.py:148(params)
    10000    0.007    0.000    0.038    0.000 _types.py:157(results)
    10000    0.006    0.000    0.015    0.000 _types.py:169(__del__)
    10000    0.093    0.000    0.685    0.000 _func.py:59(__call__)
    10000    0.011    0.000    0.036    0.000 _func.py:52(type)
    10000    0.011    0.000    0.193    0.000 _func.py:83(<listcomp>)
    10000    0.007    0.000    0.018    0.000 _memory.py:56(data_ptr)
    10000    0.006    0.000    0.015    0.000 _memory.py:65(data_len)
    10000    0.054    0.000    0.112    0.000 wasmtime_fast_memory.py:47(__setitem__)

and when doing it with

np_mem = np.frombuffer(fast_mem.get_buffer_ptr(), dtype=np.uint8)
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   170000    0.015    0.000    0.015    0.000 {built-in method builtins.isinstance}
   140000    0.014    0.000    0.014    0.000 {built-in method builtins.hasattr}
   120000    0.012    0.000    0.012    0.000 {built-in method builtins.len}
    70000    0.009    0.000    0.009    0.000 {built-in method _ctypes.POINTER}
    60000    0.009    0.000    0.009    0.000 {built-in method __new__ of type object at 0x7f988c7430a0}
    60000    0.009    0.000    0.009    0.000 {built-in method _ctypes.byref}
    50000    0.033    0.000    0.050    0.000 _types.py:45(_from_ptr)
    50000    0.024    0.000    0.046    0.000 _types.py:86(__del__)
    40000    0.022    0.000    0.022    0.000 _bindings.py:178(wasm_valtype_kind)
    40000    0.003    0.000    0.003    0.000 {method 'append' of 'list' objects}
    30000    0.018    0.000    0.018    0.000 _bindings.py:2440(wasmtime_val_delete)
    30000    0.006    0.000    0.006    0.000 _value.py:109(__init__)
    30000    0.020    0.000    0.045    0.000 _value.py:117(__del__)
    30000    0.008    0.000    0.010    0.000 _value.py:161(_unwrap_raw)
    20000    0.008    0.000    0.011    0.000 _func.py:256(enter_wasm)
    20000    0.011    0.000    0.011    0.000 _bindings.py:172(wasm_valtype_new)
    20000    0.034    0.000    0.040    0.000 _value.py:39(i32)
    20000    0.032    0.000    0.189    0.000 _value.py:129(_convert)
    20000    0.013    0.000    0.044    0.000 _types.py:12(i32)
    20000    0.017    0.000    0.040    0.000 _types.py:54(__eq__)
    20000    0.032    0.000    0.065    0.000 _types.py:94(_from_list)
    20000    0.012    0.000    0.012    0.000 _bindings.py:120(wasm_valtype_delete)
    20000    0.006    0.000    0.017    0.000 {built-in method builtins.next}
    10000    0.011    0.000    0.200    0.000 _func.py:83(<listcomp>)

we are doing a 10k call, I should see ~10k wasmtime_func_call and 10k wasmtime_func_call and this makes since, those with 20k or 40k are done on each value or before and after function call the count makes since but why we compare lists?

    10000    0.011    0.000    0.206    0.000 _func.py:83(<listcomp>)
    20000    0.033    0.000    0.195    0.000 _value.py:129(_convert)

but what does not make since is having 210k operations it seems to be types, although all of my work is in primitive type that does not need those operations and things does not need to be on variable width lists ...etc. like

    40000    0.003    0.000    0.003    0.000 {method 'append' of 'list' objects}

TL;DR: actionable items

  • make it fixed size not list
  • document away to shortcut convert
  • do the convert using some sort of map or tuple construct not dynamic list, avoid having the loop in python.

muayyad-alsadi avatar Mar 20 '23 13:03 muayyad-alsadi

Re-using the exact gcd.wat from the examples, it looks like any call has a "cost" of about 30μs.

I think I fixed that in #137

@alexprengere would you please test my approach and report how much performance gain did it achieve

muayyad-alsadi avatar Mar 22 '23 08:03 muayyad-alsadi

I've was able to eliminate the 40ms wasted time

here is the code

# gcd_alt.py
import ctypes

from wasmtime import Store, Module, Instance, WasmtimeError
from functools import partial

from wasmtime import _ffi as ffi
from wasmtime._func import enter_wasm
from wasmtime._bindings import wasmtime_val_raw_t

store = Store()
module = Module.from_file(store.engine, './gcd.wat')
instance = Instance(store, module, [])

def func_init(func, store):
    ty = func.type(store)
    ty_params = ty.params
    ty_results = ty.results
    params_str = (str(i) for i in ty_params)
    params_n = len(ty_params)
    results_n = len(ty_results)
    n = max(params_n, results_n)
    raw_type = wasmtime_val_raw_t*n
    func.raw_type = raw_type
    def _create_raw(*params):
        raw = raw_type()
        for i, param_str in enumerate(params_str):
            setattr(raw[i], param_str, params[i])
        return raw
    func._create_raw = _create_raw

_gcd_in = instance.exports(store)["gcd"]
func_init(_gcd_in, store)

def gcd(a, b):
    raw = _gcd_in._create_raw(a, b)
    raw_ptr_casted = ctypes.cast(raw, ctypes.POINTER(wasmtime_val_raw_t))
    with enter_wasm(store) as trap:
            error = ffi.wasmtime_func_call_unchecked(
                store._context,
                ctypes.byref(_gcd_in._func),
                raw_ptr_casted,
                trap)
            if error:
                raise WasmtimeError._from_ptr(error)
    return raw[0].i32

print("gcd(6, 27) = %d" % gcd(6, 27))

and here is the benchmark

import wasmtime.loader
import time
from math import gcd as math_gcd
from gcd import gcd as wasm_gcd
from gcd_alt import gcd as wasm_gcd_alt

def python_gcd(x, y):
    while y:
        x, y = y, x % y
    return abs(x)

N = 1_000
for gcdf in math_gcd, python_gcd, wasm_gcd, wasm_gcd_alt:
    start_time = time.perf_counter()
    for _ in range(N):
        g = gcdf(16516842, 154654684)
    total_time = time.perf_counter() - start_time
    print(total_time)

and here is the result

0.0002523580042179674 # math_gcd
0.0014094869984546676 # python_gcd
0.043804362998344004  #  wasm_gcd
0.005873051006346941  # wasm_gcd_alt <-------- my WASM 

muayyad-alsadi avatar Mar 22 '23 08:03 muayyad-alsadi

we have 3 types of performance bottlenecks

  • overhead of passing large data #135 #81 #134 which is fixed, merged and released as part of 7.0.0 release
  • overhead of calling a small wasm function from python #137 #139 which is fixed not yet merged. in this case the function itself does not take time (ex. gcd of two integer or simple string hash like cdb hash). there is a constant time wasted when calling a wasm function from python regardless of what's inside the function
  • overhead of calling python from wasm. the loop is inside wasm and the small function lives in python

I've addressed the first two, and I'll create a ticket for the last one with details

muayyad-alsadi avatar Apr 05 '23 11:04 muayyad-alsadi