Chris Li
Chris Li
需要对chi2函数也加上@jax.jit吗?确保其中所有的函数都是jitable的? 这样的话,好像jnp.zeros等函数也无法使用
## 目前进展 算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: data:image/s3,"s3://crabby-images/a7752/a7752ed595498eff113851dbfd75f723e64be0aa" alt="image" ## 目前遇到的问题 dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示: ```python def chi2(X, y, label_lst): total_samples = X.shape[0] num_feature = X.shape[1] num_class = len(label_lst) X = jnp.array(X) y = jnp.array(y) feature_count = jnp.sum(X,...
> > ## 目前进展 > > 算法明文测试通过,和sklearn结果基本一致,以下是两种方法得到的结果: data:image/s3,"s3://crabby-images/91916/919166feb3e54343220e1cb4262b39676846e9cd" alt="image" > > ## 目前遇到的问题 > > dot函数和outer函数均不能使用,查阅文档后发现spu是支持dot函数的,outer函数没有查到,目前的代码如下所示: > > ```python > > def chi2(X, y, label_lst): > > total_samples = X.shape[0] >...
目前的实现如下,按照scipy源码写了一个python版本 ```python import jax import jax.numpy as jnp import numpy as np from sklearn.datasets import load_iris MACHEP = 0.0000001 # the machine roundoff error / tolerance BIG = 4.503599627370496e15 BIGINV...
> > 目前的实现如下,按照scipy源码写了一个python版本 > > ```python > > import jax > > import jax.numpy as jnp > > import numpy as np > > from sklearn.datasets import load_iris > >...
## 目前的实现 ```python MAX_ITER = 5 def _sf(x, df): if df
您的建议我已采纳,将max_iter设置为0,环改成fm128。尝试后发现,即使不进入循环,也会发生下溢现象,出现在计算ax语句 ```python ax = jnp.power(x, a) * jnp.exp(-x) * jnp.exp(-jax.lax.lgamma(a)) ``` 具体是计算`jnp.exp(-x)`时发生了下溢,这里我尝试采用`jnp.exp(-(x-jnp.max(x))`进行修正,但仍旧会出现溢出
> why closed it? sry, commit record may be a bit confusing, please check it first
> I tried to test it but failed. Please let me know if I missed anything. > > ```shell > atest run --report grpc -p sample/testsuite-gitee.yaml --report-dest 127.0.0.1:7070/writer_templates.ReportWriter/SendReportResult > 2024-05-17T04:33:56.903Z...
> LGTM > > And welcome to be the 21st contributor! I'm very glad to contribute to the project!