text-classification-svm icon indicating copy to clipboard operation
text-classification-svm copied to clipboard

Classifier parameter setting

Open Opdoop opened this issue 5 years ago • 1 comments

感谢@hankcs。请问如何使用这段grid search? https://github.com/hankcs/text-classification-svm/blob/be8c6541720363faeccf21a05021dd9e30047229/src/main/java/com/hankcs/hanlp/classification/classifiers/LinearSVMClassifier.java#L96

Opdoop avatar Feb 27 '19 06:02 Opdoop

这是一段搜索正则化因子的函数。

/**
 *
 * Liblinear 自动寻参
 * @author hankcs
 */
public class grid
{
    public static double find_parameters(final Problem prob, double from, double end, double step)
    {
        if (from > end)
        {
            double x = end;
            from = end;
            end = x;
        }
        if(step < 0) step = -step;
        final double[] cs = new double[(int) ((end - from) / step)];
        final double[] as = new double[cs.length];
        Linear.setDebugOutput(null);
        ExecutorService fixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        final AtomicInteger finished = new AtomicInteger(0);
        for (int i = 0; i < cs.length; i++)
        {
            cs[i] = from + step * i;
            final int index = i;
            fixedThreadPool.execute(new Runnable()
            {
                public void run()
                {
                    int n = finished.incrementAndGet();
                    as[index] = validate(prob, cs[index]);
                    System.out.printf("%.2f%%...\n", n / (double)cs.length * 100.);
                }
            });
        }
        fixedThreadPool.shutdown();
        try
        {
            fixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
        }
        catch (InterruptedException e)
        {
            e.printStackTrace();
        }
        int p = 0;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < as.length; i++)
        {
            if (as[i] > max)
            {
                max = as[i];
                p = i;
            }
        }
        System.out.printf("Best Cross Validation Accuracy = %g%%, C = %f%n", max * 100, cs[p]);

        return cs[p];
    }

    private static double validate(Problem prob, double C)
    {
        double[] target = new double[prob.l];
        Parameter param = new Parameter(SolverType.L1R_LR, C, 0.01);
        Linear.crossValidation(prob, param, 5, target);
        int total_correct = 0;
        for (int i = 0; i < prob.l; i++)
            if (target[i] == prob.y[i]) ++total_correct;
        return total_correct / (double)prob.l;
    }

    public static void main(String[] args) throws IOException, InvalidInputDataException
    {
        Problem problem = Train.readProblem(new File("libsvm/dataset/heart_scale.txt"), -1);
        System.out.println(find_parameters(problem, 1., 1000., 1.));
    }
}

hankcs avatar Mar 05 '19 17:03 hankcs