kernex
kernex copied to clipboard
`mode="reflect"` in padding is incorrect
I think mode="reflect"
for padding_kwargs
is incorrect:
import jax.numpy as jnp
import kernex
@kernex.kmap(
kernel_size=(3,),
padding=("same"),
relative=False,
padding_kwargs=dict(mode="reflect"),
)
def f(x):
return x
x = jnp.array([1, 2, 3, 4, 5])
y = f(x)
z = jnp.pad(x, 1, mode="reflect")
print("x: ", x)
print("y: ", y)
print("z: ", z)
gives
x: [1 2 3 4 5]
y: [[3 1 2] # <-- the `3` is incorrect, should be `2`
[1 2 3]
[2 3 4]
[3 4 5]
[4 5 4]]
z: [2 1 2 3 4 5 4] # <-- here, the first element is `2`
The Kernex output reflects incorrectly: the first element is 3
instead of 2
.