Time-Series-Library
Time-Series-Library copied to clipboard
短期预测model实例化问题
https://github.com/thuml/Time-Series-Library/blob/main/exp/exp_short_term_forecasting.py#L88
实例化模型是这样的:
outputs = self.model(batch_x, None, dec_inp, None)
我看不同的模型的forecast的定义不一样,例如,我使用iTransformer时,它的forecast是这样定义的: https://github.com/thuml/Time-Series-Library/blob/main/models/iTransformer.py#L51
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
_, _, N = x_enc.shape
# Embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projection(enc_out).permute(0, 2, 1)[:, :, :N]
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out
所以,不同模型执行的时候,需要修改exp_short_term_forecasting.py文件么?