xlearn icon indicating copy to clipboard operation
xlearn copied to clipboard

FM 分类训练--no-norm 预测正常,去掉--no-norm 预测得到接近1的值

Open BHMliang opened this issue 4 years ago • 3 comments

通过命令行训练模型,加上--no-norm 得到模型文件。用自定义代码进行预测,与使用命令行加上--no-norm预测的结果一致; 通过命令行训练模型,不加--no-norm 得到模型文件。用自定义代码(归一化,norm = 1.0 / feaIds.size(); sqrtNorm = Math.sqrt(norm))进行预测,与使用命令行不加--no-norm预测的结果不一致;自定义代码预测的值比较大; 自定义预测代码:

        double sumWeight = 0.0;
        // bias
        sumWeight += bias;
        // norm
        double norm = 1.0 / feaIds.size();

        //norm = 1.0;

        // sqrt norm
        double sqrtNorm = Math.sqrt(norm);

        // linear term
        sumWeight += feaIds.stream().mapToDouble(feaId -> NumberUtils.doubleValue(feature2Weight.get(feaId)) * sqrtNorm).sum();
       //latent factor
       for (int i = 0; i < latentSize; i++) {
            //sum1
            double sum1 = 0.0;
            //sum2
            double sum2 = 0.0;
            for (long feaId : feaIds) {
                List<Double> embedding = getEmbedding(feaId);
                if(CollectionUtils.isEmpty(embedding)){
                    continue;
                }
                double d = embedding.get(i) * sqrtNorm;
                sum1 += d;
                sum2 += d * d;
            }
            sumWeight += (0.5 * (sum1 * sum1 - sum2));
        }
        for (int i = 0; i < latentSize; i++) {
            //sum1
            double sum1 = 0.0;
            //sum2
            double sum2 = 0.0;
            for (long feaId : feaIds) {
                List<Double> embedding = getEmbedding(feaId);
                if(CollectionUtils.isEmpty(embedding)){
                    continue;
                }
                double d = embedding.get(i) * sqrtNorm;
                sum1 += d;
                sum2 += d * d;
            }
         sumWeight += (0.5 * (sum1 * sum1 - sum2));

@aksnzhy 请问我的代码有什么问题吗

BHMliang avatar Oct 16 '19 18:10 BHMliang

double d = embedding.get(i) * sqrtNorm; 改为double d = embedding.get(i) * norm;之后结果一致,但不明白为什么要*norm。 根据原公式: image

v*x 翻译过来应该是embedding.get(i) * sqrtNorm * 1。 不明白这里为什么却是embedding.get(i) * norm * 1

BHMliang avatar Oct 17 '19 09:10 BHMliang

@aksnzhy 和上面一样的疑问,看了一遍xlearn的代码,在predict和CaclGrad的时候,在embedding这里都是乘的norm,而不是sqrtNorm

dongjiewhu avatar Oct 17 '19 11:10 dongjiewhu

@aksnzhy 我这边也是同样的问题,为什么在embedding这里都是乘的norm,而不是sqrtNorm呢

BUCTdarkness avatar Oct 17 '19 11:10 BUCTdarkness