PARL
PARL copied to clipboard
萌新关于drqn的一些疑问
我无法使用lstm实现td算法
例如:随机输入递减数列,数列元素取值区间为[0, 20],数列长度区间为[2, 20],输出该数列的相反数 要求:
- 使用LSTM
- 使用公式Q(s_t) = Q(s_{t+1}) - 1,规定Q(0) = 0
我期望得到一个类似[-3, -2, -1, 0]的递增数列,但是大部分情况下我只能得到[-4.5486226 -4.5486226 -4.5486226 -4.5486226],我无法找到原因。
以下是我的主要代码:
def main():
steps = 5000
lr = 1e-1
max_loss = 0.2 # 重复训练,直到loss小于该值(最多重复200次)
update_target_steps = 10 # 每隔多少步同步一次模型参数
print_steps = 100 # 每隔多少步打印一次
model, target_model = Model(), Model()
mse_loss_fn = paddle.nn.loss.MSELoss()
opt = paddle.optimizer.Adam(lr, parameters=model.parameters())
for step in range(steps):
if step % update_target_steps == 0:
target_model.set_state_dict(model.state_dict())
obs = episode_generate()
obs = paddle.to_tensor(obs, dtype='float32').unsqueeze([0, 2]) # (1, seqlen, 1)
y_true, _ = target_model(obs)
y_true.stop_gradient = True
y_true = y_true[0, :, 0] # (seqlen, )
if obs[0, -1, 0] == 0: # 该episode成功达到0
y_true[:-1] = y_true[1:]
y_true[-1] = 1
else:
y_true = y_true[1:]
y_true -= 1 # Q(s_t) = Q(s_{t+1}) - 1
loss, i = 999, 0
while loss >= max_loss and i < 200: # 重复训练,直到loss小于该值
y_pred, _ = model(obs)
y_pred = y_pred[0, :, 0] # (seqlen, )
if obs[0, -1, 0] != 0: # 该episode成功达到0
y_pred = y_pred[:-1]
loss = mse_loss_fn(y_pred, y_true)
loss.backward()
opt.step()
opt.clear_grad()
i += 1
if (step + 1) % print_steps == 0:
print("step: %d" % (step+1))
print(y_pred.numpy())
print(y_true.numpy())
print("loss:", loss.numpy())
print("重复次数:", i)
print()
class Model(paddle.nn.Layer):
"""
输入obs,期望输出-obs
"""
def __init__(self):
super(Model, self).__init__()
obs_dim = 1
dim = 128
self.fc1 = paddle.nn.Linear(obs_dim, dim)
self.fc2 = paddle.nn.Linear(dim, dim)
self.lstm = paddle.nn.LSTM(dim, dim, num_layers=1)
self.output_fc1 = paddle.nn.Linear(dim, dim)
self.output_fc2 = paddle.nn.Linear(dim, 1)
def forward(self, obs, state=None):
output = paddle.nn.functional.gelu(self.fc1(obs))
output = self.fc2(output)
output, state = self.lstm(output, initial_states=state) # [batch_size, time_steps, dim]
output = paddle.nn.functional.gelu(self.output_fc1(output))
act_logits = self.output_fc2(output)
return act_logits, state
def episode_generate() -> list:
"""
随机生成递减数列,要求数列元素取值区间为[0, 20],数列长度区间为[2, 20]
"""
begin = random.randint(1, 20)
l = random.randint(2, 20)
end = max(begin - l, -1)
return list(range(begin, end, -1))