1day_1paper icon indicating copy to clipboard operation
1day_1paper copied to clipboard

[24] Combining Recurrent, Convolutional, and Continuous-time Models with Linear State-Space Layers (LSSL)

Open dhkim0225 opened this issue 3 years ago • 0 comments

paper code

Abstract

새로운 Sequence 모델 제안 !! ( time-series 데이터에 유리한 구조 제안! ) Linear State-Space Layers (LSSL) 은 sequence mapping 를 간단하게 수행한다.
linear continuous-time state-space representation 를 simulate 하는 것이 포인트

이론적으로도 LSSL models 이 temporal CNN, RNN, neural differential equations (NDE) 모델군의 장점들을 갖고 있다고 주장한다.

State Space Representation 을 잘 모른다면 #53 참조

Technical Background

Approximations of differential equations

는 적분 때리면 와 같다. 적분 진행된 수식은 x를 위한 approximation을 들고 있는 방식으로 numerically 풀어낼 수 있다. Picard iteration 이 많이 사용된다.

Discretization

discrete times 에 대해서 들은 를 iterate 해서 구할 수 있다. 다른 방식의 RHS signal approximation 을 사용하면 다른 discretization scheme 를 만들어 낸다. 저자들은 linear ODE 에 특화된 generalized bilinear transform (GBT) 를 제안한다. 에 대해 GBT update 는 다음과 같다. image 여기서 alpha 에 1/2 를 넣고 다음과 같이 discrete state-space model을 뽑아낸다. image

∆t as a timescale

대부분의 모델에서 캡처할 수 있는 종속성의 길이는 1/∆t 에 비례한다. 그래서 ∆t 를 timescale 로 잡는다. 이건 continuous-time ODE를 discrete-time recurrence로 변환하는 데 있어 필수적이고, 대부분의 ODE 기반 RNN 모델은 이를 non-trainalbe hyperaparameter로 가지고 있다.

Continuous-time memory

u(t) ==> input function ω(t) ==> a fixed probability measure (HiPPO 에서 꼭 필요한 measurement) 에다가 Sequence of N basis function 들이 있다 해보자 time t 에서, t 이전 step history u 는 basis 로 projection 될 수 있음. output은 coefficient x(t) 로 정의 한다. 이런 mapping 을 수행했던게 HiPPO 이다.

LSSL

LSSL 은 앞서 말했듯이 3가지 구조의 장점을 가져온다.

  • LSSLs are recurrent.
    • discrete step-size ∆t 에 대해 LSSL은 linear recurrence 형태로 간단히 discretize 될 수 있다.
    • 매 time-step 마다 inference 중 constant memory 를 갖는 stateful recurrent model 을 simulate 한다.
  • LSSLs are convolutional.
    • The linear time-invariant systems 은 continuous convolution으로 나타낼 수 있다.
    • discrete time version은 convolution을 이용해서 parallelize 도 할 수 있다.
  • LSSLs are continuous-time.
    • The LSSL 그 자체가 differential equation 이다.
    • 고유한 application을 수행할 수 있다
      • continuous processes 를 simulate
      • missing data 처리
      • 다른 timescale을 적용하기

image

파라미터는 A, B, C, D, ∆t 이다. A, B 는 gbt 를 통해 구하게 되고, C, D 는 learnable parameter 이다.

LSSL 은 state matrix A 를 잘 선택하는 것이 Long Range Dependency (LRD) 에 좋은 영향이 있다고 한다. state-space representation 관점에서 보면 당연한 것.

computation bottleneck이 강하게 있는데, 이를 해결한 게 #52 이다. 4번 수식을 풀 때, iteration 으로 simulate 해서 풀게 되는데 여기서 MatMul 이 오래걸린다. 추가로 convolution 단의 krylov function 도 돌릴 때 오래 걸린다.

코드 원본은 https://github.com/HazyResearch/state-spaces/blob/main/src/models/sequence/ss/standalone/lssl.py 에 있음

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

def triangular_toeplitz_multiply(u, v):
    n = u.shape[-1]
    u_expand = F.pad(u, (0, n))
    v_expand = F.pad(v, (0, n))
    u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1)
    v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1)
    uv_f = u_f * v_f
    output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n]
    return output

