BasicTS
BasicTS copied to clipboard
请教一下关于loss部分的修改
我目前将其他模型移植到BasicTS,但是面临问题是移植模型的loss部分需要多个损失项的相加,loss = loss1(data1)+loss2(data2)+....,但BasicTS框架中losses.py由于自定义损失函数被包装了,传参只有input_data, target_data两个,是否有解决方案?
您可以参考STEP的loss。 总的来说,目前自定义loss函数的参数输入需要满足下述限制:
def customized_loss(prediction, real_value, other_param_1, other_param_2, ..., null_val=np.nan):
# main loss
pass
其中,prediction, real_value, other_param_1, other_param_2, ...,这些参数是和runner的forward函数的返回值相匹配的。换句话说,runner的返回值会自动作为参数注入到loss中。同理您可以参考STEP的runner。
最后一个参数,null_val是用来识别数据集中需要被忽略的异常点,默认一般为np.nan。您可以通过CFG.NULL_VAL在配置文件中进行设定。例如,对于交通数据集来说,0值一般是异常值(传感器宕机)。我们不希望模型强制拟合这些异常值,此时的NULL_VAL就会被设定为0.0,再进一步采用masked_mae等指标。