beast2
beast2 copied to clipboard
Optimise State.calculateCalcNodePath()
This method can show up in the profiler when using large starbeast3 analyses as the most computationally expensive method, so some optimisation would be helpful for these kind of analyses. Most of the time is spent doing membership checks in an ArrayList, according to the profiler, so replacing these with Set membership checks should help (linear vs logarithmic complexity).
The following implementation would do this, but needs rigorous testing:
final List<CalculationNode> calcNodes = new ArrayList<>();
final Set<CalculationNode> seen = new HashSet<>();
boolean progress = false;
for (int k = 0; k < nrOfChangedStateNodes; k++) {
int i = changeStateNodes[k];
// go grab the path to the Runnable
// first the outputs of the StateNodes that is changed
for (CalculationNode node : stateNodeOutputs[i]) {
if (!seen.contains(node)) {
calcNodes.add(node);
seen.add(node);
progress = true;
}
}
}
// next the path following the outputs
if (progress) {
progress = false;
// loop over beastObjects till no more beastObjects can be added
// efficiency is no issue here, assuming the graph remains
// constant
for (int calcNodeIndex = 0; calcNodeIndex < calcNodes.size(); calcNodeIndex++) {
CalculationNode node = calcNodes.get(calcNodeIndex);
for (BEASTInterface output : outputMap.get(node)) {
if (output instanceof CalculationNode) {
final CalculationNode calcNode = (CalculationNode) output;
if (!seen.contains(calcNode)) {
calcNodes.add(calcNode);
seen.add(calcNode);
progress = true;
}
} else {
throw new RuntimeException("DEVELOPER ERROR: found a"
+ " non-CalculatioNode ("
+output.getClass().getName()
+") on path between StateNode and Runnable");
}
}
}
}
// put calc nodes in partial order
for (int i = calcNodes.size()-1; i > 0; i--) {
CalculationNode node = calcNodes.get(i);
Set<BEASTInterface> outputs = node.getOutputs();
for (int j = 0; j < i; j++) {
if (outputs.contains(calcNodes.get(j))) {
// swap
final CalculationNode node2 = calcNodes.get(j);
calcNodes.set(j, node);
calcNodes.set(i, node2);
i++;
break;
}
}
}
return calcNodes;