spu icon indicating copy to clipboard operation
spu copied to clipboard

使用SPU实现chi-squared统计量

Open Candicepan opened this issue 11 months ago • 28 comments

此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)任务 ISSUE,欢迎社区开发者参与共建~ 若有感兴趣想要认领的任务,但还未报名,辛苦先完成报名进行哈~

任务介绍

  • 任务名称:使用SPU实现chi-squared统计量
  • 技术方向:SPU/SML
  • 任务难度:进阶🌟🌟
  • 任务完成时间:4周

详细要求

  • 安全性(尽量少 reveal)
  • 功能性:同时返回chi2统计量和对应的p value
  • 正确性:与sklearn的结果数值上基本一致(在可接受误差范围内)
  • 代码规范:Python 代码需要使用 black+isort 进行格式化(流水线包含代码规范检查卡点)
  • 提交说明:关联该 isuue 并提交代码至 https://github.com/secretflow/spu/tree/main/sml(具体提交文件夹可与reviewer讨论)
  • 特殊说明:若某个特性有特殊的限制,如需要 FM128,需要更多 fxp 等需要在注释文档中明确说明

能力要求

  • 熟悉经典的统计学理论
  • 熟悉 JAX 或 NumPy,可以使用 NumPy 实现算法

操作说明

  • 可参考sklearn中的说明: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.feature_selection
  • 操作指引:https://www.secretflow.org.cn/docs/spu/latest/en-US/getting_started/tutorials/develop_your_first_mpc_application
  • 范例:https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py
  • 可参考隐语mooc2期视频:
    • https://www.bilibili.com/video/BV1ba4y1S73z/?spm_id_from=333.788
    • https://www.bilibili.com/video/BV1Qe411D7MG/?spm_id_from=333.788
  • 贡献指引文档: https://github.com/secretflow/spu/blob/main/sml/development.md

Candicepan avatar Mar 08 '24 06:03 Candicepan

lizzy-0323 Give it to me

lizzy-0323 avatar Mar 27 '24 14:03 lizzy-0323

Please tell me which folder should I submit to?

lizzy-0323 avatar Mar 28 '24 07:03 lizzy-0323

Please tell me which folder should I submit to?

hello, you can first create a new folder named "feature_selection", with a new file named "univariate_selection.py", and then implement the chi2 function in this file.

In other words, the chi2 should be implemented in sml/feature_selection/univariate_selection.py;

