jax icon indicating copy to clipboard operation
jax copied to clipboard

feat(scipy.special): add erfcx — scaled complementary error function

Open KAVYANSHTYAGI opened this issue 6 months ago • 2 comments

PR Description

Overview

This PR adds support for the scaled complementary error function, erfcx, to the jax.scipy.special module.

The function is defined mathematically as:

erfcx(x)=ex2⋅erfc(x) erfcx(x)=ex2⋅erfc(x)

It is particularly useful in numerical computations involving large values of x, where erfc alone may underflow. erfcx ensures stability and precision by scaling the output appropriately. What's Included

Function Implementation

Added erfcx in jax/_src/scipy/special.py using JAX-native ops:

return lax.mul(lax.exp(lax.square(x)), lax.erfc(x))

Public API Exposure

Exported erfcx in jax/scipy/special.py for user access.

Documentation

Added erfcx to the Sphinx autosummary in docs/jax.scipy.rst with proper formatting and math support.

Unit Tests

Included op_record entry in tests/lax_scipy_special_functions_test.py

Validated JAX outputs against scipy.special.erfcx using randomized input values and standard float dtypes.

Motivation

The addition of erfcx closes a functional gap between JAX and SciPy for scientific computing workloads involving Gaussian integrals, differential equations, and error function-based models. It maintains numerical robustness in cases where exp(x^2) and erfc(x) would individually overflow or underflow, making it a critical function for users requiring stable large-x behavior. Notes

The current implementation supports only real-valued inputs.

Gradients are supported via autodiff; additional tests for edge-case gradients can be added in future follow-ups.

No dependencies were introduced.

Testing & Verification

pre-commit passed for all modified files.

pytest -k erfcx confirms output matches scipy.special.erfcx on supported inputs.

JIT and vmap compatibility verified locally.

Please let me know if additional coverage, benchmarks, or gradient validation would be helpful for merge readiness.

KAVYANSHTYAGI avatar Jun 13 '25 12:06 KAVYANSHTYAGI

It might be useful: https://github.com/jax-ml/jax/pull/3856#issuecomment-663746553

DanisNone avatar Jun 19 '25 08:06 DanisNone

When x > 26, we can compute the asymptotic expansion as x -> +infinity:

erfcx(x) = 1/sqrt(π) * (1/x - 1/(2x³) + 3/(4x⁵) - 15/(8x⁷))

For x > 26, this provides a relative error of no more than 2.6e-10.

The only caveat is that for values beyond approximately 2.535599352761576e+307, it starts returning zero, which differs from SciPy's behavior.

This is for float64. For float32, the boundary should be set at x > 9.1.

DanisNone avatar Jun 19 '25 16:06 DanisNone

@KAVYANSHTYAGI Are you planning to continue to develop erfcx for jax? This would be very helpful for my work and would be happy to contribute as well.

TheSkyentist avatar Nov 19 '25 13:11 TheSkyentist