dqn
dqn copied to clipboard
dqn.py indexing is not right
I found the indexing in build_function not right. You can run the code below to testify the wrong indexing in VS[:, A]
This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py
def build_functions(self):
S = Input(shape=self.state_size)
NS = Input(shape=self.state_size)
A = Input(shape=(1,), dtype='int32')
R = Input(shape=(1,), dtype='float32')
T = Input(shape=(1,), dtype='int32')
self.build_model()
self.value_fn = K.function([S], self.model(S))
VS = self.model(S)
VNS = disconnected_grad(self.model(NS))
future_value = (1-T) * VNS.max(axis=1, keepdims=True)
discounted_future_value = self.discount * future_value
target = R + discounted_future_value
cost0 = VS[:, A] - target
cost = ((VS[:, A] - target)**2).mean()
opt = RMSprop(0.0001)
params = self.model.trainable_weights
updates = opt.get_updates(params, [], cost)
self.train_fn = K.function([S, NS, A, R, T], [cost, cost0, target, A], updates=updates)
# import numpy as np
# t = self.train_fn([np.random.rand(10, *self.state_size), np.random.rand(10, *self.state_size), np.ones((10, 1)), np.ones((10, 1)), np.zeros((10, 1))])
# print('cost=', t[0])
# print('cost0=', t[1])
# print('target=', t[2])
# print('A=', t[3])
# raw_input()_
Hi @ShibiHe, thanks for your comment. You're right. This is a bug. I should be using np.arange(n)
instead of :
.