def krylov(L, A, b):
    """ Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick.  """

    x = b.unsqueeze(-1) # (..., N, 1)
    A_ = A

    done = L == 1
    while not done:
        # Save memory on last iteration
        l = x.shape[-1]
        if L - l <= l:
            done = True
            _x = x[..., :L-l]
        else: _x = x

        _x = A_ @ _x
        x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes
        if not done: A_ = A_ @ A_

    assert x.shape[-1] == L

    x = x.contiguous()
    return x


def hippo(N):
    """ Return the HiPPO-LegT state matrices """
    Q = np.arange(N, dtype=np.float64)
    R = (2*Q + 1) ** .5
    j, i = np.meshgrid(Q, Q)
    A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
    B = R[:, None]
    A = -A
    return A, B

class AdaptiveTransition(nn.Module):
    """ General class which supports discretizing a state space equation x' = Ax + Bu
    Different subclasses can compute the forward and inverse mults in different ways
    This particular method is specialized to the HiPPO-LegT transition for simplicity
    """

    def __init__(self, N):
        """
        N: State space order, size of HiPPO matrix
        """

        super().__init__()
        self.N = N

        A, B = hippo(N)
        A = torch.as_tensor(A, dtype=torch.float)
        B = torch.as_tensor(B, dtype=torch.float)[:, 0]
        self.register_buffer('A', A)
        self.register_buffer('B', B)

        # Register some common buffers
        # (helps make sure every subclass has access to them on the right device)
        I = torch.eye(N)
        self.register_buffer('I', I)


    def forward_mult(self, u, delta):
        """ Computes (I + delta A) u
        A: (n, n)
        u: (..., n)
        delta: (...) or scalar
        output: (..., n)
        """
        raise NotImplementedError

    def inverse_mult(self, u, delta): # TODO swap u, delta everywhere
        """ Computes (I - d A)^-1 u """
        raise NotImplementedError

    def forward_diff(self, d, u, v):
        """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d B v
        d: (...)
        u: (..., N)
        v: (...)
        """
        v = d * v
        v = v.unsqueeze(-1) * self.B
        x = self.forward_mult(u, d)
        x = x + v
        return x

    def backward_diff(self, d, u, v):
        """ Computes the 'forward diff' or Euler update rule: (I - d A)^-1 u + d (I - d A)^-1 B v
        d: (...)
        u: (..., N)
        v: (...)
        """
        v = d * v
        v = v.unsqueeze(-1) * self.B
        x = u + v
        x = self.inverse_mult(x, d)
        return x

    def bilinear(self, dt, u, v, alpha=.5):
        """ Computes the bilinear (aka trapezoid or Tustin's) update rule.
        (I - d/2 A)^-1 (I + d/2 A) u + d B (I - d/2 A)^-1 B v
        dt: (...)
        u: (..., N)
        v: (...)
        """
        x = self.forward_mult(u, (1-alpha)*dt)
        v = dt * v
        v = v.unsqueeze(-1) * self.B
        x = x + v
        x = self.inverse_mult(x, (alpha)*dt)
        return x

    def gbt_A(self, dt, alpha=.5):
        """ Compute the transition matrices associated with bilinear transform
        dt: (...)
        returns: (..., N, N)
        """
        # solve (N, ...) parallel problems of size N
        dims = len(dt.shape)
        I = self.I.view([self.N] + [1]*dims + [self.N])
        A = self.bilinear(dt, I, dt.new_zeros(*dt.shape), alpha=alpha) # (N, ..., N)
        A = rearrange(A, 'n ... m -> ... m n', n=self.N, m=self.N)
        return A

    def gbt_B(self, dt, alpha=.5):
        B = self.bilinear(dt, dt.new_zeros(*dt.shape, self.N), dt.new_ones(1), alpha=alpha) # (..., N)
        return B