(BTW, don't forget the essential tests/emulations folders and the bazel files.)

Best wishes

deadlywing avatar Mar 28 '24 07:03 deadlywing

Thank you!

lizzy-0323 avatar Mar 28 '24 07:03 lizzy-0323

关于docker环境上无法运行编译运行的问题

os

Ubuntu 20.04.6 LTS

运行命令

bazel run -c opt //sml/cluster/tests:kmeans_test

报错

image ps:前置库均已安装完毕,貌似是网络问题?但没有找到解决方法0.0

lizzy-0323 avatar Mar 30 '24 06:03 lizzy-0323

关于docker环境上无法运行编译运行的问题

os

Ubuntu 20.04.6 LTS

运行命令

bazel run -c opt //sml/cluster/tests:kmeans_test

报错

image ps:前置库均已安装完毕,貌似是网络问题?但没有找到解决方法0.0

请确保 docker 可以正常访问 github

anakinxc avatar Mar 30 '24 06:03 anakinxc

docker是可以正常访问github的,clone等操作也没问题,并且使用了代理。

lizzy-0323 avatar Apr 01 '24 06:04 lizzy-0323

docker是可以正常访问github的,clone等操作也没问题,并且使用了代理。

hmmm, 看起来和上周末 xz 发生的故事有点关系,我来看看

anakinxc avatar Apr 01 '24 06:04 anakinxc

@lizzy-0323 main 上最新的代码再试一下?

anakinxc avatar Apr 01 '24 07:04 anakinxc

@lizzy-0323 main 上最新的代码再试一下?

好的 已经成功运行起来啦!

lizzy-0323 avatar Apr 01 '24 10:04 lizzy-0323

请问如果我要统计标签的种类,是否在调用chi2函数前要求传入num_class这一参数呢,因为需要满足jit的条件,似乎大部分的方法都不可以使用。

lizzy-0323 avatar Apr 09 '24 02:04 lizzy-0323

请问如果我要统计标签的种类,是否在调用chi2函数前要求传入num_class这一参数呢,因为需要满足jit的条件,似乎大部分的方法都不可以使用。

是的,,不过建议是手动指定y的具体取值范围(比如传一个list进来),可以作为一个额外参数传入哈~

deadlywing avatar Apr 09 '24 02:04 deadlywing

抱歉我没太懂您的意思,比如y的取值有0,1,3三种时,传入的是[0,2]?还是说把y所有可能的取值去重之后传入呢? 这是我目前的实现,但似乎无法通过spu环境: image

lizzy-0323 avatar Apr 09 '24 02:04 lizzy-0323

抱歉我没太懂您的意思,比如y的取值有0,1,3三种时,传入的是[0,2]?还是说把y所有可能的取值去重之后传入呢? 这是我目前的实现,但似乎无法通过spu环境: image

y的取值有0,1,3: 我的意思是直接传入[0,1,3]; BTW,您可以先在函数上加@jax.jit,只有明文下也是jitable的函数才有可能在SPU下正确执行。

deadlywing avatar Apr 09 '24 03:04 deadlywing

需要对chi2函数也加上@jax.jit吗?确保其中所有的函数都是jitable的? 这样的话,好像jnp.zeros等函数也无法使用

lizzy-0323 avatar Apr 09 '24 05:04 lizzy-0323

您在测试明文实现阶段可以加jit,spu在运行的时候自动会jit,所以最终您不需要在代码里显式的加这个。

您在编程的时候要注意所有的shape应该是确定的,只要不涉及dynamic shape,jnp.zeros等函数是能使用的

deadlywing avatar Apr 09 '24 05:04 deadlywing

目前进展

算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: image

目前遇到的问题

dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:

def chi2(X, y, label_lst):
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # observed = jnp.zeros((num_feature, num_class))
    observed = jnp.dot(X.T, y)
    expected = jnp.zeros((num_feature, num_class))
    # expected = jnp.outer(feature_count, class_prob)
    # print(observed)
    # print(expected)
    chi2_stats = (observed - expected) ** 2 / expected
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    df = num_class - 1
    p_value = sf(chi2_stats, df=df)
    return chi2_stats, p_value

函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里

lizzy-0323 avatar Apr 09 '24 08:04 lizzy-0323

目前进展

算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: image

目前遇到的问题

dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:

def chi2(X, y, label_lst):
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # observed = jnp.zeros((num_feature, num_class))
    observed = jnp.dot(X.T, y)
    expected = jnp.zeros((num_feature, num_class))
    # expected = jnp.outer(feature_count, class_prob)
    # print(observed)
    # print(expected)
    chi2_stats = (observed - expected) ** 2 / expected
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    df = num_class - 1
    p_value = sf(chi2_stats, df=df)
    return chi2_stats, p_value

函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里

hello,抱歉昨天有事没有及时回复;

首先,dot和outer函数都是支持的,您要不直接发起一个PR,把完整的代码推上来,否则只从这个代码片段找不太出来原因-。-

deadlywing avatar Apr 10 '24 02:04 deadlywing

目前进展

算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: image

目前遇到的问题

dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:

def chi2(X, y, label_lst):
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # observed = jnp.zeros((num_feature, num_class))
    observed = jnp.dot(X.T, y)
    expected = jnp.zeros((num_feature, num_class))
    # expected = jnp.outer(feature_count, class_prob)
    # print(observed)
    # print(expected)
    chi2_stats = (observed - expected) ** 2 / expected
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    df = num_class - 1
    p_value = sf(chi2_stats, df=df)
    return chi2_stats, p_value

函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里

hello,抱歉昨天有事没有及时回复;

首先,dot和outer函数都是支持的,您要不直接发起一个PR,把完整的代码推上来,否则只从这个代码片段找不太出来原因-。-

好的,pr已提交,实在麻烦您!

lizzy-0323 avatar Apr 10 '24 04:04 lizzy-0323

经沟通,确认由lizzy-0323持续进行改项目的开发,任务时间放宽至5月中旬,感谢对隐语的关注和支持~

Yeekin-GYJ avatar Apr 16 '24 07:04 Yeekin-GYJ

目前的实现如下,按照scipy源码写了一个python版本

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import load_iris

MACHEP = 0.0000001  # the machine roundoff error / tolerance
BIG = 4.503599627370496e15
BIGINV = 2.22044604925031308085e-16


def _sf(x, df):
    def body_func(x_i):
        if x_i < 0:
            return 1
        if x_i == 0:
            return 0
        if df <= 0:
            raise ValueError("Domain error")
        if x_i < 1.0 or x_i < df:
            return 1.0 - _igam(0.5 * df, 0.5 * x_i)
        return _igamc(df * 0.5, x_i * 0.5)

    return jax.vmap(body_func)(x)


def _igam(a, x):
    # Compute  x**a * exp(-x) / Gamma(a)
    ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a))

    # Power series
    r = a
    c = 1.0
    ans = 1.0

    while True:
        r += 1.0
        c *= x / r
        ans += c
        if c / ans <= MACHEP:
            return ans * ax / a


