Last active
December 13, 2023 13:17
-
-
Save Ryu1845/7e78da4baa8925b4de482969befa949d to your computer and use it in GitHub Desktop.
Revisions
-
Ryu1845 revised this gist
Mar 19, 2023 . 1 changed file with 0 additions and 61 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,61 +0,0 @@ -
Ryu1845 revised this gist
Mar 19, 2023 . 1 changed file with 18 additions and 14 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,9 +1,11 @@ import jax import jax.numpy as jnp from flax import linen as nn parallel_scan = jax.lax.associative_scan # Randomness. seed = 0 root_key = jax.random.PRNGKey(seed=seed) class LRU(nn.Module) N: int # state dimension @@ -18,23 +20,25 @@ def __call__(self, input_sequence): """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H).""" # Initialize parameters of the LRU layer.""" N, H, r_min, r_max, max_phase = self.N, self.H, self.r_min, self.r_max, self.max_phase # Keys u1_key, u2_key, B_re_key, B_im_key, C_re_key, C_im_key, D_key = jax.random.split(key=root_key, num=7) # Initialization of Lambda is complex valued distributed uniformly on ring # between r_min and r_max, with phase in [0, max_phase]. u1 = jax.random.uniform(u1_key, shape = (N,)) u2 = jax.random.uniform(u2_key, shape = (N,)) nu_log = jnp.log(-0.5*jnp.log(u1*(r_max**2-r_min**2) + r_min**2)) theta_log = jnp.log(max_phase*u2) # Glorot initialized Input/Output projection matrices B_re = jnp.random.normal(B_re_key, shape=(N,H))/jnp.sqrt(2*H) B_im = jnp.random.normal(B_im_key, shape=(N,H))/jnp.sqrt(2*H) C_re = jnp.random.normal(C_re_key, shape=(H,N))/jnp.sqrt(N) C_im = jnp.random.normal(C_im_key, shape=(H,N))/jnp.sqrt(N) D = jnp.random.normal(D_key, shape=(H,)) # Normalization factor diag_lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log)) gamma_log = jnp.log(jnp.sqrt(1-jnp.abs(diag_lambda)**2)) # Materializing the diagonal of Lambda and projections Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log)) @@ -54,4 +58,4 @@ def binary_operator_diag(element_i, element_j): return a_j * a_i, a_j * bu_i + bu_j _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence) return y -
Ryu1845 revised this gist
Mar 19, 2023 . 1 changed file with 9 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -6,11 +6,18 @@ class LRU(nn.Module) N: int # state dimension H: int # model dimension r_min: int = 0 r_max: int = 1 max_phase: float = 6.28 @nn.compact def __call__(self, input_sequence): """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H).""" # Initialize parameters of the LRU layer.""" N, H, r_min, r_max, max_phase = self.N, self.H, self.r_min, self.r_max, self.max_phase # Initialization of Lambda is complex valued distributed uniformly on ring # between r_min and r_max, with phase in [0, max_phase]. u1 = np.random.uniform(size = (N,)) -
Ryu1845 revised this gist
Mar 19, 2023 . 2 changed files with 51 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -64,7 +64,7 @@ def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28): def binary_operator_diag(element_i, element_j): """Binary operator for parallel scan of linear recurrence.""" a_i, bu_i = element_i a_j, bu_j = element_j return a_j * a_i, a_j * bu_i + bu_j This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,50 @@ import jax import jax.numpy as jnp import numpy as np from flax import linen as nn parallel_scan = jax.lax.associative_scan class LRU(nn.Module) @nn.compact def __call__(self, N, H, r_min=0, r_max=1, max_phase=6.28): """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H).""" # Initialize parameters of the LRU layer.""" # N: state dimension, H: model dimension # Initialization of Lambda is complex valued distributed uniformly on ring # between r_min and r_max, with phase in [0, max_phase]. u1 = np.random.uniform(size = (N,)) u2 = np.random.uniform(size = (N,)) nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2)) theta_log = np.log(max_phase*u2) # Glorot initialized Input/Output projection matrices B_re = np.random.normal(size=(N,H))/np.sqrt(2*H) B_im = np.random.normal(size=(N,H))/np.sqrt(2*H) C_re = np.random.normal(size=(H,N))/np.sqrt(N) C_im = np.random.normal(size=(H,N))/np.sqrt(N) D = np.random.normal(size=(H,)) # Normalization factor diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log)) gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2)) # Materializing the diagonal of Lambda and projections Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log)) B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1) C = C_re + 1j*C_im # Running the LRU + output projection # For details on parallel scan, check discussion in Smith et al (2022). Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0) Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence) elements = (Lambda_elements, Bu_elements) def binary_operator_diag(element_i, element_j): """Binary operator for parallel scan of linear recurrence.""" a_i, bu_i = element_i a_j, bu_j = element_j return a_j * a_i, a_j * bu_i + bu_j _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence) return y -
Ryu1845 revised this gist
Mar 19, 2023 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -64,7 +64,7 @@ def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28): def binary_operator_diag(element_i, element_j): "Binary operator for parallel scan of linear recurrence." a_i, bu_i = element_i a_j, bu_j = element_j return a_j * a_i, a_j * bu_i + bu_j -
Ryu1845 created this gist
Mar 19, 2023 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,70 @@ """ Simplified Implementation of the Linear Recurrent Unit ------------------------------------------------------ We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU). The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$ according to the following formula (and efficiently parallelized using an associative scan): $x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$, and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$. In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2 to account for halving the state variance by taking the real part of the output projection. $D$ is random $H$-dimensional and mutiplies elementwise each $u_k$, where $k$ is the timestamp. $\Lambda$ is initialized with the help of Lemma, with phase potentially restricted to a thin slice """ import jax import jax.numpy as jnp import numpy as np parallel_scan = jax.lax.associative_scan def forward(lru_parameters, input_sequence): """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H).""" # All LRU parameters nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = lru_parameters # Materializing the diagonal of Lambda and projections Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log)) B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1) C = C_re + 1j*C_im # Running the LRU + output projection # For details on parallel scan, check discussion in Smith et al (2022). Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0) Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence) elements = (Lambda_elements, Bu_elements) _, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence) return y def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=6.28): """Initialize parameters of the LRU layer.""" # N: state dimension, H: model dimension # Initialization of Lambda is complex valued distributed uniformly on ring # between r_min and r_max, with phase in [0, max_phase]. u1 = np.random.uniform(size = (N,)) u2 = np.random.uniform(size = (N,)) nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2)) theta_log = np.log(max_phase*u2) # Glorot initialized Input/Output projection matrices B_re = np.random.normal(size=(N,H))/np.sqrt(2*H) B_im = np.random.normal(size=(N,H))/np.sqrt(2*H) C_re = np.random.normal(size=(H,N))/np.sqrt(N) C_im = np.random.normal(size=(H,N))/np.sqrt(N) D = np.random.normal(size=(H,)) # Normalization factor diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log)) gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2)) return nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log def binary_operator_diag(element_i, element_j): # Binary operator for parallel scan of linear recurrence. a_i, bu_i = element_i a_j, bu_j = element_j return a_j * a_i, a_j * bu_i + bu_j