incubator-wayang
incubator-wayang copied to clipboard
Add first draft of ml-in-wayang.md
This PR provides a short .md-guide that shows an examplory usage of the previously introduced abstraction of the cost model. The guide shows how it can be utilized in order to predict query plans runtimes with a pre-trained ML model.
Using Machine Learning for query optimization in Apache Wayang (incubating)
Apache Wayang (incubating) can be customized with concrete
implementations of the EstimatableCost
interface in order to optimize
for a desired metric. The implementation can be enabled by providing it
to a Configuration
.
public class CustomEstimatableCost implements EstimatableCost {
/* Provide concrete implementations to match desired cost function(s)
* by implementing the interface in this class.
*/
}
public class WordCount {
public static void main(String[] args) {
/* Create a Wayang context and specify the platforms Wayang will consider */
Configuration config = new Configuration();
/* Provision of a EstimatableCost that implements the interface.*/
config.setCostModel(new CustomEstimatableCost());
WayangContext wayangContext = new WayangContext(config)
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin());
/*... omitted */
}
}
In combination with an encoding scheme and a third party package to load ML models, the following example shows how to predict runtimes of query execution plans runtimes in Apache Wayang (incubating):
import org.apache.wayang.core.optimizer.costs.EstimatableCost;
import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory;
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval;
import org.apache.wayang.core.optimizer.enumeration.LoopImplementation;
import org.apache.wayang.core.optimizer.enumeration.PlanImplementation;
import org.apache.wayang.core.platform.Junction;
import org.apache.wayang.core.plan.executionplan.ExecutionPlan;
import org.apache.wayang.core.plan.executionplan.ExecutionStage;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.ml.encoding.OneHotEncoder;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.exception.WayangException;
import org.apache.wayang.core.plan.executionplan.Channel;
import org.apache.wayang.ml.OrtMLModel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.List;
public class MLCost implements EstimatableCost {
public EstimatableCostFactory getFactory() {
return new Factory();
}
public static class Factory implements EstimatableCostFactory {
@Override public EstimatableCost makeCost() {
return new MLCost();
}
}
@Override public ProbabilisticDoubleInterval getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
@Override public ProbabilisticDoubleInterval getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return ProbabilisticDoubleInterval.ofExactly(
model.runModel(OneHotEncoder.encode(plan))
);
} catch(Exception e) {
return ProbabilisticDoubleInterval.zero;
}
}
/** Returns a squashed cost estimate. */
@Override public double getSquashedEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public double getSquashedParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
try {
Configuration config = plan
.getOptimizationContext()
.getConfiguration();
OrtMLModel model = OrtMLModel.getInstance(config);
return model.runModel(OneHotEncoder.encode(plan));
} catch(Exception e) {
return 0;
}
}
@Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator operator) {
List<ProbabilisticDoubleInterval> intervalList = new ArrayList<ProbabilisticDoubleInterval>();
List<Double> doubleList = new ArrayList<Double>();
intervalList.add(this.getEstimate(plan, true));
doubleList.add(this.getSquashedEstimate(plan, true));
return new Tuple<>(intervalList, doubleList);
}
public PlanImplementation pickBestExecutionPlan(
Collection<PlanImplementation> executionPlans,
ExecutionPlan existingPlan,
Set<Channel> openChannels,
Set<ExecutionStage> executedStages) {
final PlanImplementation bestPlanImplementation = executionPlans.stream()
.reduce((p1, p2) -> {
final double t1 = p1.getSquashedCostEstimate();
final double t2 = p2.getSquashedCostEstimate();
return t1 < t2 ? p1 : p2;
})
.orElseThrow(() -> new WayangException("Could not find an execution plan."));
return bestPlanImplementation;
}
}
Third-party packages such as OnnxRuntime
can be used to load
pre-trained .onnx
files that contain desired ML models.
import org.apache.wayang.core.api.Configuration;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import java.util.Vector;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.Map;
import java.util.function.BiFunction;
public class OrtMLModel {
private static OrtMLModel INSTANCE;
private OrtSession session;
private OrtEnvironment env;
private final Map<String, OnnxTensor> inputMap = new HashMap<>();
private final Set<String> requestedOutputs = new HashSet<>();
public static OrtMLModel getInstance(Configuration configuration) throws OrtException {
if (INSTANCE == null) {
INSTANCE = new OrtMLModel(configuration);
}
return INSTANCE;
}
private OrtMLModel(Configuration configuration) throws OrtException {
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
}
public void loadModel(String filePath) throws OrtException {
if (this.env == null) {
this.env = OrtEnvironment.getEnvironment();
}
if (this.session == null) {
this.session = env.createSession(filePath, new OrtSession.SessionOptions());
}
}
public void closeSession() throws OrtException {
this.session.close();
this.env.close();
}
/**
* @param encodedVector
* @return NaN on error, and a predicted cost on any other value.
* @throws OrtException
*/
public double runModel(Vector<Long> encodedVector) throws OrtException {
double costPrediction;
OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
this.inputMap.put("input", tensor);
this.requestedOutputs.add("output");
BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
try {
return ((double[]) r.get(s).get().getValue())[0];
} catch (OrtException e) {
return Double.NaN;
}
};
try (Result r = session.run(inputMap, requestedOutputs)) {
costPrediction = unwrapFunc.apply(r, "output");
}
return costPrediction;
}
}