Gato-A-Generalist-Agent icon indicating copy to clipboard operation
Gato-A-Generalist-Agent copied to clipboard

问题请教

Open le-wei opened this issue 11 months ago • 0 comments

你好非常感谢你共享出如此清晰,可读性那么高的代码供我学习。但是在代码中有一处我没有想太明白,想请教一下。 在model.py文件DecisionTransformer类的forward函数中最后对动作、状态进行预测哪一部分。

        h = self.embed_ln(h)

        # transformer and prediction
        h = self.transformer(h)

        # 修改 h 的形状, 使得 [B, 3 * T, H] -> [B, 3, T, H]
        # h[:, i] 的形状为 [B, T, H]
        # h[:, 0, t] 是基于输入序列 r_0, s_0, a_0 ... r_t 来条件化的
        # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
        # h[:, 1, t] 是基于输入序列 r_0, s_0, a_0 ... r_t, s_t 来条件化的
        # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t
        # h[:, 2, t] 是基于输入序列 r_0, s_0, a_0 ... r_t, s_t, a_t 来条件化的
        # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t
        # 即每个时间步 (t) 我们有 3 个 transformer 根据以前序列提取的特征输出
        # 以前序列包括 t 之前的所有时间步以及当前时刻 t 的 3 个输入 (r_t, s_t, a_t)
        h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_rtg(h[:, 2])     # predict next rtg given r, s, a
        state_preds = self.predict_state(h[:, 0])    # predict next state given r, s, a
        action_preds = self.predict_action(h[:, 1])  # predict action given r, s

self.predict_rtg(h[:, 2]) 中的h[:, 2]为什么是这样取。我理解的应该是取h[:, 0]去预测rtg;取h[:,1]预测state;取h[:,2]预测action。请教一下是不是我理解出错了。非常能给出解答。

le-wei avatar Mar 07 '24 13:03 le-wei