def _igamc(a, x):
    # Compute  x**a * exp(-x) / Gamma(a) ,ax
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    # print(ax)
    # Continued fraction
    y = 1.0 - a
    z = x + y + 1.0
    c = 0.0
    pkm2 = 1.0
    qkm2 = x
    pkm1 = x + 1.0
    qkm1 = z * x
    ans = pkm1 / qkm1
    while True:
        c += 1.0
        y += 1.0
        z += 2.0
        yc = y * c
        pk = pkm1 * z - pkm2 * yc
        qk = qkm1 * z - qkm2 * yc
        if qk != 0:
            r = pk / qk
            t = abs((ans - r) / r)
            ans = r
        else:
            t = 1.0
        pkm2 = pkm1
        pkm1 = pk
        qkm2 = qkm1
        qkm1 = qk
        if abs(pk) > BIG:
            pkm2 *= BIGINV
            pkm1 *= BIGINV
            qkm2 *= BIGINV
            qkm1 *= BIGINV
        if t <= MACHEP:
            return ans * ax


def chi2(X, y, label_lst):
    """
    Calculate the chi-squared statistic and p-value for feature independence testing.

    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
        The feature matrix from which to calculate the chi-squared statistic. Each row
        represents a sample and each column a feature.

    y : array-like, shape (n_samples,)
        The class labels of each sample as integers.

    label_lst : list of str, length = n_classes
        A list of class labels corresponding to the columns of y.

    Returns
    -------
    chi2_stats : array, shape (n_features,)
        The chi-squared statistic for each feature, indicating the degree of
        association between the feature and the class labels.

    p_value : array, shape (n_features,)
        The p-value associated with each chi-squared statistic, which can be used
        to test the null hypothesis that the features are independent of the
        class labels.
    """
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    # one hot encoding
    y = jnp.eye(num_class)[y]
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # Calculate the observed frequency count
    observed = jnp.dot(X.T, y)
    expected = jnp.outer(feature_count, class_prob)
    # Calculate the chi-squared statistic
    chi2_stats = (observed - expected) ** 2 / expected
    # Sum over class dimensions
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    # Degrees of freedom
    df = num_class - 1
    # Calculate the p-value for each feature
    p_value = _sf(chi2_stats, df=df)
    return chi2_stats, p_value


