Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Last active December 13, 2023 13:17
Show Gist options
  • Select an option

  • Save Ryu1845/7e78da4baa8925b4de482969befa949d to your computer and use it in GitHub Desktop.

Select an option

Save Ryu1845/7e78da4baa8925b4de482969befa949d to your computer and use it in GitHub Desktop.

Revisions

  1. Ryu1845 revised this gist Mar 19, 2023. 1 changed file with 0 additions and 61 deletions.
    61 changes: 0 additions & 61 deletions lru_flax.py
    Original file line number Diff line number Diff line change
    @@ -1,61 +0,0 @@
    import jax
    import jax.numpy as jnp
    from flax import linen as nn

    parallel_scan = jax.lax.associative_scan
    # Randomness.
    seed = 0
    root_key = jax.random.PRNGKey(seed=seed)

    class LRU(nn.Module)
    N: int # state dimension
    H: int # model dimension
    r_min: int = 0
    r_max: int = 1
    max_phase: float = 6.28


    @nn.compact
    def __call__(self, input_sequence):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
    # Initialize parameters of the LRU layer."""
    N, H, r_min, r_max, max_phase = self.N, self.H, self.r_min, self.r_max, self.max_phase
    # Keys
    u1_key, u2_key, B_re_key, B_im_key, C_re_key, C_im_key, D_key = jax.random.split(key=root_key, num=7)
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = jax.random.uniform(u1_key, shape = (N,))
    u2 = jax.random.uniform(u2_key, shape = (N,))
    nu_log = jnp.log(-0.5*jnp.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = jnp.log(max_phase*u2)

    # Glorot initialized Input/Output projection matrices
    B_re = jnp.random.normal(B_re_key, shape=(N,H))/jnp.sqrt(2*H)
    B_im = jnp.random.normal(B_im_key, shape=(N,H))/jnp.sqrt(2*H)
    C_re = jnp.random.normal(C_re_key, shape=(H,N))/jnp.sqrt(N)
    C_im = jnp.random.normal(C_im_key, shape=(H,N))/jnp.sqrt(N)
    D = jnp.random.normal(D_key, shape=(H,))

    # Normalization factor
    diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    gamma_log = jnp.log(jnp.sqrt(1-jnp.abs(diag_lambda)**2))

    # Materializing the diagonal of Lambda and projections
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    C = C_re + 1j*C_im

    # Running the LRU + output projection
    # For details on parallel scan, check discussion in Smith et al (2022).
    Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
    Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements)

    def binary_operator_diag(element_i, element_j):
    """Binary operator for parallel scan of linear recurrence."""
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j
    _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)
    return y
  2. Ryu1845 revised this gist Mar 19, 2023. 1 changed file with 18 additions and 14 deletions.
    32 changes: 18 additions & 14 deletions lru_flax.py
    Original file line number Diff line number Diff line change
    @@ -1,9 +1,11 @@
    import jax
    import jax.numpy as jnp
    import numpy as np
    from flax import linen as nn
    parallel_scan = jax.lax.associative_scan

    parallel_scan = jax.lax.associative_scan
    # Randomness.
    seed = 0
    root_key = jax.random.PRNGKey(seed=seed)

    class LRU(nn.Module)
    N: int # state dimension
    @@ -18,23 +20,25 @@ def __call__(self, input_sequence):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
    # Initialize parameters of the LRU layer."""
    N, H, r_min, r_max, max_phase = self.N, self.H, self.r_min, self.r_max, self.max_phase
    # Keys
    u1_key, u2_key, B_re_key, B_im_key, C_re_key, C_im_key, D_key = jax.random.split(key=root_key, num=7)
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = np.random.uniform(size = (N,))
    u2 = np.random.uniform(size = (N,))
    nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = np.log(max_phase*u2)
    u1 = jax.random.uniform(u1_key, shape = (N,))
    u2 = jax.random.uniform(u2_key, shape = (N,))
    nu_log = jnp.log(-0.5*jnp.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = jnp.log(max_phase*u2)

    # Glorot initialized Input/Output projection matrices
    B_re = np.random.normal(size=(N,H))/np.sqrt(2*H)
    B_im = np.random.normal(size=(N,H))/np.sqrt(2*H)
    C_re = np.random.normal(size=(H,N))/np.sqrt(N)
    C_im = np.random.normal(size=(H,N))/np.sqrt(N)
    D = np.random.normal(size=(H,))
    B_re = jnp.random.normal(B_re_key, shape=(N,H))/jnp.sqrt(2*H)
    B_im = jnp.random.normal(B_im_key, shape=(N,H))/jnp.sqrt(2*H)
    C_re = jnp.random.normal(C_re_key, shape=(H,N))/jnp.sqrt(N)
    C_im = jnp.random.normal(C_im_key, shape=(H,N))/jnp.sqrt(N)
    D = jnp.random.normal(D_key, shape=(H,))

    # Normalization factor
    diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log))
    gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2))
    diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    gamma_log = jnp.log(jnp.sqrt(1-jnp.abs(diag_lambda)**2))

    # Materializing the diagonal of Lambda and projections
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    @@ -54,4 +58,4 @@ def binary_operator_diag(element_i, element_j):
    return a_j * a_i, a_j * bu_i + bu_j
    _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)
    return y
    return y
  3. Ryu1845 revised this gist Mar 19, 2023. 1 changed file with 9 additions and 2 deletions.
    11 changes: 9 additions & 2 deletions lru_flax.py
    Original file line number Diff line number Diff line change
    @@ -6,11 +6,18 @@


    class LRU(nn.Module)
    N: int # state dimension
    H: int # model dimension
    r_min: int = 0
    r_max: int = 1
    max_phase: float = 6.28


    @nn.compact
    def __call__(self, N, H, r_min=0, r_max=1, max_phase=6.28):
    def __call__(self, input_sequence):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
    # Initialize parameters of the LRU layer."""
    # N: state dimension, H: model dimension
    N, H, r_min, r_max, max_phase = self.N, self.H, self.r_min, self.r_max, self.max_phase
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = np.random.uniform(size = (N,))
  4. Ryu1845 revised this gist Mar 19, 2023. 2 changed files with 51 additions and 1 deletion.
    2 changes: 1 addition & 1 deletion lru.py
    Original file line number Diff line number Diff line change
    @@ -64,7 +64,7 @@ def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28):


    def binary_operator_diag(element_i, element_j):
    "Binary operator for parallel scan of linear recurrence."
    """Binary operator for parallel scan of linear recurrence."""
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j
    50 changes: 50 additions & 0 deletions lru_flax.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,50 @@
    import jax
    import jax.numpy as jnp
    import numpy as np
    from flax import linen as nn
    parallel_scan = jax.lax.associative_scan


    class LRU(nn.Module)
    @nn.compact
    def __call__(self, N, H, r_min=0, r_max=1, max_phase=6.28):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""
    # Initialize parameters of the LRU layer."""
    # N: state dimension, H: model dimension
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = np.random.uniform(size = (N,))
    u2 = np.random.uniform(size = (N,))
    nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = np.log(max_phase*u2)

    # Glorot initialized Input/Output projection matrices
    B_re = np.random.normal(size=(N,H))/np.sqrt(2*H)
    B_im = np.random.normal(size=(N,H))/np.sqrt(2*H)
    C_re = np.random.normal(size=(H,N))/np.sqrt(N)
    C_im = np.random.normal(size=(H,N))/np.sqrt(N)
    D = np.random.normal(size=(H,))

    # Normalization factor
    diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log))
    gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2))

    # Materializing the diagonal of Lambda and projections
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    C = C_re + 1j*C_im

    # Running the LRU + output projection
    # For details on parallel scan, check discussion in Smith et al (2022).
    Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
    Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements)

    def binary_operator_diag(element_i, element_j):
    """Binary operator for parallel scan of linear recurrence."""
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j
    _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)
    return y
  5. Ryu1845 revised this gist Mar 19, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion lru.py
    Original file line number Diff line number Diff line change
    @@ -64,7 +64,7 @@ def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28):


    def binary_operator_diag(element_i, element_j):
    # Binary operator for parallel scan of linear recurrence.
    "Binary operator for parallel scan of linear recurrence."
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j
  6. Ryu1845 created this gist Mar 19, 2023.
    70 changes: 70 additions & 0 deletions lru.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,70 @@
    """
    Simplified Implementation of the Linear Recurrent Unit
    ------------------------------------------------------
    We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
    The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
    according to the following formula (and efficiently parallelized using an associative scan):
    $x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
    and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
    In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
    to account for halving the state variance by taking the real part of the output projection.
    $D$ is random $H$-dimensional and mutiplies elementwise each $u_k$, where $k$ is the timestamp.
    $\Lambda$ is initialized with the help of Lemma, with phase potentially restricted to a thin slice
    """
    import jax
    import jax.numpy as jnp
    import numpy as np
    parallel_scan = jax.lax.associative_scan


    def forward(lru_parameters, input_sequence):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""

    # All LRU parameters
    nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = lru_parameters

    # Materializing the diagonal of Lambda and projections
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
    B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    C = C_re + 1j*C_im

    # Running the LRU + output projection
    # For details on parallel scan, check discussion in Smith et al (2022).
    Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
    Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements)
    _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)

    return y

    def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28):
    """Initialize parameters of the LRU layer."""

    # N: state dimension, H: model dimension
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = np.random.uniform(size = (N,))
    u2 = np.random.uniform(size = (N,))
    nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2))
    theta_log = np.log(max_phase*u2)

    # Glorot initialized Input/Output projection matrices
    B_re = np.random.normal(size=(N,H))/np.sqrt(2*H)
    B_im = np.random.normal(size=(N,H))/np.sqrt(2*H)
    C_re = np.random.normal(size=(H,N))/np.sqrt(N)
    C_im = np.random.normal(size=(H,N))/np.sqrt(N)
    D = np.random.normal(size=(H,))

    # Normalization factor
    diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log))
    gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2))

    return nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log


    def binary_operator_diag(element_i, element_j):
    # Binary operator for parallel scan of linear recurrence.
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j