class LegTTransitionDense(AdaptiveTransition):
    """ Slower and memory inefficient version via manual matrix mult/inv """

    def forward_mult(self, u, delta, transpose=False):
        if isinstance(delta, torch.Tensor):
            delta = delta.unsqueeze(-1)
        A_ = self.A.transpose(-1, -2) if transpose else self.A
        x = (A_ @ u.unsqueeze(-1)).squeeze(-1)
        x = u + delta * x

        return x

    def inverse_mult(self, u, delta, transpose=False):
        """ Computes (I - d A)^-1 u """

        if isinstance(delta, torch.Tensor):
            delta = delta.unsqueeze(-1).unsqueeze(-1)
        _A = self.I - delta * self.A
        if transpose: _A = _A.transpose(-1, -2)

        # x = torch.linalg.solve(_A, u.unsqueeze(-1)).squeeze(-1) # this can run out of memory
        xs = []
        for _A_, u_ in zip(*torch.broadcast_tensors(_A, u.unsqueeze(-1))):
            x_ = torch.linalg.solve(_A_, u_[...,:1]).squeeze(-1)
            xs.append(x_)
        x = torch.stack(xs, dim=0)

        return x


class StateSpace(nn.Module):
    """ Computes a state space layer.
    Simulates the state space ODE
    x' = Ax + Bu
    y = Cx + Du
    - A single state space computation maps a 1D function u to a 1D function y
    - For an input of H features, each feature is independently run through the state space
      with a different timescale / sampling rate / discretization step size.
    """
    def __init__(
            self,
            d, # hidden dimension, also denoted H below
            order=-1, # order of the state space, i.e. dimension N of the state x
            dt_min=1e-3, # discretization step size - should be roughly inverse to the length of the sequence
            dt_max=1e-1,
            channels=1, # denoted by M below
            dropout=0.0,
        ):
        super().__init__()
        self.H = d
        self.N = order if order > 0 else d

        # Construct transition
        # self.transition = LegTTransition(self.N) # NOTE use this line for speed
        self.transition = LegTTransitionDense(self.N)

        self.M = channels

        self.C = nn.Parameter(torch.randn(self.H, self.M, self.N))
        self.D = nn.Parameter(torch.randn(self.H, self.M))

        # Initialize timescales
        log_dt = torch.rand(self.H) * (math.log(dt_max)-math.log(dt_min)) + math.log(dt_min)
        self.register_buffer('dt', torch.exp(log_dt))

        # Cached Krylov (convolution filter)
        self.k = None

        self.activation_fn = nn.GELU()
        self.dropout = nn.Dropout(dropout)

        self.output_linear = nn.Linear(self.M * self.H, self.H)


    def forward(self, u): # absorbs return_output and transformer src mask
        """
        u: (L, B, H) or (length, batch, hidden)
        Returns: (L, B, H)
        """

        # We need to compute the convolution filter if first pass or length changes
        if self.k is None or u.shape[0] > self.k.shape[-1]:
            A = self.transition.gbt_A(self.dt) # (..., N, N)
            B = self.transition.gbt_B(self.dt) # (..., N)
            self.k = krylov(u.shape[0], A, B) # (H, N, L)

        # Convolution
        y = self.linear_system_from_krylov(u, self.k[..., :u.shape[0]]) # (L, B, H, M)

        # Dropout
        y = self.dropout(self.activation_fn(y))

        # Linear
        y = rearrange(y, 'l b h m -> l b (h m)') # (L, B, H*M)
        y = self.output_linear(y) # (L, B, H)
        return y

    def linear_system_from_krylov(self, u, k):
        """
        Computes the state-space system y = Cx + Du from Krylov matrix K(A, B)
        u: (L, B, ...) ... = H
        C: (..., M, N) ... = H
        D: (..., M)
        k: (..., N, L) Krylov matrix representing b, Ab, A^2b...
        y: (L, B, ..., M)
        """


        k = self.C @ k # (..., M, L)

        k = rearrange(k, '... m l -> m ... l')
        k = k.to(u) # if training in half precision, need to go back to float32 for the fft
        k = k.unsqueeze(1) # (M, 1, ..., L)

        v = u.unsqueeze(-1).transpose(0, -1) # (1, B, ..., L)
        y = triangular_toeplitz_multiply(k, v) # (M, B, ..., L)
        y = y.transpose(0, -1) # (L, B, ..., M)
        y = y + u.unsqueeze(-1) * self.D # (L, B, ..., M)
        return y

Results

image

WIP (Too hard. It takes time)

dhkim0225 avatar Nov 08 '21 06:11 dhkim0225