Chris Li

Results 24 comments of Chris Li

需要对chi2函数也加上@jax.jit吗?确保其中所有的函数都是jitable的? 这样的话,好像jnp.zeros等函数也无法使用

## 目前进展 算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: ![image](https://github.com/secretflow/spu/assets/58327725/9e4b06c4-5f07-4946-a936-ca200a3457f7) ## 目前遇到的问题 dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示: ```python def chi2(X, y, label_lst): total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X,...

> > ## 目前进展 > > 算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: ![image](https://private-user-images.githubusercontent.com/58327725/320768856-9e4b06c4-5f07-4946-a936-ca200a3457f7.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTI3MTUxODYsIm5iZiI6MTcxMjcxNDg4NiwicGF0aCI6Ii81ODMyNzcyNS8zMjA3Njg4NTYtOWU0YjA2YzQtNWYwNy00OTQ2LWE5MzYtY2EyMDBhMzQ1N2Y3LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA0MTAlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNDEwVDAyMDgwNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTBkNzI4ZDZkNTMwMmQ1NGNjZDdiOGY0ZjI4NjM4ODNiNDc3MWQyYTJiNTUwZTI4OWRjZTdhODQ5MmFkNDdjZDkmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.ChVseo9NgSkQRqQUZqVpomkWkFunamWUyPWyszkjqEA) > > ## 目前遇到的问题 > > dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示: > > ```python > > def chi2(X, y, label_lst): > > total_samples = X.shape[0] >...

目前的实现如下,按照scipy源码写了一个python版本 ```python import jax import jax.numpy as jnp import numpy as np from sklearn.datasets import load_iris MACHEP = 0.0000001 # the machine roundoff error / tolerance BIG = 4.503599627370496e15 BIGINV...

> > 目前的实现如下,按照scipy源码写了一个python版本 > > ```python > > import jax > > import jax.numpy as jnp > > import numpy as np > > from sklearn.datasets import load_iris > >...

## 目前的实现 ```python MAX_ITER = 5 def _sf(x, df): if df

您的建议我已采纳,将max_iter设置为0,环改成fm128。尝试后发现,即使不进入循环,也会发生下溢现象,出现在计算ax语句 ```python ax = jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a)) ``` 具体是计算`jnp.exp(-x)`时发生了下溢,这里我尝试采用`jnp.exp(-(x-jnp.max(x))`进行修正,但仍旧会出现溢出

> why closed it? sry, commit record may be a bit confusing, please check it first

> I tried to test it but failed. Please let me know if I missed anything. > > ```shell > atest run --report grpc -p sample/testsuite-gitee.yaml --report-dest 127.0.0.1:7070/writer_templates.ReportWriter/SendReportResult > 2024-05-17T04:33:56.903Z...

> LGTM > > And welcome to be the 21st contributor! I'm very glad to contribute to the project!