Carlos Martin
Carlos Martin
@fabianp Done: #1089.
Here's a simpler reproduction: ```python3 import jax import optax from jax import numpy as jnp logits = jnp.array(1000.) labels = jnp.array(1) print(jax.grad(optax.sigmoid_binary_cross_entropy)(logits, labels)) # 0.0 print(jax.grad(optax.sigmoid_focal_loss)(logits, labels, gamma=0)) # 0.0...
@mattjj Good point. I thought JAX would immediately optimize `x ** 0.0` to `1.0` before the gradient is taken, but it looks like that's not the case: ```python3 import jax...
`nan` appears under `grad` for any floating `gamma < 1` because `D[x^y, x] = y x^(y - 1)`, which causes a division by zero when `x == 0` and `y...
@vroulet Close. That wrongly gives a zero gradient for `gamma` in the situation where `gamma == 0.0` but the base `1 - p_t` is strictly positive (so there's no `nan`...
This sounds reasonable to me. To make sure everything's covered, I'd explicitly test all combinations of the following: - Generic and degenerate cases - Derivatives of order 0 (values), 1,...
You raise an excellent point. I actually had the same thought: It would be better to make the interface function-based rather than object-based, since functions are more "nimble" (and let...
@q-berthet Thanks for your comment. > there is no risk of accidentally changing one and not the other. I think the same applies to the proposed `noise_fn`. > why do...
@vroulet Sounds good to me.
@vroulet Yes, that kind of caching might be contributing to the speedup as well.