PythonCall w/ Jax: Fast inference w/ numpy but only jax.numpy works for jax.grad
I am trying to use Jax in my Julia codebase for something that Zygote cannot do well (meta-learning). Someone recommended PythonCall as a solution to some issues I was having with PyCall.
So far, PythonCall has been great. Things work and it is generally quite quick.
There is one pain point: jax.grad should work with numpy.array (it does in Python) but it errors with PythonCall:
PyException(<py TypeError("Cannot interpret value of type <class 'juliacall.ArrayValue'> as an abstract array; it does not have a dtype attribute")>)
Instead, jax.numpy.array works with jax.grad but it is slower.
I assume this is due to the way PythonCall handles non-copying conversions under the hood.
Is there any way to reconcile these two and get jax.grad to work with numpy.array and the optimizations for non-copying arrays?
I've included an MWE:
using PythonCall
using Flux
jax = pyimport("jax")
jnp = pyimport("jax.numpy")
np = pyimport("numpy")
stax = pyimport("jax.example_libraries.stax")
optimizers = pyimport("jax.example_libraries.optimizers")
optax = pyimport("optax")
random = pyimport("jax.random")
rngkey = random.PRNGKey(123)
dense = stax.Dense
Relu = stax.Relu
in_shape = (-1, 1)
learner_init, learner_apply = stax.serial(dense(1))
out_shape, learner_params = learner_init(input_shape=in_shape, rng=rngkey)
x = transpose(rand(Float32, 1, 1))
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
opt_state = opt_init(learner_params)
p = get_params(opt_state)
function mse_loss(p,x)
y_hat = learner_apply(p,x)
return jnp.mean(optax.l2_loss(y_hat, x))
end
###
### For inference, both np.array and jnp.array work fine
###
learner_apply(p, np.array(x)); # Faster if you check with @btime
learner_apply(p, jnp.array(x)); # Slower
###
### For gradients, only jnp.array works:
###
try
jax.grad(mse_loss)(p, jnp.array(x))
println("jax.numpy.array works with jax.grad but is slow")
println()
catch e
println(e)
end
try
jax.grad(mse_loss)(p, np.array(x))
catch e
println(e)
println()
println("numpy.array does not work with jax.grad")
end
The CondaPkg.toml file:
[deps]
python = "3.10"
[pip.deps]
jax = ""
jaxlib = ""
optax = ""
What's the full stack trace?
The full stack trace is:
ERROR: LoadError: Python: TypeError: Cannot interpret value of type <class 'juliacall.ArrayValue'> as an abstract array; it does not have a dtype attribute
Python stacktrace:
[1] apply_fun
@ jax.example_libraries.stax ~/julia-dev/RLE2/.CondaPkg/env/lib/python3.10/site-packages/jax/example_libraries/stax.py:61
[2] apply_fun
@ jax.example_libraries.stax ~/julia-dev/RLE2/.CondaPkg/env/lib/python3.10/site-packages/jax/example_libraries/stax.py:307
[3] __call__
@ ~/.julia/packages/PythonCall/3GRYN/src/jlwrap/any.jl:202
Stacktrace:
[1] pythrow()
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/err.jl:94
[2] errcheck
@ ~/.julia/packages/PythonCall/3GRYN/src/err.jl:10 [inlined]
[3] pycallargs(f::Py, args::Py)
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/abstract/object.jl:210
[4] pycall(::Py, ::Py, ::Vararg{Py}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/abstract/object.jl:228
[5] pycall(::Py, ::Py, ::Vararg{Any})
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/abstract/object.jl:218
[6] (::Py)(::Py, ::Vararg{Py}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/Py.jl:352
[7] (::Py)(::Py, ::Vararg{Any})
@ PythonCall ~/.julia/packages/PythonCall/3GRYN/src/Py.jl:352
[8] top-level scope
@ ~/julia-dev/RLE2/jax_problem4.jl:58
[9] include(fname::String)
@ Base.MainInclude ./client.jl:451
[10] top-level scope
@ REPL[127]:1
[11] top-level scope
@ ~/.julia/packages/CUDA/ZdCxS/src/initialization.jl:155
in expression starting at /home/user/julia-dev/RLE2/jax_problem4.jl:58
This is really confusing.
The error itself is fairly self-explanatory: JAX is trying to interpret some Julia array as a JAX-compatible array, and is failing because it is expecting a dtype attribute.
What's confusing is that I don't see how any Julia array is getting that far. AFAICT the only Julia array in your code is x, but that always gets converted to something else (jnp.array(x) or np.array(x)). In the failing part of your code there is only np.array(x) which should be returning a numpy array copying x - but not referencing x in any way.
You'll need to do some digging to figure out how a Julia array is getting that far. I can't dig myself as I only have a Windows computer right now and jaxlib requires Linux or Mac.
Seems like the Tracer class that Jax uses to calculate a gradient somehow grabs the Julia PyArray. I found this out from printing pytype(args_[0][0][0]) (to unpack the tuple, list, etc) and args_ in the pycall function found in src/abstract/object.jl:266. Here is the type value of args_ which triggers the error
<class 'jax._src.interpreters.ad.JVPTracer'>
([(Traced<ConcreteArray([[0.14509337]], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([[0.14509337]], dtype=float32)
tangent = Traced<ShapedArray(float32[1,1])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[1,1]), None)
recipe = LambdaBinding(), Traced<ConcreteArray([-0.01266545], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([-0.01266545], dtype=float32)
tangent = Traced<ShapedArray(float32[1])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[1]), None)
recipe = LambdaBinding())], Julia:
1×1 PyArray{Float32, 2}:
0.088733494)
ERROR: LoadError: Python: TypeError: Cannot interpret value of type <class 'juliacall.ArrayValue'> as an abstract array; it does not have a dtype attribute
I should also add this error does not occur if np.array(x) is not an argument to mse_loss, and is instead treated like a constant. This is not a useable workaround, however, because the gradient function will be jit'd and needs to take the data as an argument.
fixed_x = np.array(x)
function mse_loss2(p)
y_hat = learner_apply(p,fixed_x)
return jnp.mean(optax.l2_loss(y_hat, fixed_x))
end
jax.grad(mse_loss2)(p) #Works!
Edit: I realized, I do not need to jit the gradient function itself but instead jit around a function that calculates the gradient (and updates the parameters, but not shown below), like so:
function get_grad(p, fixed_x)
function mse_loss_fixed(p)
y_hat = learner_apply(p,fixed_x)
return jnp.mean(optax.l2_loss(y_hat, fixed_x))
end
g = jax.grad(mse_loss_fixed)(p)
end
get_grad(p, np.array(x))
get_grad_jit = jax.jit(get_grad)
get_grad_jit(p, np.array(x))
A bit hacky, but it seems to work. Benchmarking with BenchmarkTools.btime, looks like we can get the benefits of non-copy conversion and jit compilation in jax!
julia> @btime get_grad(p, np.array(x));
3.548 ms (82 allocations: 1.84 KiB)
julia> @btime get_grad(p, jnp.array(x));
5.476 ms (879 allocations: 91.44 KiB)
julia> @btime get_grad_jit(p, np.array(x));
7.286 μs (23 allocations: 672 bytes)
julia> @btime get_grad_jit(p, jnp.array(x));
1.815 ms (820 allocations: 90.25 KiB)
OK I've realised what's going on here. When you call mse_loss(p, np.array(x)) then the numpy array gets wrapped as a PyArray on the Julia side (that is, the x variable inside mse_loss is a PyArray), and when this gets passed back to Python in the functions that mse_loss calls, it is not unwrapped back to the original numpy array, but is wrapped again as a Julia object (specifically a juliacall.VectorValue). This wrapped object behaves a lot like a numpy array, but not enough for Jax, and hence the error.
The clue was in this bit of your above output:
Julia:
1×1 PyArray{Float32, 2}:
0.088733494
Note that it says PyArray and not Array.
Firstly, this has highlighted a bug, which is that PyArray was not marked as a Python wrapper type. Wrapper types always get unwrapped when passed back to Python, which would have solved your problem, since then the Python code would have only seen the numpy array. I have fixed this on the main branch.
Secondly, an arguably better solution to your problem is to use pyfunc, which wraps a Julia function as a Python function:
jax.grad(pyfunc(mse_loss))(p, np.array(x))
The difference between this and just using default conversion is that the arguments to the wrapped function are unconverted Python objects (of type Py). This means that the numpy array is still a numpy array inside mse_loss, not a PyArray. You usually want to use this for functions passed to Python from Julia (e.g. callback functions) since then you have total control over the arguments - the default behaviour which converts the arguments is more intended for calling arbitrary Julia functions from Python.
I don't know what you're ultimately trying to do, but you can also just use @pyexec to do all the Python stuff entirely in Python. For example:
@pyexec """
def closure():
import jax, jax.numpy, ...
...
def mse_loss(p, x):
...
grad = jax.grad(mse_loss)
return loss, grad, p
loss, grad, p = closure()
""" => (loss, grad, p)
grad(p, np.array(x))
This issue has been marked as stale because it has been open for 30 days with no activity. If the issue is still relevant then please leave a comment, or else it will be closed in 7 days.
This issue has been closed because it has been stale for 7 days. You can re-open it if it is still relevant.