diffeqpy icon indicating copy to clipboard operation
diffeqpy copied to clipboard

Trick to avoid python overhead

Open aseyboldt opened this issue 4 years ago • 12 comments

I figured out a way to avoid the python overhead if the right-hand-side function is implemented using numba, and I thought maybe people might be interested. The trick is to use numba.cfunc, pass the pointer to julia and then call it in a small wrapper function with ccall. A small notebook to show the basic idea: https://gist.github.com/aseyboldt/c1ebffcc1e2e217943a9372898eb8d87

aseyboldt avatar Apr 08 '20 13:04 aseyboldt

Nice, thanks for working that out. Would you mind turning that into a README example? That should work well in general.

ChrisRackauckas avatar Apr 08 '20 16:04 ChrisRackauckas

I can do that. It would be nice if we could come up with a better way of passing the function pointer to the julia code, right now I'm using using a global on the julia side. Could we write a function in diffeqpy that takes a python function as input and returns a wrapped julia function for it?

aseyboldt avatar Apr 08 '20 21:04 aseyboldt

@tkf do you know of a better way?

Could we write a function in diffeqpy that takes a python function as input and returns a wrapped julia function for it?

In theory, I think so. If the function is simple enough I might use ModelingToolkit for that, but more generally a wrapped JIT approach like this would probably be a good one to support.

ChrisRackauckas avatar Apr 08 '20 21:04 ChrisRackauckas

This should do the trick:

wrap_python_callback = Main.eval(
    """
    using PyCall

    function wrap_python_callback(callback)
        py_func_ptr = callback.address
        func_ptr = convert(Csize_t, py_func_ptr)
        func_ptr = convert(Ptr{Nothing}, func_ptr)
        
        function rhs_wrapper(du, u, p, t)
          n_states = length(du)
          n_params = length(p)
          ccall(
              convert(Ptr{Nothing}, func_ptr),
              Nothing,
              (Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
              n_states,
              n_params,
              du,
              u,
              p,
              t
            )
        end
    end
    """
)

prob = de.ODEProblem(wrap_python_callback(f), u0, tspan, p)

I guess we could also compile with fixed array shapes on the numba side, I'll play a bit with that...

By the way: Are there plans to also wrap some of the sensitivity (forward, adjoint..) from DifferentialEquations?

aseyboldt avatar Apr 09 '20 07:04 aseyboldt

@aseyboldt This is very cool! Thanks for posting this. Using closure seems to be a nice way to do it.

I guess we could also compile with fixed array shapes on the numba side, I'll play a bit with that...

My hunch is that it does not matter much as the arrays has to be heap-allocated anyway if you pass them using Ptr via ccall. I don't have a serious experience using numba so I may be wrong.

