dqn icon indicating copy to clipboard operation
dqn copied to clipboard

dqn.py indexing is not right

Open ShibiHe opened this issue 8 years ago • 1 comments

I found the indexing in build_function not right. You can run the code below to testify the wrong indexing in VS[:, A]

This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py

def build_functions(self):
    S = Input(shape=self.state_size)
    NS = Input(shape=self.state_size)
    A = Input(shape=(1,), dtype='int32')
    R = Input(shape=(1,), dtype='float32')
    T = Input(shape=(1,), dtype='int32')
    self.build_model()
    self.value_fn = K.function([S], self.model(S))

    VS = self.model(S)
    VNS = disconnected_grad(self.model(NS))
    future_value = (1-T) * VNS.max(axis=1, keepdims=True)
    discounted_future_value = self.discount * future_value
    target = R + discounted_future_value
    cost0 = VS[:, A] - target
    cost = ((VS[:, A] - target)**2).mean()
    opt = RMSprop(0.0001)
    params = self.model.trainable_weights
    updates = opt.get_updates(params, [], cost)

    self.train_fn = K.function([S, NS, A, R, T], [cost, cost0, target, A], updates=updates)
    # import numpy as np
    # t = self.train_fn([np.random.rand(10, *self.state_size), np.random.rand(10, *self.state_size), np.ones((10, 1)), np.ones((10, 1)), np.zeros((10, 1))])
    # print('cost=', t[0])
    # print('cost0=', t[1])
    # print('target=', t[2])
    # print('A=', t[3])
    # raw_input()_

ShibiHe avatar Jun 06 '16 16:06 ShibiHe

Hi @ShibiHe, thanks for your comment. You're right. This is a bug. I should be using np.arange(n) instead of :.

sherjilozair avatar Oct 28 '16 00:10 sherjilozair