今天咱们来介绍一个纯时序模型,N-BEATS模型,直接开门见山看看模型的机构,然后对着模型解读。
N-BEATS模型结构
图1.1就是N-BEATS模型的结构,看起来比较复杂,咱们拆开来看。首先来看模型的输入,输入的是观测序列[y1,...,yT], 输出是预测数据[yT,...,yT+H], 这里在原文中描述T=nH, n一般取2到7。
看上图的模型主要有几个结构,分别是stack、block。 多个block组成1个stack。 输入的数就是之前一段时间的观测数据。
Stack
上图是一个stack结构,第一个block的输入是观测序列x, 然后经过一个block会输出两个向量,分别是backcast x,和长度为H的预测向量y,后面block的输入是前面block的输出。
Block
一个Block里面由两部分构成,如下图。
第一部分是全连接层生成θb和θf,这里的线性层是一个简单的投影。其中FC Stack如下
h1=FC1(x)h2=FC2(h1)h3=FC3(h2)h4=FC4(h3)θf=wh4
其中FC是一个标准的全连接层,激活函数RELU。w是需要学习的权重。
第二部分是由gb(θb)和gf(θf)组成,接受上文的θf、θb,用于预测前向序列y,和后向序列x。
y=gf(θf)x=gb(θb)
通过预测θf来优化预测y的准确性,而预测x的作用是移除输入中对预测结果没有帮助的成分,帮助下游模块更好的预测。然后使用双残差的结构。 这里需要注意一下,前4个全连接层是共享的,只有到最后产生前后向参数的时候才引入独立的FC, 从而能够通过历史的残差提升对预测的数据预测精度
x=xt−1−xty=∑yi
输出的x不断求残差,然后每次输出的y在最后聚合以后变成实际的预测值。那么就搞定了预测任务。
总而言之
可以发现N-BEATS模型有一个比较好的思路是不仅预测后续的n个值,同时将过去的可观测值也当成一个输入,这样其实是更能让模型的准确率提升的,也是一个比较有意思的思路。希望在后续的工作中借鉴。
原文地址: N-BEATS