Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Ryu1845 / lru.py
Last active December 13, 2023 13:17
Linear Recurrent Unit (LRU) from the paper ["Resurrecting Recurrent Neural Networks for Long Sequences"](https://arxiv.org/abs/2303.06349)
"""
Simplified Implementation of the Linear Recurrent Unit
------------------------------------------------------
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU).
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$
according to the following formula (and efficiently parallelized using an associative scan):
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$,
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$.
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2
to account for halving the state variance by taking the real part of the output projection.