if __name__ == '__main__':
    x, y = load_iris(return_X_y=True)
    label_lst = np.unique(y)
    chi2_stats, p_value = chi2(x, y, label_lst)
    print(chi2_stats, p_value)
    from sklearn.feature_selection import chi2 as chi2_sklearn

    sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y)
    print(sklearn_chi2_stats, sklearn_p_value)

lizzy-0323 avatar Apr 23 '24 06:04 lizzy-0323

目前的实现如下,按照scipy源码写了一个python版本

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import load_iris

MACHEP = 0.0000001  # the machine roundoff error / tolerance
BIG = 4.503599627370496e15
BIGINV = 2.22044604925031308085e-16


def _sf(x, df):
    def body_func(x_i):
        if x_i < 0:
            return 1
        if x_i == 0:
            return 0
        if df <= 0:
            raise ValueError("Domain error")
        if x_i < 1.0 or x_i < df:
            return 1.0 - _igam(0.5 * df, 0.5 * x_i)
        return _igamc(df * 0.5, x_i * 0.5)

    return jax.vmap(body_func)(x)


def _igam(a, x):
    # Compute  x**a * exp(-x) / Gamma(a)
    ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a))

    # Power series
    r = a
    c = 1.0
    ans = 1.0

    while True:
        r += 1.0
        c *= x / r
        ans += c
        if c / ans <= MACHEP:
            return ans * ax / a


def _igamc(a, x):
    # Compute  x**a * exp(-x) / Gamma(a) ,ax
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    # print(ax)
    # Continued fraction
    y = 1.0 - a
    z = x + y + 1.0
    c = 0.0
    pkm2 = 1.0
    qkm2 = x
    pkm1 = x + 1.0
    qkm1 = z * x
    ans = pkm1 / qkm1
    while True:
        c += 1.0
        y += 1.0
        z += 2.0
        yc = y * c
        pk = pkm1 * z - pkm2 * yc
        qk = qkm1 * z - qkm2 * yc
        if qk != 0:
            r = pk / qk
            t = abs((ans - r) / r)
            ans = r
        else:
            t = 1.0
        pkm2 = pkm1
        pkm1 = pk
        qkm2 = qkm1
        qkm1 = qk
        if abs(pk) > BIG:
            pkm2 *= BIGINV
            pkm1 *= BIGINV
            qkm2 *= BIGINV
            qkm1 *= BIGINV
        if t <= MACHEP:
            return ans * ax


def chi2(X, y, label_lst):
    """
    Calculate the chi-squared statistic and p-value for feature independence testing.

    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
        The feature matrix from which to calculate the chi-squared statistic. Each row
        represents a sample and each column a feature.

    y : array-like, shape (n_samples,)
        The class labels of each sample as integers.

    label_lst : list of str, length = n_classes
        A list of class labels corresponding to the columns of y.

    Returns
    -------
    chi2_stats : array, shape (n_features,)
        The chi-squared statistic for each feature, indicating the degree of
        association between the feature and the class labels.

    p_value : array, shape (n_features,)
        The p-value associated with each chi-squared statistic, which can be used
        to test the null hypothesis that the features are independent of the
        class labels.
    """
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    # one hot encoding
    y = jnp.eye(num_class)[y]
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # Calculate the observed frequency count
    observed = jnp.dot(X.T, y)
    expected = jnp.outer(feature_count, class_prob)
    # Calculate the chi-squared statistic
    chi2_stats = (observed - expected) ** 2 / expected
    # Sum over class dimensions
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    # Degrees of freedom
    df = num_class - 1
    # Calculate the p-value for each feature
    p_value = _sf(chi2_stats, df=df)
    return chi2_stats, p_value


