probability icon indicating copy to clipboard operation
probability copied to clipboard

Softplus leaks memory (and is no longer needed)

Open nfergu opened this issue 6 months ago • 0 comments

The TensorFlow probability implementation of softplus leaks memory, and appears to no longer be needed. That is, I think the standard tf.nn.softplus implementation can be used now, as numerical stability issues appear to have been solved.

Currently the implementation of softplus is as follows (from here):

# TODO(b/155501444): Remove this when tf.nn.softplus is fixed.
if JAX_MODE:
  _stable_grad_softplus = tf.nn.softplus
else:

  @tf.custom_gradient
  def _stable_grad_softplus(x):
    """A (more) numerically stable softplus than `tf.nn.softplus`."""
    x = tf.convert_to_tensor(x)
    if x.dtype == tf.float64:
      cutoff = -20
    else:
      cutoff = -9

    y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x))

    def grad_fn(dy):
      return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x))

    return y, grad_fn

This leaks memory (in non-JAX mode) due to a couple of issues:

  • The grad_fn closure captures the tensor represented by x. This closure then ends up in the gradient registry, which is never cleared. So the tensor represented by x hangs around forever.
  • For a similar reason TensorFlow's custom_gradient implementation also leaks memory. See 97697 for more details.

Here is a Colab notebook to demonstrate the memory leak.

However, I believe that the numerical stability issues with tf.nn.softplus have been solved. Specifically:

  • The tf.nn.softplus implementation now uses log1p as of this commit on May 1 2020.
  • The gradient computation for tf.nn.softplus now uses math_ops.sigmoid as of this commit on April 4 2019.
  • The Eigen implementation of sigmoid (which I think is here) computes this as e^x / 1.0 + e^x, so using the approximation of e^x in _stable_grad_softplus seems unnecessary to me. If e^x is very small then 1.0 + e^x will be exactly 1.0, so this is equivalent to e^x. If e^x > 1.0 then the result of e^x / 1.0 + e^x will be (I think) more accurate than just approximating the gradient to e^x. But I am not a numerical stability expert, so I may be wrong.

nfergu avatar Jul 30 '25 23:07 nfergu