1day_1paper
1day_1paper copied to clipboard
[26] HiPPO: Recurrent Memory with Optimal Polynomial Projections
S4 ( #52 ) 의 시초. HiPPO ==> LSSL ( #51 ) ==> S4 ( #52 ) 모두 1저자 작품. 수식이 워낙 어려워서 컨셉 위주로만 이해해 보려 한다.
Abstract
HiPPO 는 polynomial bases 에 projection 을 해서 continuous-signal 과 discrete time-signal 를 online 하게 compression 하는 녀석이다.
각 time-step의 중요도 measurement 가 정의되면, HiPPO 는 natural online function approximation problem
의 optimal solution을 찾아낸다. 특별한 경우에, first principle 로 LMU(Legendre Memory Unit)의 derivation을 만들어 내고, GRU와 같은 recurrent neural network의 gating mechanism을 generalize 한다.
new memory update mechanism (HiPPO-LegS) 또한 적용되는데, 모든 history를 기억하게 해서 timescale 에 따른 prior 를 배제할 수 있다. HiPPO-LegS 는 이론적으로 timescale robustness, fast update, bounded gradient 의 이점이 있다. 이런 memory dynamic은 recurrent neural network 와 결합될 때, 실험적으로도 temporal dependency 를 잘 잡아내는 것을 보여준다.
~하... abstract 도 어렵다.~
Online Function Approximation: A Formalism for Incremental Memory Representations
이론적으로 분석이 쉬운 discrete-time 을 우선 continuous time setting 으로 바꾸는 것이 첫 번째 step이다. 그리고 다음과 같은 질문을 던진다.
1 dimension 의 continuous function f(t) 가 주어졌을 때,
모든 time t 에 대해 0 ~ t 의 history 를 간직하는 fixed size representation c(t) 를 유지할 수 있을까?
이 문제를 정의하기 위해서 2가지를 추가로 정의한다.
- Quality of approximation
- function history 의 optimal-approximation 이란 무엇일까?
- 이를 정의하려면 past-time 에 대한 weight function이 필요하다.
- Basis
- Continuous function을 fixed-length vector로 압축하는 방법은 단순 projection 이 있다.
- 간단히 polynomial basis 를 사용한다고 가정한다.
HiPPO
- 모든 함수에 대해,
- 모든 시간 t 에는 f를 polynomial-space로 optimal projection하는 g(t)가 있다. (with measure µ(t) weighing the past.)
- 적절한 basis 선택으로 c(t) 는 history 를 나타내게 된다.
- Discretize 시킨다. (
)
여러가지 증명을 거쳐 나온 HiPPO 는 다음과 같은 형태이다.
간단히 말해, measurement들을 갖고 closed form matrix A(t), B(t) 를 이용한 ODE 를 구성한다.
Measurement에 따른 output. 상당히 hmmteresting !
결국 이를 rnn 에 붙인다면, 이런 형태이다
사실 코드는 S4 레포의 hippo 구현체를 보는게 더 편했다.
class HiPPO_LegS(nn.Module):
""" Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
def __init__(self, N, max_length=1024, measure='legs', discretization='bilinear'):
"""
max_length: maximum sequence length
"""
super().__init__()
self.N = N
A, B = transition(measure, N)
B = B.squeeze(-1)
A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
B_stacked = np.empty((max_length, N), dtype=B.dtype)
for t in range(1, max_length + 1):
At = A / t
Bt = B / t
if discretization == 'forward':
A_stacked[t - 1] = np.eye(N) + At
B_stacked[t - 1] = Bt
elif discretization == 'backward':
A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
elif discretization == 'bilinear':
A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True)
B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True)
else: # ZOH
A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
self.A_stacked = torch.Tensor(A_stacked) # (max_length, N, N)
self.B_stacked = torch.Tensor(B_stacked) # (max_length, N)
# print("B_stacked shape", B_stacked.shape)
vals = np.linspace(0.0, 1.0, max_length)
self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
def forward(self, inputs):
"""
inputs : (length, ...)
output : (length, ..., N) where N is the order of the HiPPO projection
"""
L = inputs.shape[0]
inputs = inputs.unsqueeze(-1)
u = torch.transpose(inputs, 0, -2)
u = u * self.B_stacked[:L]
u = torch.transpose(u, 0, -2) # (length, ..., N)
result = variable_unroll_matrix(self.A_stacked[:L], u)
return result
def reconstruct(self, c):
a = self.eval_matrix @ c.unsqueeze(-1)
return a.squeeze(-1)
def transition(measure, N, **measure_args):
""" A, B transition matrices for different measures.
measure: the type of measure
legt - Legendre (translated)
legs - Legendre (scaled)
glagt - generalized Laguerre (translated)
lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
"""
# Laguerre (translated)
if measure == 'lagt':
b = measure_args.get('beta', 1.0)
A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
if measure == 'tlagt':
# beta = 1 corresponds to no tilt
b = measure_args.get('beta', 1.0)
A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N)))
B = b * np.ones((N, 1))
# Generalized Laguerre
# alpha 0, beta small is most stable (limits to the 'lagt' measure)
# alpha 0, beta 1 has transition matrix A = [lower triangular 1]
if measure == 'glagt':
alpha = measure_args.get('alpha', 0.0)
beta = measure_args.get('beta', 0.01)
A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1)
B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None]
L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1)))
A = (1./L[:, None]) * A * L[None, :]
B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2)
# Legendre (translated)
elif measure == 'legt':
Q = np.arange(N, dtype=np.float64)
R = (2*Q + 1) ** .5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
B = R[:, None]
A = -A
# LMU: equivalent to LegT up to normalization
elif measure == 'lmu':
Q = np.arange(N, dtype=np.float64)
R = (2*Q + 1)[:, None] # / theta
j, i = np.meshgrid(Q, Q)
A = np.where(i < j, -1, (-1.)**(i-j+1)) * R
B = (-1.)**Q[:, None] * R
# Legendre (scaled)
elif measure == 'legs':
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
return A, B
def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16):
if s is None:
s = torch.zeros_like(u[0])
has_batch = len(u.shape) >= len(A.shape)
op = lambda x, y: batch_mult(x, y, has_batch)
sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
matmul = lambda x, y: x @ y
return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit)
def batch_mult(A, u, has_batch=None):
""" Matrix mult A @ u with special case to save memory if u has additional batch dim
The batch dimension is assumed to be the second dimension
A : (L, ..., N, N)
u : (L, [B], ..., N)
has_batch: True, False, or None. If None, determined automatically
Output:
x : (L, [B], ..., N)
A @ u broadcasted appropriately
"""
if has_batch is None:
has_batch = len(u.shape) >= len(A.shape)
if has_batch:
u = u.permute([0] + list(range(2, len(u.shape))) + [1])
else:
u = u.unsqueeze(-1)
v = (A @ u)
if has_batch:
v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1)))
else:
v = v[..., 0]
return v
Results
감사합니다. 덕분에 논문에 대한 이해가 빨라질 것 같습니다.