if __name__ == '__main__':
    x, y = load_iris(return_X_y=True)
    label_lst = np.unique(y)
    chi2_stats, p_value = chi2(x, y, label_lst)
    print(chi2_stats, p_value)
    from sklearn.feature_selection import chi2 as chi2_sklearn

    sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y)
    print(sklearn_chi2_stats, sklearn_p_value)

Hi,明文下的实现会依赖eps (即机器精度)动态决定循环次数,SPU不支持,您可以手动决定一个循环次数上限,然后考察算子精度,只要精度ok即可~

deadlywing avatar Apr 23 '24 08:04 deadlywing

目前的实现如下,按照scipy源码写了一个python版本

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import load_iris

MACHEP = 0.0000001  # the machine roundoff error / tolerance
BIG = 4.503599627370496e15
BIGINV = 2.22044604925031308085e-16


def _sf(x, df):
    def body_func(x_i):
        if x_i < 0:
            return 1
        if x_i == 0:
            return 0
        if df <= 0:
            raise ValueError("Domain error")
        if x_i < 1.0 or x_i < df:
            return 1.0 - _igam(0.5 * df, 0.5 * x_i)
        return _igamc(df * 0.5, x_i * 0.5)

    return jax.vmap(body_func)(x)


def _igam(a, x):
    # Compute  x**a * exp(-x) / Gamma(a)
    ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a))

    # Power series
    r = a
    c = 1.0
    ans = 1.0

    while True:
        r += 1.0
        c *= x / r
        ans += c
        if c / ans <= MACHEP:
            return ans * ax / a


def _igamc(a, x):
    # Compute  x**a * exp(-x) / Gamma(a) ,ax
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    # print(ax)
    # Continued fraction
    y = 1.0 - a
    z = x + y + 1.0
    c = 0.0
    pkm2 = 1.0
    qkm2 = x
    pkm1 = x + 1.0
    qkm1 = z * x
    ans = pkm1 / qkm1
    while True:
        c += 1.0
        y += 1.0
        z += 2.0
        yc = y * c
        pk = pkm1 * z - pkm2 * yc
        qk = qkm1 * z - qkm2 * yc
        if qk != 0:
            r = pk / qk
            t = abs((ans - r) / r)
            ans = r
        else:
            t = 1.0
        pkm2 = pkm1
        pkm1 = pk
        qkm2 = qkm1
        qkm1 = qk
        if abs(pk) > BIG:
            pkm2 *= BIGINV
            pkm1 *= BIGINV
            qkm2 *= BIGINV
            qkm1 *= BIGINV
        if t <= MACHEP:
            return ans * ax


def chi2(X, y, label_lst):
    """
    Calculate the chi-squared statistic and p-value for feature independence testing.

    Parameters
    ----------
    X : array-like, shape (n_samples, n_features)
        The feature matrix from which to calculate the chi-squared statistic. Each row
        represents a sample and each column a feature.

    y : array-like, shape (n_samples,)
        The class labels of each sample as integers.

    label_lst : list of str, length = n_classes
        A list of class labels corresponding to the columns of y.

    Returns
    -------
    chi2_stats : array, shape (n_features,)
        The chi-squared statistic for each feature, indicating the degree of
        association between the feature and the class labels.

    p_value : array, shape (n_features,)
        The p-value associated with each chi-squared statistic, which can be used
        to test the null hypothesis that the features are independent of the
        class labels.
    """
    total_samples = X.shape[0]
    num_feature = X.shape[1]
    num_class = len(label_lst)
    # one hot encoding
    y = jnp.eye(num_class)[y]
    X = jnp.array(X)
    y = jnp.array(y)
    feature_count = jnp.sum(X, axis=0)
    class_prob = jnp.mean(y, axis=0)
    # Calculate the observed frequency count
    observed = jnp.dot(X.T, y)
    expected = jnp.outer(feature_count, class_prob)
    # Calculate the chi-squared statistic
    chi2_stats = (observed - expected) ** 2 / expected
    # Sum over class dimensions
    chi2_stats = jnp.nansum(chi2_stats, axis=1)
    # Degrees of freedom
    df = num_class - 1
    # Calculate the p-value for each feature
    p_value = _sf(chi2_stats, df=df)
    return chi2_stats, p_value


