Gato-A-Generalist-Agent
Gato-A-Generalist-Agent copied to clipboard
问题请教
你好非常感谢你共享出如此清晰,可读性那么高的代码供我学习。但是在代码中有一处我没有想太明白,想请教一下。 在model.py文件DecisionTransformer类的forward函数中最后对动作、状态进行预测哪一部分。
h = self.embed_ln(h)
# transformer and prediction
h = self.transformer(h)
# 修改 h 的形状, 使得 [B, 3 * T, H] -> [B, 3, T, H]
# h[:, i] 的形状为 [B, T, H]
# h[:, 0, t] 是基于输入序列 r_0, s_0, a_0 ... r_t 来条件化的
# h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
# h[:, 1, t] 是基于输入序列 r_0, s_0, a_0 ... r_t, s_t 来条件化的
# h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t
# h[:, 2, t] 是基于输入序列 r_0, s_0, a_0 ... r_t, s_t, a_t 来条件化的
# h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t
# 即每个时间步 (t) 我们有 3 个 transformer 根据以前序列提取的特征输出
# 以前序列包括 t 之前的所有时间步以及当前时刻 t 的 3 个输入 (r_t, s_t, a_t)
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)
# get predictions
return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a
state_preds = self.predict_state(h[:, 0]) # predict next state given r, s, a
action_preds = self.predict_action(h[:, 1]) # predict action given r, s
self.predict_rtg(h[:, 2]) 中的h[:, 2]为什么是这样取。我理解的应该是取h[:, 0]去预测rtg;取h[:,1]预测state;取h[:,2]预测action。请教一下是不是我理解出错了。非常能给出解答。