(For small-dimensional diffeq where this is important, I think a better way is to convert sympy expression to the diffeq #32. "Better" in the sense if you have existing Python function/library that you want to use. If you are open to learn new things, ModelingToolkit may be better.)

By the way: Are there plans to also wrap some of the sensitivity (forward, adjoint..) from DifferentialEquations?

I'm not familar with jax at all but if you can get the function pointer of the jacobian generated via jax, like you do it with numba, I guess it's possible to hook it into the sensitivity analysis?

tkf avatar Apr 09 '20 07:04 tkf

Actually, it may be dangerous as-is. Does it hurt the performance if you do

wrap_python_callback = Main.eval(
    """
    using PyCall

    function wrap_python_callback(callback)
        
        function rhs_wrapper(du, u, p, t)
          py_func_ptr = callback.address
          func_ptr = convert(Csize_t, py_func_ptr)
          func_ptr = convert(Ptr{Nothing}, func_ptr)
...

? If so, I think you need to be careful that callback passed to the Julia side is not GC'ed at Python side. It's depending on how Numba implements this. But I think it'd be safe to leak callback on Python side to protect that the function is never GC'ed while Julia is using it.

tkf avatar Apr 09 '20 08:04 tkf

Good point about keeping callback alive. Unfortunately your simple solution is much slower, in that example it takes 20ms instead of 3.5ms. I really don't know julia well, so maybe this is a very basic question: What exactly does solve accept as function argument? It seems to convert it to an ODEFunction struct (https://github.com/SciML/DiffEqBase.jl/blob/master/src/diffeqfunction.jl#L21). I guess it just calls the constructor function here? Would it be possible to add a field of type Any to ODEFunction and store the callback object in there?

aseyboldt avatar Apr 09 '20 08:04 aseyboldt

Is it possible that pyjulia leaks objects that are passed into function in general? I can't get it to finalize an object after passing it to a julia function:

import weakref

class Foo:
    pass

def foo():
    obj = Foo()

    def destroy():
        print('destroying...')
    weakref.finalize(obj, destroy)
    
    use_python_obj = Main.eval(
        """
        using PyCall

        function use_python_val(obj)
            1
        end
        """
    )
    
    use_python_obj(obj)
    gc.collect()

foo()

This should print 'destroying' I think, but doesn't. It does if we don't pass obj to a julia function. Could this be a pyjulia/pycall bug where it doesn't decrement the refcount properly?

aseyboldt avatar Apr 09 '20 09:04 aseyboldt

I opened a separate issue about this leak for pyjulia: https://github.com/JuliaPy/pyjulia/issues/365

aseyboldt avatar Apr 09 '20 10:04 aseyboldt

Sorry for all the noise, but that leak turned out not to exist :-) I forgot to call the julia gc...

But from what I can tell with a little experiment, just referencing the callback in the inner function seems to capture it even if we don't use it. It does clean up the callback if we comment out the marked line. @tkf Do you know if that is guaranteed or if that could change in the future due to some compiler optmizations?

import weakref


def foo():
    c_sig = numba.types.int64(
        numba.types.int64,
        numba.types.int64,
        numba.types.CPointer(numba.types.float64),
        numba.types.CPointer(numba.types.float64),
        numba.types.CPointer(numba.types.float64),
        numba.types.float64
    )

    @numba.cfunc(c_sig)
    def f(n_states, n_params, du_, u_, p_, t):
        u = numba.carray(u_, 3)
        du = numba.carray(du_, 3)
        p = numba.carray(p_, 3)
        x, y, z = u
        sigma, rho, beta = p
        du[0] = sigma * (y - x)
        du[1] = x * (rho - z) - y
        du[2] = x * y - beta * z
        return 0

    def destroy():
        print('destroying...')
    weakref.finalize(f, destroy)
    
    wrap_python_callback = Main.eval(
        """
        using PyCall

        function wrap_python_callback(callback)
            py_func_ptr = callback.address
            func_ptr = convert(Csize_t, py_func_ptr)
            func_ptr = convert(Ptr{Nothing}, func_ptr)

            function f(du,u,p,t)
                callback   # !!!!!!!!!!!!!!!!!!!!!
                n_states = length(du)
                n_params = length(p)
                ccall(
                    convert(Ptr{Nothing}, func_ptr),
                    Nothing,
                    (Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
                    n_states,
                    n_params,
                    du,
                    u,
                    p,
                    t
                )
                #return callback
            end
        end
        """
    )
    out = wrap_python_callback(f)
    Main.GC.gc()
    gc.collect()
    return out

foo()

aseyboldt avatar Apr 09 '20 11:04 aseyboldt

Do you know if that is guaranteed or if that could change in the future due to some compiler optmizations?

Yeah, I think it's possible that compiler can eliminate this in the future.

I think what we need is GC.@preserve:

function wrap_python_callback(callback)
    py_func_ptr = callback.address
    func_ptr = convert(Csize_t, py_func_ptr)
    func_ptr = convert(Ptr{Nothing}, func_ptr)

    function f(du,u,p,t)
        n_states = length(du)
        n_params = length(p)
        GC.@preserve callback ccall(
            convert(Ptr{Nothing}, func_ptr),
            Nothing,
            (Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
            n_states,
            n_params,
            du,
            u,
            p,
            t
        )
    end
end

Does it still hurt the performance?

tkf avatar Apr 09 '20 21:04 tkf

BTW, my longer term answer to this is that I kind of think that ModelingToolkit would be a good interface from Python, so I plan to update the docs to showcase how to use it. In the backend it would be generating Julia code so the overhead would be cut out, and it would give a nice front end DSL.

ChrisRackauckas avatar Apr 14 '20 01:04 ChrisRackauckas

Using the new JIT functions does well

ChrisRackauckas avatar Oct 21 '23 11:10 ChrisRackauckas