text-classification-svm
text-classification-svm copied to clipboard
Classifier parameter setting
感谢@hankcs。请问如何使用这段grid search? https://github.com/hankcs/text-classification-svm/blob/be8c6541720363faeccf21a05021dd9e30047229/src/main/java/com/hankcs/hanlp/classification/classifiers/LinearSVMClassifier.java#L96
这是一段搜索正则化因子的函数。
/**
*
* Liblinear 自动寻参
* @author hankcs
*/
public class grid
{
public static double find_parameters(final Problem prob, double from, double end, double step)
{
if (from > end)
{
double x = end;
from = end;
end = x;
}
if(step < 0) step = -step;
final double[] cs = new double[(int) ((end - from) / step)];
final double[] as = new double[cs.length];
Linear.setDebugOutput(null);
ExecutorService fixedThreadPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
final AtomicInteger finished = new AtomicInteger(0);
for (int i = 0; i < cs.length; i++)
{
cs[i] = from + step * i;
final int index = i;
fixedThreadPool.execute(new Runnable()
{
public void run()
{
int n = finished.incrementAndGet();
as[index] = validate(prob, cs[index]);
System.out.printf("%.2f%%...\n", n / (double)cs.length * 100.);
}
});
}
fixedThreadPool.shutdown();
try
{
fixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
}
catch (InterruptedException e)
{
e.printStackTrace();
}
int p = 0;
double max = Double.NEGATIVE_INFINITY;
for (int i = 0; i < as.length; i++)
{
if (as[i] > max)
{
max = as[i];
p = i;
}
}
System.out.printf("Best Cross Validation Accuracy = %g%%, C = %f%n", max * 100, cs[p]);
return cs[p];
}
private static double validate(Problem prob, double C)
{
double[] target = new double[prob.l];
Parameter param = new Parameter(SolverType.L1R_LR, C, 0.01);
Linear.crossValidation(prob, param, 5, target);
int total_correct = 0;
for (int i = 0; i < prob.l; i++)
if (target[i] == prob.y[i]) ++total_correct;
return total_correct / (double)prob.l;
}
public static void main(String[] args) throws IOException, InvalidInputDataException
{
Problem problem = Train.readProblem(new File("libsvm/dataset/heart_scale.txt"), -1);
System.out.println(find_parameters(problem, 1., 1000., 1.));
}
}