spu
spu copied to clipboard
使用SPU实现chi-squared统计量
此 ISSUE 为 隐语开源共建计划(SecretFlow Open Source Contribution Plan,简称 SF OSCP)任务 ISSUE,欢迎社区开发者参与共建~ 若有感兴趣想要认领的任务,但还未报名,辛苦先完成报名进行哈~
任务介绍
- 任务名称:使用SPU实现chi-squared统计量
- 技术方向:SPU/SML
- 任务难度:进阶🌟🌟
- 任务完成时间:4周
详细要求
- 安全性(尽量少 reveal)
- 功能性:同时返回chi2统计量和对应的p value
- 正确性:与sklearn的结果数值上基本一致(在可接受误差范围内)
- 代码规范:Python 代码需要使用 black+isort 进行格式化(流水线包含代码规范检查卡点)
- 提交说明:关联该 isuue 并提交代码至 https://github.com/secretflow/spu/tree/main/sml(具体提交文件夹可与reviewer讨论)
- 特殊说明:若某个特性有特殊的限制,如需要 FM128,需要更多 fxp 等需要在注释文档中明确说明
能力要求
- 熟悉经典的统计学理论
- 熟悉 JAX 或 NumPy,可以使用 NumPy 实现算法
操作说明
- 可参考sklearn中的说明: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.feature_selection
- 操作指引:https://www.secretflow.org.cn/docs/spu/latest/en-US/getting_started/tutorials/develop_your_first_mpc_application
- 范例:https://github.com/secretflow/spu/blob/main/sml/linear_model/simple_sgd.py
- 可参考隐语mooc2期视频:
- https://www.bilibili.com/video/BV1ba4y1S73z/?spm_id_from=333.788
- https://www.bilibili.com/video/BV1Qe411D7MG/?spm_id_from=333.788
- 贡献指引文档: https://github.com/secretflow/spu/blob/main/sml/development.md
lizzy-0323 Give it to me
Please tell me which folder should I submit to?
Please tell me which folder should I submit to?
hello, you can first create a new folder named "feature_selection", with a new file named "univariate_selection.py", and then implement the chi2 function in this file.
In other words, the chi2 should be implemented in sml/feature_selection/univariate_selection.py;
(BTW, don't forget the essential tests/emulations folders and the bazel files.)
Best wishes
Thank you!
关于docker环境上无法运行编译运行的问题
os
Ubuntu 20.04.6 LTS
运行命令
bazel run -c opt //sml/cluster/tests:kmeans_test
报错
关于docker环境上无法运行编译运行的问题
os
Ubuntu 20.04.6 LTS
运行命令
bazel run -c opt //sml/cluster/tests:kmeans_test
报错
ps:前置库均已安装完毕,貌似是网络问题?但没有找到解决方法0.0
请确保 docker 可以正常访问 github
docker是可以正常访问github的,clone等操作也没问题,并且使用了代理。
docker是可以正常访问github的,clone等操作也没问题,并且使用了代理。
hmmm, 看起来和上周末 xz 发生的故事有点关系,我来看看
@lizzy-0323 main 上最新的代码再试一下?
@lizzy-0323 main 上最新的代码再试一下?
好的 已经成功运行起来啦!
请问如果我要统计标签的种类,是否在调用chi2函数前要求传入num_class这一参数呢,因为需要满足jit的条件,似乎大部分的方法都不可以使用。
请问如果我要统计标签的种类,是否在调用chi2函数前要求传入num_class这一参数呢,因为需要满足jit的条件,似乎大部分的方法都不可以使用。
是的,,不过建议是手动指定y的具体取值范围(比如传一个list进来),可以作为一个额外参数传入哈~
抱歉我没太懂您的意思,比如y的取值有0,1,3三种时,传入的是[0,2]?还是说把y所有可能的取值去重之后传入呢?
这是我目前的实现,但似乎无法通过spu环境:
抱歉我没太懂您的意思,比如y的取值有0,1,3三种时,传入的是[0,2]?还是说把y所有可能的取值去重之后传入呢? 这是我目前的实现,但似乎无法通过spu环境:
y的取值有0,1,3: 我的意思是直接传入[0,1,3]; BTW,您可以先在函数上加@jax.jit,只有明文下也是jitable的函数才有可能在SPU下正确执行。
需要对chi2函数也加上@jax.jit吗?确保其中所有的函数都是jitable的? 这样的话,好像jnp.zeros等函数也无法使用
您在测试明文实现阶段可以加jit,spu在运行的时候自动会jit,所以最终您不需要在代码里显式的加这个。
您在编程的时候要注意所有的shape应该是确定的,只要不涉及dynamic shape,jnp.zeros等函数是能使用的
目前进展
算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果:
目前遇到的问题
dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:
def chi2(X, y, label_lst):
total_samples = X.shape[0]
num_feature = X.shape[1]
num_class = len(label_lst)
X = jnp.array(X)
y = jnp.array(y)
feature_count = jnp.sum(X, axis=0)
class_prob = jnp.mean(y, axis=0)
# observed = jnp.zeros((num_feature, num_class))
observed = jnp.dot(X.T, y)
expected = jnp.zeros((num_feature, num_class))
# expected = jnp.outer(feature_count, class_prob)
# print(observed)
# print(expected)
chi2_stats = (observed - expected) ** 2 / expected
chi2_stats = jnp.nansum(chi2_stats, axis=1)
df = num_class - 1
p_value = sf(chi2_stats, df=df)
return chi2_stats, p_value
函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里
目前进展
算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果:
目前遇到的问题
dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:
def chi2(X, y, label_lst): total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X, axis=0) class_prob = jnp.mean(y, axis=0) # observed = jnp.zeros((num_feature, num_class)) observed = jnp.dot(X.T, y) expected = jnp.zeros((num_feature, num_class)) # expected = jnp.outer(feature_count, class_prob) # print(observed) # print(expected) chi2_stats = (observed - expected) ** 2 / expected chi2_stats = jnp.nansum(chi2_stats, axis=1) df = num_class - 1 p_value = sf(chi2_stats, df=df) return chi2_stats, p_value
函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里
hello,抱歉昨天有事没有及时回复;
首先,dot和outer函数都是支持的,您要不直接发起一个PR,把完整的代码推上来,否则只从这个代码片段找不太出来原因-。-
目前进展
算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果:
目前遇到的问题
dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示:
def chi2(X, y, label_lst): total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X, axis=0) class_prob = jnp.mean(y, axis=0) # observed = jnp.zeros((num_feature, num_class)) observed = jnp.dot(X.T, y) expected = jnp.zeros((num_feature, num_class)) # expected = jnp.outer(feature_count, class_prob) # print(observed) # print(expected) chi2_stats = (observed - expected) ** 2 / expected chi2_stats = jnp.nansum(chi2_stats, axis=1) df = num_class - 1 p_value = sf(chi2_stats, df=df) return chi2_stats, p_value
函数参数:y数组传入的是一个one-hot编码后的数组,并还需要传入一个label数组作为辅助 其中,为了确认两个函数不能使用,在测试时选择对另一个数组采用zeros初始化的方法。 想请教一下问题大概出在哪里
hello,抱歉昨天有事没有及时回复;
首先,dot和outer函数都是支持的,您要不直接发起一个PR,把完整的代码推上来,否则只从这个代码片段找不太出来原因-。-
好的,pr已提交,实在麻烦您!
经沟通,确认由lizzy-0323持续进行改项目的开发,任务时间放宽至5月中旬,感谢对隐语的关注和支持~
目前的实现如下,按照scipy源码写了一个python版本
import jax
import jax.numpy as jnp
import numpy as np
from sklearn.datasets import load_iris
MACHEP = 0.0000001 # the machine roundoff error / tolerance
BIG = 4.503599627370496e15
BIGINV = 2.22044604925031308085e-16
def _sf(x, df):
def body_func(x_i):
if x_i < 0:
return 1
if x_i == 0:
return 0
if df <= 0:
raise ValueError("Domain error")
if x_i < 1.0 or x_i < df:
return 1.0 - _igam(0.5 * df, 0.5 * x_i)
return _igamc(df * 0.5, x_i * 0.5)
return jax.vmap(body_func)(x)
def _igam(a, x):
# Compute x**a * exp(-x) / Gamma(a)
ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a))
# Power series
r = a
c = 1.0
ans = 1.0
while True:
r += 1.0
c *= x / r
ans += c
if c / ans <= MACHEP:
return ans * ax / a
def _igamc(a, x):
# Compute x**a * exp(-x) / Gamma(a) ,ax
ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
# print(ax)
# Continued fraction
y = 1.0 - a
z = x + y + 1.0
c = 0.0
pkm2 = 1.0
qkm2 = x
pkm1 = x + 1.0
qkm1 = z * x
ans = pkm1 / qkm1
while True:
c += 1.0
y += 1.0
z += 2.0
yc = y * c
pk = pkm1 * z - pkm2 * yc
qk = qkm1 * z - qkm2 * yc
if qk != 0:
r = pk / qk
t = abs((ans - r) / r)
ans = r
else:
t = 1.0
pkm2 = pkm1
pkm1 = pk
qkm2 = qkm1
qkm1 = qk
if abs(pk) > BIG:
pkm2 *= BIGINV
pkm1 *= BIGINV
qkm2 *= BIGINV
qkm1 *= BIGINV
if t <= MACHEP:
return ans * ax
def chi2(X, y, label_lst):
"""
Calculate the chi-squared statistic and p-value for feature independence testing.
Parameters
----------
X : array-like, shape (n_samples, n_features)
The feature matrix from which to calculate the chi-squared statistic. Each row
represents a sample and each column a feature.
y : array-like, shape (n_samples,)
The class labels of each sample as integers.
label_lst : list of str, length = n_classes
A list of class labels corresponding to the columns of y.
Returns
-------
chi2_stats : array, shape (n_features,)
The chi-squared statistic for each feature, indicating the degree of
association between the feature and the class labels.
p_value : array, shape (n_features,)
The p-value associated with each chi-squared statistic, which can be used
to test the null hypothesis that the features are independent of the
class labels.
"""
total_samples = X.shape[0]
num_feature = X.shape[1]
num_class = len(label_lst)
# one hot encoding
y = jnp.eye(num_class)[y]
X = jnp.array(X)
y = jnp.array(y)
feature_count = jnp.sum(X, axis=0)
class_prob = jnp.mean(y, axis=0)
# Calculate the observed frequency count
observed = jnp.dot(X.T, y)
expected = jnp.outer(feature_count, class_prob)
# Calculate the chi-squared statistic
chi2_stats = (observed - expected) ** 2 / expected
# Sum over class dimensions
chi2_stats = jnp.nansum(chi2_stats, axis=1)
# Degrees of freedom
df = num_class - 1
# Calculate the p-value for each feature
p_value = _sf(chi2_stats, df=df)
return chi2_stats, p_value
if __name__ == '__main__':
x, y = load_iris(return_X_y=True)
label_lst = np.unique(y)
chi2_stats, p_value = chi2(x, y, label_lst)
print(chi2_stats, p_value)
from sklearn.feature_selection import chi2 as chi2_sklearn
sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y)
print(sklearn_chi2_stats, sklearn_p_value)
目前的实现如下,按照scipy源码写了一个python版本
import jax import jax.numpy as jnp import numpy as np from sklearn.datasets import load_iris MACHEP = 0.0000001 # the machine roundoff error / tolerance BIG = 4.503599627370496e15 BIGINV = 2.22044604925031308085e-16 def _sf(x, df): def body_func(x_i): if x_i < 0: return 1 if x_i == 0: return 0 if df <= 0: raise ValueError("Domain error") if x_i < 1.0 or x_i < df: return 1.0 - _igam(0.5 * df, 0.5 * x_i) return _igamc(df * 0.5, x_i * 0.5) return jax.vmap(body_func)(x) def _igam(a, x): # Compute x**a * exp(-x) / Gamma(a) ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a)) # Power series r = a c = 1.0 ans = 1.0 while True: r += 1.0 c *= x / r ans += c if c / ans <= MACHEP: return ans * ax / a def _igamc(a, x): # Compute x**a * exp(-x) / Gamma(a) ,ax ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a)) # print(ax) # Continued fraction y = 1.0 - a z = x + y + 1.0 c = 0.0 pkm2 = 1.0 qkm2 = x pkm1 = x + 1.0 qkm1 = z * x ans = pkm1 / qkm1 while True: c += 1.0 y += 1.0 z += 2.0 yc = y * c pk = pkm1 * z - pkm2 * yc qk = qkm1 * z - qkm2 * yc if qk != 0: r = pk / qk t = abs((ans - r) / r) ans = r else: t = 1.0 pkm2 = pkm1 pkm1 = pk qkm2 = qkm1 qkm1 = qk if abs(pk) > BIG: pkm2 *= BIGINV pkm1 *= BIGINV qkm2 *= BIGINV qkm1 *= BIGINV if t <= MACHEP: return ans * ax def chi2(X, y, label_lst): """ Calculate the chi-squared statistic and p-value for feature independence testing. Parameters ---------- X : array-like, shape (n_samples, n_features) The feature matrix from which to calculate the chi-squared statistic. Each row represents a sample and each column a feature. y : array-like, shape (n_samples,) The class labels of each sample as integers. label_lst : list of str, length = n_classes A list of class labels corresponding to the columns of y. Returns ------- chi2_stats : array, shape (n_features,) The chi-squared statistic for each feature, indicating the degree of association between the feature and the class labels. p_value : array, shape (n_features,) The p-value associated with each chi-squared statistic, which can be used to test the null hypothesis that the features are independent of the class labels. """ total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) # one hot encoding y = jnp.eye(num_class)[y] X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X, axis=0) class_prob = jnp.mean(y, axis=0) # Calculate the observed frequency count observed = jnp.dot(X.T, y) expected = jnp.outer(feature_count, class_prob) # Calculate the chi-squared statistic chi2_stats = (observed - expected) ** 2 / expected # Sum over class dimensions chi2_stats = jnp.nansum(chi2_stats, axis=1) # Degrees of freedom df = num_class - 1 # Calculate the p-value for each feature p_value = _sf(chi2_stats, df=df) return chi2_stats, p_value if __name__ == '__main__': x, y = load_iris(return_X_y=True) label_lst = np.unique(y) chi2_stats, p_value = chi2(x, y, label_lst) print(chi2_stats, p_value) from sklearn.feature_selection import chi2 as chi2_sklearn sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y) print(sklearn_chi2_stats, sklearn_p_value)
Hi,明文下的实现会依赖eps (即机器精度)动态决定循环次数,SPU不支持,您可以手动决定一个循环次数上限,然后考察算子精度,只要精度ok即可~
目前的实现如下,按照scipy源码写了一个python版本
import jax import jax.numpy as jnp import numpy as np from sklearn.datasets import load_iris MACHEP = 0.0000001 # the machine roundoff error / tolerance BIG = 4.503599627370496e15 BIGINV = 2.22044604925031308085e-16 def _sf(x, df): def body_func(x_i): if x_i < 0: return 1 if x_i == 0: return 0 if df <= 0: raise ValueError("Domain error") if x_i < 1.0 or x_i < df: return 1.0 - _igam(0.5 * df, 0.5 * x_i) return _igamc(df * 0.5, x_i * 0.5) return jax.vmap(body_func)(x) def _igam(a, x): # Compute x**a * exp(-x) / Gamma(a) ax = jax.exp(a * jax.log(x) - x - jax.lgamma(a)) # Power series r = a c = 1.0 ans = 1.0 while True: r += 1.0 c *= x / r ans += c if c / ans <= MACHEP: return ans * ax / a def _igamc(a, x): # Compute x**a * exp(-x) / Gamma(a) ,ax ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a)) # print(ax) # Continued fraction y = 1.0 - a z = x + y + 1.0 c = 0.0 pkm2 = 1.0 qkm2 = x pkm1 = x + 1.0 qkm1 = z * x ans = pkm1 / qkm1 while True: c += 1.0 y += 1.0 z += 2.0 yc = y * c pk = pkm1 * z - pkm2 * yc qk = qkm1 * z - qkm2 * yc if qk != 0: r = pk / qk t = abs((ans - r) / r) ans = r else: t = 1.0 pkm2 = pkm1 pkm1 = pk qkm2 = qkm1 qkm1 = qk if abs(pk) > BIG: pkm2 *= BIGINV pkm1 *= BIGINV qkm2 *= BIGINV qkm1 *= BIGINV if t <= MACHEP: return ans * ax def chi2(X, y, label_lst): """ Calculate the chi-squared statistic and p-value for feature independence testing. Parameters ---------- X : array-like, shape (n_samples, n_features) The feature matrix from which to calculate the chi-squared statistic. Each row represents a sample and each column a feature. y : array-like, shape (n_samples,) The class labels of each sample as integers. label_lst : list of str, length = n_classes A list of class labels corresponding to the columns of y. Returns ------- chi2_stats : array, shape (n_features,) The chi-squared statistic for each feature, indicating the degree of association between the feature and the class labels. p_value : array, shape (n_features,) The p-value associated with each chi-squared statistic, which can be used to test the null hypothesis that the features are independent of the class labels. """ total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) # one hot encoding y = jnp.eye(num_class)[y] X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X, axis=0) class_prob = jnp.mean(y, axis=0) # Calculate the observed frequency count observed = jnp.dot(X.T, y) expected = jnp.outer(feature_count, class_prob) # Calculate the chi-squared statistic chi2_stats = (observed - expected) ** 2 / expected # Sum over class dimensions chi2_stats = jnp.nansum(chi2_stats, axis=1) # Degrees of freedom df = num_class - 1 # Calculate the p-value for each feature p_value = _sf(chi2_stats, df=df) return chi2_stats, p_value if __name__ == '__main__': x, y = load_iris(return_X_y=True) label_lst = np.unique(y) chi2_stats, p_value = chi2(x, y, label_lst) print(chi2_stats, p_value) from sklearn.feature_selection import chi2 as chi2_sklearn sklearn_chi2_stats, sklearn_p_value = chi2_sklearn(x, y) print(sklearn_chi2_stats, sklearn_p_value)
Hi,明文下的实现会依赖eps (即机器精度)动态决定循环次数,SPU不支持,您可以手动决定一个循环次数上限,然后考察算子精度,只要精度ok即可~
好的,不过可以麻烦您帮我看一下_sf()函数中vmap的使用方法是否正确吗? 我想按照scipy里的思路,根据x和df的大小关系,选择不同的方法(连分数或者级数)来进行计算,但是采用vmap后并不能执行条件判断语句,这方面我不太懂有什么方法去实现。
- 首先,你这里应该不需要vmap,,x只是一个2维array,处理起来并不复杂。
- 需要做判断的地方,如选择_igam或者_igamc的时候,可以参考使用
jnp.select
目前的实现
MAX_ITER = 5
def _sf(x, df):
if df <= 0:
raise ValueError("Domain error")
condlist = [x < 0, x == 0, jnp.logical_or(x < 1, x < df)]
choicelist = [
jnp.ones_like(x),
jnp.zeros_like(x),
1 - _igam(0.5 * df, 0.5 * x),
]
result = jnp.select(condlist, choicelist, default=_igamc(df * 0.5, x * 0.5))
return result
def _igam(a, x):
# series
# Compute x**a * exp(-x) / Gamma(a)
ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
# Power series
r = jnp.full_like(x, a)
c = jnp.ones_like(x)
ans = jnp.ones_like(x)
def loop_body(_, val):
r, c, ans = val
r += 1.0
c *= x / r
ans += c
return (r, c, ans)
init_val = (r, c, ans)
_, _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
return ans * ax / a
def _igamc(a, x):
# Continued fraction
# Compute ax = x**a * exp(-x) / Gamma(a)
ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
y = jnp.ones_like(x) - a
z = x + y + 1.0
c = jnp.zeros_like(x)
pkm2 = jnp.ones_like(x)
qkm2 = jnp.full_like(x, a)
pkm1 = x + 1.0
qkm1 = z * x
ans = pkm1 / qkm1
def loop_body(_, val):
(c, y, z, pkm1, pkm2, qkm1, qkm2), ans = val
c += 1.0
y += 1.0
z += 2.0
yc = y * c
# 连分数的一项
pk = pkm1 * z - pkm2 * yc
qk = qkm1 * z - qkm2 * yc
r = pk / qk
ans = r
pkm2 = pkm1
pkm1 = pk
qkm2 = qkm1
qkm1 = qk
return (c, y, z, pkm1, pkm2, qkm1, qkm2), ans
params = (c, y, z, pkm1, pkm2, qkm1, qkm2)
init_val = (params, ans)
_, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val)
return ans * ax
明文测试结果
SPU测试结果
问题
目前结果如上,好像还有一些数值溢出问题,请问这部分有没有什么解决思路,已经尝试过降低循环上限,但是后两个值还是算不出来。
目前的实现
MAX_ITER = 5 def _sf(x, df): if df <= 0: raise ValueError("Domain error") condlist = [x < 0, x == 0, jnp.logical_or(x < 1, x < df)] choicelist = [ jnp.ones_like(x), jnp.zeros_like(x), 1 - _igam(0.5 * df, 0.5 * x), ] result = jnp.select(condlist, choicelist, default=_igamc(df * 0.5, x * 0.5)) return result def _igam(a, x): # series # Compute x**a * exp(-x) / Gamma(a) ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a)) # Power series r = jnp.full_like(x, a) c = jnp.ones_like(x) ans = jnp.ones_like(x) def loop_body(_, val): r, c, ans = val r += 1.0 c *= x / r ans += c return (r, c, ans) init_val = (r, c, ans) _, _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val) return ans * ax / a def _igamc(a, x): # Continued fraction # Compute ax = x**a * exp(-x) / Gamma(a) ax = jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a)) y = jnp.ones_like(x) - a z = x + y + 1.0 c = jnp.zeros_like(x) pkm2 = jnp.ones_like(x) qkm2 = jnp.full_like(x, a) pkm1 = x + 1.0 qkm1 = z * x ans = pkm1 / qkm1 def loop_body(_, val): (c, y, z, pkm1, pkm2, qkm1, qkm2), ans = val c += 1.0 y += 1.0 z += 2.0 yc = y * c # 连分数的一项 pk = pkm1 * z - pkm2 * yc qk = qkm1 * z - qkm2 * yc r = pk / qk ans = r pkm2 = pkm1 pkm1 = pk qkm2 = qkm1 qkm1 = qk return (c, y, z, pkm1, pkm2, qkm1, qkm2), ans params = (c, y, z, pkm1, pkm2, qkm1, qkm2) init_val = (params, ans) _, ans = jax.lax.fori_loop(0, MAX_ITER, loop_body, init_val) return ans * ax
明文测试结果
## SPU测试结果
## 问题 目前结果如上,好像还有一些数值溢出问题,请问这部分有没有什么解决思路,已经尝试过降低循环上限,但是后两个值还是算不出来。
溢出的原因是连分数法中pk和qk的数值增长比较快,可以考虑:
- 减少MAX_ITER到3 (在这个例子中)
- 把环改成FM128
建议:
- MAX_ITER 作为一个超参数,可供用户选择
- 在chi2的函数文档处说明一下pvalue的计算容易溢出 (e.g. 可以建议用户使用FM128,或者适当减少MAX_ITER)
- pvalue可以选择性的计算(这虽然和sklearn的行为不同,但我觉得还是有意义的,因为chi统计量的计算非常cheap,但是pvalue计算非常昂贵,而二者的作用几乎是等价的,如果用于feature selection,完全可以只计算前者,不过这个也需要在comment里说明)
- 有一些简单的优化可以做一下:比如
jnp.exp(a * jnp.log(x) - x - jax.lax.lgamma(a))
,在a不为整数时较高效,但a为整数时,jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a))
可以节省一次log算子;(由于a是明文,故这个优化是可以做的) - api简化:比如
label_lst
似乎不必要,只需要class的个数即可;另外,关于y的格式也需要做解释,当前实现应该只支持y的取值是0,1,2...这种格式的,这个也需要明确说明出来。
您的建议我已采纳,将max_iter设置为0,环改成fm128。尝试后发现,即使不进入循环,也会发生下溢现象,出现在计算ax语句
ax = jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a))
具体是计算jnp.exp(-x)
时发生了下溢,这里我尝试采用jnp.exp(-(x-jnp.max(x))
进行修正,但仍旧会出现溢出
下溢是正常的,很小的数直接变成0也合理啊