beast2 icon indicating copy to clipboard operation
beast2 copied to clipboard

Optimise State.calculateCalcNodePath()

Open rbouckaert opened this issue 2 years ago • 0 comments

This method can show up in the profiler when using large starbeast3 analyses as the most computationally expensive method, so some optimisation would be helpful for these kind of analyses. Most of the time is spent doing membership checks in an ArrayList, according to the profiler, so replacing these with Set membership checks should help (linear vs logarithmic complexity).

The following implementation would do this, but needs rigorous testing:

    final List<CalculationNode> calcNodes = new ArrayList<>();
    final Set<CalculationNode> seen = new HashSet<>();
    
    boolean progress = false;
    for (int k = 0; k < nrOfChangedStateNodes; k++) {
        int i = changeStateNodes[k];
        // go grab the path to the Runnable
        // first the outputs of the StateNodes that is changed
        for (CalculationNode node : stateNodeOutputs[i]) {
            if (!seen.contains(node)) {
            	calcNodes.add(node);
                seen.add(node);
                progress = true;
            }
        }
    }
    
    // next the path following the outputs
    if (progress) {
        progress = false;
        // loop over beastObjects till no more beastObjects can be added
        // efficiency is no issue here, assuming the graph remains 
        // constant
        for (int calcNodeIndex = 0; calcNodeIndex < calcNodes.size(); calcNodeIndex++) {
            CalculationNode node = calcNodes.get(calcNodeIndex);
            for (BEASTInterface output : outputMap.get(node)) {
                if (output instanceof CalculationNode) {
                    final CalculationNode calcNode = (CalculationNode) output;
                    if (!seen.contains(calcNode)) {
                        calcNodes.add(calcNode);
                        seen.add(calcNode);
                        progress = true;
                    }
                } else {
                    throw new RuntimeException("DEVELOPER ERROR: found a"
                            + " non-CalculatioNode ("
                            +output.getClass().getName()
                            +") on path between StateNode and Runnable");
                }
            }
        }
    }

    // put calc nodes in partial order
    for (int i = calcNodes.size()-1; i > 0; i--) {
        CalculationNode node = calcNodes.get(i);
        Set<BEASTInterface> outputs = node.getOutputs();
        for (int j = 0; j < i; j++) {
            if (outputs.contains(calcNodes.get(j))) {
                // swap
                final CalculationNode node2 = calcNodes.get(j);
                calcNodes.set(j, node);
                calcNodes.set(i, node2);
                i++;
                break;
            }
        }
    }

    return calcNodes;

rbouckaert avatar Mar 29 '23 23:03 rbouckaert