kernex
kernex copied to clipboard
Support Pytrees
Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.
Repro:
import jax.numpy as jnp
import kernex
@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
x, y = tree
return jnp.sum(x * jnp.square(y))
data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))
raises
Traceback (most recent call last):
File "/home/clemisch/kernex_tree.py", line 52, in <module>
out = kernel((data, data))
^^^^^^^^^^^^^^^^^^^^
File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
self.shape = array.shape
^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'
Hello, Thanks for your question. This is a reasonable request; I will try to look into it when I have time.
Hello, meanwhile, can you try this ?
The key point here is to stack the arrays on some axis i
and make the kernel size for that axis i
equal to the same size as the axis size with valid
padding for that axis. In this example, i
is the first axis.
I also recommend using jax.debug.print
to ensure the array views are what you are looking for.
import jax.numpy as jnp
import kernex
import jax
@kernex.kmap(kernel_size=(2, 3, 3), padding=("valid","valid","valid"))
def kernel(tree):
x, y = tree
jax.debug.print("x={x} \n\n y={y}\n",x=x, y=y)
return jnp.sum(x * jnp.square(y))
data = jnp.arange(25).reshape(5, 5)
out = kernel(jnp.stack([data, data],axis=0))
# x=[[ 0 1 2]
# [ 5 6 7]
# [10 11 12]]
# y=[[ 0 1 2]
# [ 5 6 7]
# [10 11 12]]
# x=[[ 1 2 3]
# [ 6 7 8]
# [11 12 13]]
# y=[[ 1 2 3]
# [ 6 7 8]
# [11 12 13]]
# x=[[ 2 3 4]
# [ 7 8 9]
# [12 13 14]]
# y=[[ 2 3 4]
# [ 7 8 9]
# [12 13 14]]
# x=[[ 5 6 7]
# [10 11 12]
# [15 16 17]]
# y=[[ 5 6 7]
# [10 11 12]
# [15 16 17]]
# x=[[ 6 7 8]
# [11 12 13]
# [16 17 18]]
# y=[[ 6 7 8]
# [11 12 13]
# [16 17 18]]
# x=[[ 7 8 9]
# [12 13 14]
# [17 18 19]]
# y=[[ 7 8 9]
# [12 13 14]
# [17 18 19]]
# x=[[10 11 12]
# [15 16 17]
# [20 21 22]]
# y=[[10 11 12]
# [15 16 17]
# [20 21 22]]
# x=[[11 12 13]
# [16 17 18]
# [21 22 23]]
# y=[[11 12 13]
# [16 17 18]
# [21 22 23]]
# x=[[12 13 14]
# [17 18 19]
# [22 23 24]]
# y=[[12 13 14]
# [17 18 19]
# [22 23 24]]
Thanks, that works for me!
As a follow-up, I think it is simpler to define which argnums
to generate kernel. For the previous example maybe the API would be something like this kmap(.., argnums=(0,1))(lambda x,y: ... )
What do you think?
Thanks for the follow-up and including me in this.
To clarify, do you mean not supporting trees, but instead multiple arguments? So something like
@kernex.kmap(kernel_size=(3, 3), argnums=(0, 1))
def kernel(x, y):
return jnp.sum(x * jnp.square(y))
, or for non-mapped local weights
@kernex.kmap(kernel_size=(3, 3), argnums=(0,))
def kernel(x, y_local):
return jnp.sum(x * jnp.square(y_local))
where y_local
would not be mapped over y
but a constant (3,3)
array.
TLDR: Anything is fine for me. I think supporting trees would be slightly more powerful, but any reasonable task should be translatable to multiple args
instead of a tree.