if __name__ == '__main__':
    x, y = load_iris(return_X_y=True)
    label_lst = np.unique(y)
    chi2_stats, p_value = chi2(x, y, label_lst)
    print(chi2_stats, p_value)
    from sklearn.feature_selection import chi2 as chi2_sklearn

    sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y)
    print(sklearn_chi2_stats, sklearn_p_value)

Hi,明文下的实现会依赖eps (即机器精度)动态决定循环次数,SPU不支持,您可以手动决定一个循环次数上限,然后考察算子精度,只要精度ok即可~

好的,不过可以麻烦您帮我看一下_sf()函数中vmap的使用方法是否正确吗? 我想按照scipy里的思路,根据x和df的大小关系,选择不同的方法(连分数或者级数)来进行计算,但是采用vmap后并不能执行条件判断语句,这方面我不太懂有什么方法去实现。

lizzy-0323 avatar Apr 23 '24 08:04 lizzy-0323

  1. 首先,你这里应该不需要vmap,,x只是一个2维array,处理起来并不复杂。
  2. 需要做判断的地方,如选择_igam或者_igamc的时候,可以参考使用jnp.select

deadlywing avatar Apr 23 '24 08:04 deadlywing

目前的实现

MAX_ITER = 5
def _sf(x, df):
    if df <= 0:
        raise ValueError("Domain error")
    condlist = [x < 0, x == 0, jnp.logical_or(x < 1, x < df)]
    choicelist = [
        jnp.ones_like(x),
        jnp.zeros_like(x),
        1 - _igam(0.5 * df, 0.5 * x),
    ]
    result = jnp.select(condlist, choicelist, default=_igamc(df * 0.5, x * 0.5))
    return result


def _igam(a, x):
    # series
    # Compute  x**a * exp(-x) / Gamma(a)
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    # Power series
    r = jnp.full_like(x, a)
    c = jnp.ones_like(x)
    ans = jnp.ones_like(x)

    def loop_body(_, val):
        r, c, ans = val
        r += 1.0
        c *= x / r
        ans += c
        return (r, c, ans)

    init_val = (r, c, ans)
    _, _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
    return ans * ax / a


def _igamc(a, x):
    # Continued fraction
    # Compute ax = x**a * exp(-x) / Gamma(a)
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    y = jnp.ones_like(x) - a
    z = x + y + 1.0
    c = jnp.zeros_like(x)
    pkm2 = jnp.ones_like(x)
    qkm2 = jnp.full_like(x, a)
    pkm1 = x + 1.0
    qkm1 = z * x
    ans = pkm1 / qkm1

    def loop_body(_, val):
        (c, y, z, pkm1, pkm2, qkm1, qkm2), ans = val
        c += 1.0
        y += 1.0
        z += 2.0
        yc = y * c
        # 连分数的一项
        pk = pkm1 * z - pkm2 * yc
        qk = qkm1 * z - qkm2 * yc
        r = pk / qk
        ans = r
        pkm2 = pkm1
        pkm1 = pk
        qkm2 = qkm1
        qkm1 = qk
        return (c, y, z, pkm1, pkm2, qkm1, qkm2), ans

    params = (c, y, z, pkm1, pkm2, qkm1, qkm2)
    init_val = (params, ans)
    _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
    return ans * ax

明文测试结果

image

SPU测试结果

image

问题

目前结果如上,好像还有一些数值溢出问题,请问这部分有没有什么解决思路,已经尝试过降低循环上限,但是后两个值还是算不出来。

lizzy-0323 avatar Apr 24 '24 03:04 lizzy-0323

目前的实现

