burlap
burlap copied to clipboard
Possible bug in "performReachabilityFrom" function in PolicyIteration.java
The number of states reachable from the given states differs when using PolicyIteration and ValueIteration.
When PolicyIteration and ValueIteration were run on a graph defined MDP of 100 states, ValueIteration output 100 but PolicyIteration output 99.
Code snippet to create the graph and run policy and value iterations (warning: code is untested, I'm working on an assignment, apologies):
public class GraphDefinedMDP {
private static final int NUM_STATES = 100;
private GraphDefinedDomain graphDefinedDomainGen;
private SADomain domain;
private State initialState;
private int[] goalStates = new int[NUM_STATES / 75];
private HashableStateFactory hashingFactory;
private Environment env;
public GraphDefinedMDP(int initialStateNum) {
graphDefinedDomainGen = new GraphDefinedDomain(NUM_STATES);
// Deterministic goal states
for (int i = 0; i < goalStates.length; i++) {
goalStates[i] = i * goalStates.length + (goalStates.length / 2);
}
// Print goal states
System.out.format("Goal States: %s\n",
Util.intArrayToString(goalStates));
// Set terminal states
TerminalFunction tf = new GraphTF(goalStates);
RewardFunction rf = new GraphRF() {
@Override
public double reward(int s, int a, int sprime) {
for (int goalState : goalStates)
if (goalState == sprime)
return 2;
return -1;
}
};
graphDefinedDomainGen.setTf(tf);
graphDefinedDomainGen.setRf(rf);
// All nodes are equally reachable from every other node
int action = 0;
double probability = 1.0 / (NUM_STATES - 1);
for (int srcNode = 0; srcNode < NUM_STATES; srcNode++) {
for (int dstNode = 0; dstNode < NUM_STATES; dstNode++) {
if (srcNode != dstNode) {
graphDefinedDomainGen.setTransition(srcNode, action,
dstNode, probability);
action = (action + 1) % goalStates.length;
}
}
}
if (graphDefinedDomainGen.isValidMDPGraph()) {
// Invalid MDP graph
System.exit(1);
}
domain = graphDefinedDomainGen.generateDomain();
initialState = new GraphStateNode(initialStateNum);
System.out.println("initialState: " + initialState);
hashingFactory = new SimpleHashableStateFactory();
env = new SimulatedEnvironment(domain, initialState);
}
public void valueIteration() {
Planner planner = new ValueIteration(domain, 0.99, hashingFactory,
0.001, 200);
Policy p = planner.planFromState(initialState);
Episode episode = PolicyUtils.rollout(p, initialState,
domain.getModel(), 500);
printEpisodeStats(episode);
}
public void policyIteration() {
Planner planner = new PolicyIteration(domain, 0.99, hashingFactory,
0.001, 200, 10);
Policy p = planner.planFromState(initialState);
Episode episode = PolicyUtils.rollout(p, initialState,
domain.getModel(), 500);
printEpisodeStats(episode);
}
public static void main(String args[]) {
int initialStateNum = 99;
System.out.println("---Value iteration---");
GraphDefinedMDP obj1 = new GraphDefinedMDP(initialStateNum);
long startTime = System.currentTimeMillis();
obj1.valueIteration();
long endTime = System.currentTimeMillis();
System.out.format("Time taken for value iteration: %d ms\n\n\n",
(endTime - startTime));
System.out.println("---Policy iteration---");
GraphDefinedMDP obj2 = new GraphDefinedMDP(initialStateNum);
startTime = System.currentTimeMillis();
obj2.policyIteration();
endTime = System.currentTimeMillis();
System.out.format("Time taken for policy iteration: %d ms\n\n\n",
(endTime - startTime));
}
Diff of the performReachabilityFrom(State state)
function: https://www.diffchecker.com/4ZsjNQx7