Carlos Martin

Results 114 comments of Carlos Martin

@fabianp Done: #1089.

Here's a simpler reproduction: ```python3 import jax import optax from jax import numpy as jnp logits = jnp.array(1000.) labels = jnp.array(1) print(jax.grad(optax.sigmoid_binary_cross_entropy)(logits, labels)) # 0.0 print(jax.grad(optax.sigmoid_focal_loss)(logits, labels, gamma=0)) # 0.0...

@mattjj Good point. I thought JAX would immediately optimize `x ** 0.0` to `1.0` before the gradient is taken, but it looks like that's not the case: ```python3 import jax...

`nan` appears under `grad` for any floating `gamma < 1` because `D[x^y, x] = y x^(y - 1)`, which causes a division by zero when `x == 0` and `y...

@vroulet Close. That wrongly gives a zero gradient for `gamma` in the situation where `gamma == 0.0` but the base `1 - p_t` is strictly positive (so there's no `nan`...

This sounds reasonable to me. To make sure everything's covered, I'd explicitly test all combinations of the following: - Generic and degenerate cases - Derivatives of order 0 (values), 1,...

You raise an excellent point. I actually had the same thought: It would be better to make the interface function-based rather than object-based, since functions are more "nimble" (and let...

@q-berthet Thanks for your comment. > there is no risk of accidentally changing one and not the other. I think the same applies to the proposed `noise_fn`. > why do...

@vroulet Yes, that kind of caching might be contributing to the speedup as well.