MAX_ITER = 5
def _sf(x, df):
    if df <= 0:
        raise ValueError("Domain error")
    condlist = [x < 0, x == 0, jnp.logical_or(x < 1, x < df)]
    choicelist = [
        jnp.ones_like(x),
        jnp.zeros_like(x),
        1 - _igam(0.5 * df, 0.5 * x),
    ]
    result = jnp.select(condlist, choicelist, default=_igamc(df * 0.5, x * 0.5))
    return result


def _igam(a, x):
    # series
    # Compute  x**a * exp(-x) / Gamma(a)
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    # Power series
    r = jnp.full_like(x, a)
    c = jnp.ones_like(x)
    ans = jnp.ones_like(x)

    def loop_body(_, val):
        r, c, ans = val
        r += 1.0
        c *= x / r
        ans += c
        return (r, c, ans)

    init_val = (r, c, ans)
    _, _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
    return ans * ax / a


def _igamc(a, x):
    # Continued fraction
    # Compute ax = x**a * exp(-x) / Gamma(a)
    ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
    y = jnp.ones_like(x) - a
    z = x + y + 1.0
    c = jnp.zeros_like(x)
    pkm2 = jnp.ones_like(x)
    qkm2 = jnp.full_like(x, a)
    pkm1 = x + 1.0
    qkm1 = z * x
    ans = pkm1 / qkm1

    def loop_body(_, val):
        (c, y, z, pkm1, pkm2, qkm1, qkm2), ans = val
        c += 1.0
        y += 1.0
        z += 2.0
        yc = y * c
        # 连分数的一项
        pk = pkm1 * z - pkm2 * yc
        qk = qkm1 * z - qkm2 * yc
        r = pk / qk
        ans = r
        pkm2 = pkm1
        pkm1 = pk
        qkm2 = qkm1
        qkm1 = qk
        return (c, y, z, pkm1, pkm2, qkm1, qkm2), ans

    params = (c, y, z, pkm1, pkm2, qkm1, qkm2)
    init_val = (params, ans)
    _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
    return ans * ax

明文测试结果

image ## SPU测试结果 image ## 问题 目前结果如上,好像还有一些数值溢出问题,请问这部分有没有什么解决思路,已经尝试过降低循环上限,但是后两个值还是算不出来。

溢出的原因是连分数法中pk和qk的数值增长比较快,可以考虑:

  1. 减少MAX_ITER到3 (在这个例子中)
  2. 把环改成FM128

建议:

  1. MAX_ITER 作为一个超参数,可供用户选择
  2. 在chi2的函数文档处说明一下pvalue的计算容易溢出 (e.g. 可以建议用户使用FM128,或者适当减少MAX_ITER)
  3. pvalue可以选择性的计算(这虽然和sklearn的行为不同,但我觉得还是有意义的,因为chi统计量的计算非常cheap,但是pvalue计算非常昂贵,而二者的作用几乎是等价的,如果用于feature selection,完全可以只计算前者,不过这个也需要在comment里说明)
  4. 有一些简单的优化可以做一下:比如jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a)),在a不为整数时较高效,但a为整数时,jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a)) 可以节省一次log算子;(由于a是明文,故这个优化是可以做的)
  5. api简化:比如label_lst似乎不必要,只需要class的个数即可;另外,关于y的格式也需要做解释,当前实现应该只支持y的取值是0,1,2...这种格式的,这个也需要明确说明出来。

deadlywing avatar Apr 24 '24 07:04 deadlywing

您的建议我已采纳,将max_iter设置为0,环改成fm128。尝试后发现,即使不进入循环,也会发生下溢现象,出现在计算ax语句

ax = jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a))

具体是计算jnp.exp(-x)时发生了下溢,这里我尝试采用jnp.exp(-(x-jnp.max(x))进行修正,但仍旧会出现溢出

lizzy-0323 avatar Apr 25 '24 05:04 lizzy-0323

下溢是正常的,很小的数直接变成0也合理啊

deadlywing avatar Apr 25 '24 05:04 deadlywing