Time-Series-Library icon indicating copy to clipboard operation
Time-Series-Library copied to clipboard

短期预测model实例化问题

Open ChesonHuang opened this issue 6 months ago • 0 comments

https://github.com/thuml/Time-Series-Library/blob/main/exp/exp_short_term_forecasting.py#L88 实例化模型是这样的: outputs = self.model(batch_x, None, dec_inp, None)

我看不同的模型的forecast的定义不一样,例如,我使用iTransformer时,它的forecast是这样定义的: https://github.com/thuml/Time-Series-Library/blob/main/models/iTransformer.py#L51

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # Normalization from Non-stationary Transformer
    means = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - means
    stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    x_enc /= stdev

    _, _, N = x_enc.shape

    # Embedding
    enc_out = self.enc_embedding(x_enc, x_mark_enc)
    enc_out, attns = self.encoder(enc_out, attn_mask=None)

    dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
    # De-Normalization from Non-stationary Transformer
    dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
    return dec_out

所以,不同模型执行的时候,需要修改exp_short_term_forecasting.py文件么?

ChesonHuang avatar Aug 09 '24 08:08 ChesonHuang