burlap icon indicating copy to clipboard operation
burlap copied to clipboard

Possible bug in "performReachabilityFrom" function in PolicyIteration.java

Open torcellite opened this issue 7 years ago • 0 comments

The number of states reachable from the given states differs when using PolicyIteration and ValueIteration.

When PolicyIteration and ValueIteration were run on a graph defined MDP of 100 states, ValueIteration output 100 but PolicyIteration output 99.

Code snippet to create the graph and run policy and value iterations (warning: code is untested, I'm working on an assignment, apologies):

    public class GraphDefinedMDP {

	private static final int NUM_STATES = 100;

	private GraphDefinedDomain graphDefinedDomainGen;
	private SADomain domain;

	private State initialState;
	private int[] goalStates = new int[NUM_STATES / 75];

	private HashableStateFactory hashingFactory;

	private Environment env;

	public GraphDefinedMDP(int initialStateNum) {
		graphDefinedDomainGen = new GraphDefinedDomain(NUM_STATES);

		// Deterministic goal states
		for (int i = 0; i < goalStates.length; i++) {
			goalStates[i] = i * goalStates.length + (goalStates.length / 2);
		}
		// Print goal states
		System.out.format("Goal States: %s\n",
				Util.intArrayToString(goalStates));

		// Set terminal states
		TerminalFunction tf = new GraphTF(goalStates);
		RewardFunction rf = new GraphRF() {
			@Override
			public double reward(int s, int a, int sprime) {
				for (int goalState : goalStates)
					if (goalState == sprime)
						return 2;
				return -1;
			}
		};
		graphDefinedDomainGen.setTf(tf);
		graphDefinedDomainGen.setRf(rf);

		// All nodes are equally reachable from every other node
		int action = 0;
		double probability = 1.0 / (NUM_STATES - 1);
		for (int srcNode = 0; srcNode < NUM_STATES; srcNode++) {
			for (int dstNode = 0; dstNode < NUM_STATES; dstNode++) {
				if (srcNode != dstNode) {
					graphDefinedDomainGen.setTransition(srcNode, action,
							dstNode, probability);
					action = (action + 1) % goalStates.length;
				}
			}
		}

		if (graphDefinedDomainGen.isValidMDPGraph()) {
			// Invalid MDP graph
			System.exit(1);
		}

		domain = graphDefinedDomainGen.generateDomain();

		initialState = new GraphStateNode(initialStateNum);
		System.out.println("initialState: " + initialState);

		hashingFactory = new SimpleHashableStateFactory();

		env = new SimulatedEnvironment(domain, initialState);
	}

	public void valueIteration() {
		Planner planner = new ValueIteration(domain, 0.99, hashingFactory,
				0.001, 200);
		Policy p = planner.planFromState(initialState);

		Episode episode = PolicyUtils.rollout(p, initialState,
				domain.getModel(), 500);
		printEpisodeStats(episode);
	}

	public void policyIteration() {
		Planner planner = new PolicyIteration(domain, 0.99, hashingFactory,
				0.001, 200, 10);
		Policy p = planner.planFromState(initialState);

		Episode episode = PolicyUtils.rollout(p, initialState,
				domain.getModel(), 500);
		printEpisodeStats(episode);
	}

	public static void main(String args[]) {
		int initialStateNum = 99;
		System.out.println("---Value iteration---");
		GraphDefinedMDP obj1 = new GraphDefinedMDP(initialStateNum);
		long startTime = System.currentTimeMillis();
		obj1.valueIteration();
		long endTime = System.currentTimeMillis();
		System.out.format("Time taken for value iteration: %d ms\n\n\n",
				(endTime - startTime));

		System.out.println("---Policy iteration---");
		GraphDefinedMDP obj2 = new GraphDefinedMDP(initialStateNum);
		startTime = System.currentTimeMillis();
		obj2.policyIteration();
		endTime = System.currentTimeMillis();
		System.out.format("Time taken for policy iteration: %d ms\n\n\n",
				(endTime - startTime));
	}

Diff of the performReachabilityFrom(State state) function: https://www.diffchecker.com/4ZsjNQx7

torcellite avatar Apr 24 '17 00:04 torcellite