diffeqpy
diffeqpy copied to clipboard
Trick to avoid python overhead
I figured out a way to avoid the python overhead if the right-hand-side function is implemented using numba, and I thought maybe people might be interested.
The trick is to use numba.cfunc
, pass the pointer to julia and then call it in a small wrapper function with ccall
. A small notebook to show the basic idea:
https://gist.github.com/aseyboldt/c1ebffcc1e2e217943a9372898eb8d87
Nice, thanks for working that out. Would you mind turning that into a README example? That should work well in general.
I can do that. It would be nice if we could come up with a better way of passing the function pointer to the julia code, right now I'm using using a global on the julia side. Could we write a function in diffeqpy that takes a python function as input and returns a wrapped julia function for it?
@tkf do you know of a better way?
Could we write a function in diffeqpy that takes a python function as input and returns a wrapped julia function for it?
In theory, I think so. If the function is simple enough I might use ModelingToolkit for that, but more generally a wrapped JIT approach like this would probably be a good one to support.
This should do the trick:
wrap_python_callback = Main.eval(
"""
using PyCall
function wrap_python_callback(callback)
py_func_ptr = callback.address
func_ptr = convert(Csize_t, py_func_ptr)
func_ptr = convert(Ptr{Nothing}, func_ptr)
function rhs_wrapper(du, u, p, t)
n_states = length(du)
n_params = length(p)
ccall(
convert(Ptr{Nothing}, func_ptr),
Nothing,
(Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
n_states,
n_params,
du,
u,
p,
t
)
end
end
"""
)
prob = de.ODEProblem(wrap_python_callback(f), u0, tspan, p)
I guess we could also compile with fixed array shapes on the numba side, I'll play a bit with that...
By the way: Are there plans to also wrap some of the sensitivity (forward, adjoint..) from DifferentialEquations?
@aseyboldt This is very cool! Thanks for posting this. Using closure seems to be a nice way to do it.
I guess we could also compile with fixed array shapes on the numba side, I'll play a bit with that...
My hunch is that it does not matter much as the arrays has to be heap-allocated anyway if you pass them using Ptr
via ccall
. I don't have a serious experience using numba so I may be wrong.
(For small-dimensional diffeq where this is important, I think a better way is to convert sympy expression to the diffeq #32. "Better" in the sense if you have existing Python function/library that you want to use. If you are open to learn new things, ModelingToolkit may be better.)
By the way: Are there plans to also wrap some of the sensitivity (forward, adjoint..) from DifferentialEquations?
I'm not familar with jax at all but if you can get the function pointer of the jacobian generated via jax, like you do it with numba, I guess it's possible to hook it into the sensitivity analysis?
Actually, it may be dangerous as-is. Does it hurt the performance if you do
wrap_python_callback = Main.eval(
"""
using PyCall
function wrap_python_callback(callback)
function rhs_wrapper(du, u, p, t)
py_func_ptr = callback.address
func_ptr = convert(Csize_t, py_func_ptr)
func_ptr = convert(Ptr{Nothing}, func_ptr)
...
? If so, I think you need to be careful that callback
passed to the Julia side is not GC'ed at Python side. It's depending on how Numba implements this. But I think it'd be safe to leak callback
on Python side to protect that the function is never GC'ed while Julia is using it.
Good point about keeping callback
alive.
Unfortunately your simple solution is much slower, in that example it takes 20ms instead of 3.5ms.
I really don't know julia well, so maybe this is a very basic question: What exactly does solve accept as function argument? It seems to convert it to an ODEFunction struct (https://github.com/SciML/DiffEqBase.jl/blob/master/src/diffeqfunction.jl#L21). I guess it just calls the constructor function here? Would it be possible to add a field of type Any
to ODEFunction and store the callback object in there?
Is it possible that pyjulia leaks objects that are passed into function in general? I can't get it to finalize an object after passing it to a julia function:
import weakref
class Foo:
pass
def foo():
obj = Foo()
def destroy():
print('destroying...')
weakref.finalize(obj, destroy)
use_python_obj = Main.eval(
"""
using PyCall
function use_python_val(obj)
1
end
"""
)
use_python_obj(obj)
gc.collect()
foo()
This should print 'destroying' I think, but doesn't. It does if we don't pass obj to a julia function. Could this be a pyjulia/pycall bug where it doesn't decrement the refcount properly?
I opened a separate issue about this leak for pyjulia: https://github.com/JuliaPy/pyjulia/issues/365
Sorry for all the noise, but that leak turned out not to exist :-) I forgot to call the julia gc...
But from what I can tell with a little experiment, just referencing the callback in the inner function seems to capture it even if we don't use it. It does clean up the callback if we comment out the marked line. @tkf Do you know if that is guaranteed or if that could change in the future due to some compiler optmizations?
import weakref
def foo():
c_sig = numba.types.int64(
numba.types.int64,
numba.types.int64,
numba.types.CPointer(numba.types.float64),
numba.types.CPointer(numba.types.float64),
numba.types.CPointer(numba.types.float64),
numba.types.float64
)
@numba.cfunc(c_sig)
def f(n_states, n_params, du_, u_, p_, t):
u = numba.carray(u_, 3)
du = numba.carray(du_, 3)
p = numba.carray(p_, 3)
x, y, z = u
sigma, rho, beta = p
du[0] = sigma * (y - x)
du[1] = x * (rho - z) - y
du[2] = x * y - beta * z
return 0
def destroy():
print('destroying...')
weakref.finalize(f, destroy)
wrap_python_callback = Main.eval(
"""
using PyCall
function wrap_python_callback(callback)
py_func_ptr = callback.address
func_ptr = convert(Csize_t, py_func_ptr)
func_ptr = convert(Ptr{Nothing}, func_ptr)
function f(du,u,p,t)
callback # !!!!!!!!!!!!!!!!!!!!!
n_states = length(du)
n_params = length(p)
ccall(
convert(Ptr{Nothing}, func_ptr),
Nothing,
(Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
n_states,
n_params,
du,
u,
p,
t
)
#return callback
end
end
"""
)
out = wrap_python_callback(f)
Main.GC.gc()
gc.collect()
return out
foo()
Do you know if that is guaranteed or if that could change in the future due to some compiler optmizations?
Yeah, I think it's possible that compiler can eliminate this in the future.
I think what we need is GC.@preserve
:
function wrap_python_callback(callback)
py_func_ptr = callback.address
func_ptr = convert(Csize_t, py_func_ptr)
func_ptr = convert(Ptr{Nothing}, func_ptr)
function f(du,u,p,t)
n_states = length(du)
n_params = length(p)
GC.@preserve callback ccall(
convert(Ptr{Nothing}, func_ptr),
Nothing,
(Int64, Int64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64),
n_states,
n_params,
du,
u,
p,
t
)
end
end
Does it still hurt the performance?
BTW, my longer term answer to this is that I kind of think that ModelingToolkit would be a good interface from Python, so I plan to update the docs to showcase how to use it. In the backend it would be generating Julia code so the overhead would be cut out, and it would give a nice front end DSL.
Using the new JIT functions does well