1day_1paper icon indicating copy to clipboard operation
1day_1paper copied to clipboard

[26] HiPPO: Recurrent Memory with Optimal Polynomial Projections

Open dhkim0225 opened this issue 3 years ago • 1 comments

S4 ( #52 ) 의 시초. HiPPO ==> LSSL ( #51 ) ==> S4 ( #52 ) 모두 1저자 작품. 수식이 워낙 어려워서 컨셉 위주로만 이해해 보려 한다.

paper code blog

Abstract

HiPPO 는 polynomial bases 에 projection 을 해서 continuous-signal 과 discrete time-signal 를 online 하게 compression 하는 녀석이다.

각 time-step의 중요도 measurement 가 정의되면, HiPPO 는 natural online function approximation problem 의 optimal solution을 찾아낸다. 특별한 경우에, first principle 로 LMU(Legendre Memory Unit)의 derivation을 만들어 내고, GRU와 같은 recurrent neural network의 gating mechanism을 generalize 한다.

new memory update mechanism (HiPPO-LegS) 또한 적용되는데, 모든 history를 기억하게 해서 timescale 에 따른 prior 를 배제할 수 있다. HiPPO-LegS 는 이론적으로 timescale robustness, fast update, bounded gradient 의 이점이 있다. 이런 memory dynamic은 recurrent neural network 와 결합될 때, 실험적으로도 temporal dependency 를 잘 잡아내는 것을 보여준다.

~하... abstract 도 어렵다.~

Online Function Approximation: A Formalism for Incremental Memory Representations

이론적으로 분석이 쉬운 discrete-time 을 우선 continuous time setting 으로 바꾸는 것이 첫 번째 step이다. 그리고 다음과 같은 질문을 던진다.

1 dimension 의 continuous function f(t) 가 주어졌을 때,
모든 time t 에 대해 0 ~ t 의 history 를 간직하는 fixed size representation c(t) 를 유지할 수 있을까?

이 문제를 정의하기 위해서 2가지를 추가로 정의한다.

  1. Quality of approximation
    1. function history 의 optimal-approximation 이란 무엇일까?
    2. 이를 정의하려면 past-time 에 대한 weight function이 필요하다.
  2. Basis
    1. Continuous function을 fixed-length vector로 압축하는 방법은 단순 projection 이 있다.
    2. 간단히 polynomial basis 를 사용한다고 가정한다.

HiPPO

image

  1. 모든 함수에 대해,
  2. 모든 시간 t 에는 f를 polynomial-space로 optimal projection하는 g(t)가 있다. (with measure µ(t) weighing the past.)
  3. 적절한 basis 선택으로 c(t) 는 history 를 나타내게 된다.
  4. Discretize 시킨다. ()

여러가지 증명을 거쳐 나온 HiPPO 는 다음과 같은 형태이다. image 간단히 말해, measurement들을 갖고 closed form matrix A(t), B(t) 를 이용한 ODE 를 구성한다.

Measurement에 따른 output. 상당히 hmmteresting ! image

결국 이를 rnn 에 붙인다면, 이런 형태이다 image image

사실 코드는 S4 레포의 hippo 구현체를 보는게 더 편했다.

class HiPPO_LegS(nn.Module):
    """ Vanilla HiPPO-LegS model (scale invariant instead of time invariant) """
    def __init__(self, N, max_length=1024, measure='legs', discretization='bilinear'):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = transition(measure, N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            if discretization == 'forward':
                A_stacked[t - 1] = np.eye(N) + At
                B_stacked[t - 1] = Bt
            elif discretization == 'backward':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, np.eye(N), lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At, Bt, lower=True)
            elif discretization == 'bilinear':
                A_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True)
                B_stacked[t - 1] = la.solve_triangular(np.eye(N) - At / 2, Bt, lower=True)
            else: # ZOH
                A_stacked[t - 1] = la.expm(A * (math.log(t + 1) - math.log(t)))
                B_stacked[t - 1] = la.solve_triangular(A, A_stacked[t - 1] @ B - B, lower=True)
        self.A_stacked = torch.Tensor(A_stacked) # (max_length, N, N)
        self.B_stacked = torch.Tensor(B_stacked) # (max_length, N)
        # print("B_stacked shape", B_stacked.shape)

        vals = np.linspace(0.0, 1.0, max_length)
        self.eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)

    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """

        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        u = torch.transpose(inputs, 0, -2)
        u = u * self.B_stacked[:L]
        u = torch.transpose(u, 0, -2) # (length, ..., N)

        result = variable_unroll_matrix(self.A_stacked[:L], u)
        return result

    def reconstruct(self, c):
        a = self.eval_matrix @ c.unsqueeze(-1)
        return a.squeeze(-1)

def transition(measure, N, **measure_args):
    """ A, B transition matrices for different measures.
    measure: the type of measure
      legt - Legendre (translated)
      legs - Legendre (scaled)
      glagt - generalized Laguerre (translated)
      lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
    """
    # Laguerre (translated)
    if measure == 'lagt':
        b = measure_args.get('beta', 1.0)
        A = np.eye(N) / 2 - np.tril(np.ones((N, N)))
        B = b * np.ones((N, 1))
    if measure == 'tlagt':
        # beta = 1 corresponds to no tilt
        b = measure_args.get('beta', 1.0)
        A = (1.-b)/2 * np.eye(N) - np.tril(np.ones((N, N)))
        B = b * np.ones((N, 1))
    # Generalized Laguerre
    # alpha 0, beta small is most stable (limits to the 'lagt' measure)
    # alpha 0, beta 1 has transition matrix A = [lower triangular 1]
    if measure == 'glagt':
        alpha = measure_args.get('alpha', 0.0)
        beta = measure_args.get('beta', 0.01)
        A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1)
        B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None]

        L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1)))
        A = (1./L[:, None]) * A * L[None, :]
        B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2)
    # Legendre (translated)
    elif measure == 'legt':
        Q = np.arange(N, dtype=np.float64)
        R = (2*Q + 1) ** .5
        j, i = np.meshgrid(Q, Q)
        A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
        B = R[:, None]
        A = -A
    # LMU: equivalent to LegT up to normalization
    elif measure == 'lmu':
        Q = np.arange(N, dtype=np.float64)
        R = (2*Q + 1)[:, None] # / theta
        j, i = np.meshgrid(Q, Q)
        A = np.where(i < j, -1, (-1.)**(i-j+1)) * R
        B = (-1.)**Q[:, None] * R
    # Legendre (scaled)
    elif measure == 'legs':
        q = np.arange(N, dtype=np.float64)
        col, row = np.meshgrid(q, q)
        r = 2 * q + 1
        M = -(np.where(row >= col, r, 0) - np.diag(q))
        T = np.sqrt(np.diag(2 * q + 1))
        A = T @ M @ np.linalg.inv(T)
        B = np.diag(T)[:, None]

    return A, B

def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16):
    if s is None:
        s = torch.zeros_like(u[0])
    has_batch = len(u.shape) >= len(A.shape)
    op = lambda x, y: batch_mult(x, y, has_batch)
    sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0]
    matmul = lambda x, y: x @ y
    return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit)

def batch_mult(A, u, has_batch=None):
    """ Matrix mult A @ u with special case to save memory if u has additional batch dim
    The batch dimension is assumed to be the second dimension
    A : (L, ..., N, N)
    u : (L, [B], ..., N)
    has_batch: True, False, or None. If None, determined automatically
    Output:
    x : (L, [B], ..., N)
      A @ u broadcasted appropriately
    """

    if has_batch is None:
        has_batch = len(u.shape) >= len(A.shape)

    if has_batch:
        u = u.permute([0] + list(range(2, len(u.shape))) + [1])
    else:
        u = u.unsqueeze(-1)
    v = (A @ u)
    if has_batch:
        v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1)))
    else:
        v = v[..., 0]
    return v

Results

image

dhkim0225 avatar Nov 09 '21 11:11 dhkim0225

감사합니다. 덕분에 논문에 대한 이해가 빨라질 것 같습니다.

KwangryeolPark avatar Feb 23 '24 11:02 KwangryeolPark