feat(scipy.special): add erfcx — scaled complementary error function
PR Description
Overview
This PR adds support for the scaled complementary error function, erfcx, to the jax.scipy.special module.
The function is defined mathematically as:
erfcx(x)=ex2⋅erfc(x) erfcx(x)=ex2⋅erfc(x)
It is particularly useful in numerical computations involving large values of x, where erfc alone may underflow. erfcx ensures stability and precision by scaling the output appropriately. What's Included
Function Implementation
Added erfcx in jax/_src/scipy/special.py using JAX-native ops:
return lax.mul(lax.exp(lax.square(x)), lax.erfc(x))
Public API Exposure
Exported erfcx in jax/scipy/special.py for user access.
Documentation
Added erfcx to the Sphinx autosummary in docs/jax.scipy.rst with proper formatting and math support.
Unit Tests
Included op_record entry in tests/lax_scipy_special_functions_test.py
Validated JAX outputs against scipy.special.erfcx using randomized input values and standard float dtypes.
Motivation
The addition of erfcx closes a functional gap between JAX and SciPy for scientific computing workloads involving Gaussian integrals, differential equations, and error function-based models. It maintains numerical robustness in cases where exp(x^2) and erfc(x) would individually overflow or underflow, making it a critical function for users requiring stable large-x behavior. Notes
The current implementation supports only real-valued inputs.
Gradients are supported via autodiff; additional tests for edge-case gradients can be added in future follow-ups.
No dependencies were introduced.
Testing & Verification
pre-commit passed for all modified files.
pytest -k erfcx confirms output matches scipy.special.erfcx on supported inputs.
JIT and vmap compatibility verified locally.
Please let me know if additional coverage, benchmarks, or gradient validation would be helpful for merge readiness.
It might be useful: https://github.com/jax-ml/jax/pull/3856#issuecomment-663746553
When x > 26, we can compute the asymptotic expansion as x -> +infinity:
erfcx(x) = 1/sqrt(π) * (1/x - 1/(2x³) + 3/(4x⁵) - 15/(8x⁷))
For x > 26, this provides a relative error of no more than 2.6e-10.
The only caveat is that for values beyond approximately 2.535599352761576e+307, it starts returning zero, which differs from SciPy's behavior.
This is for float64. For float32, the boundary should be set at x > 9.1.
@KAVYANSHTYAGI Are you planning to continue to develop erfcx for jax? This would be very helpful for my work and would be happy to contribute as well.