Skip to content

Instantly share code, notes, and snippets.

@scotthuang1989
Forked from spitis/bnlstm.py
Created March 11, 2018 02:55
Show Gist options
  • Select an option

  • Save scotthuang1989/b4e9aea48f25c3872edd13efcf1c364e to your computer and use it in GitHub Desktop.

Select an option

Save scotthuang1989/b4e9aea48f25c3872edd13efcf1c364e to your computer and use it in GitHub Desktop.

Revisions

  1. @spitis spitis created this gist Feb 2, 2017.
    110 changes: 110 additions & 0 deletions bnlstm.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,110 @@
    """adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state"""
    import tensorflow as tf, numpy as np
    RNNCell = tf.nn.rnn_cell.RNNCell

    class BNLSTMCell(RNNCell):
    '''Batch normalized LSTM as described in arxiv.org/abs/1603.09025'''
    def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95):
    """
    * max bn steps is the maximum number of steps for which to store separate population stats
    """
    self._num_units = num_units
    self._training = is_training_tensor
    self._max_bn_steps = max_bn_steps
    self._activation = activation
    self._decay = decay
    self._initial_scale = 0.1

    @property
    def state_size(self):
    return (self._num_units, self._num_units, 1)

    @property
    def output_size(self):
    return self._num_units

    def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False):
    '''Assume 2d [batch, values] tensor'''

    with tf.variable_scope(name_scope):
    size = x.get_shape().as_list()[1]

    scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale))
    if no_offset:
    offset = 0
    elif set_forget_gate_bias:
    offset = tf.get_variable('offset', [size], initializer=offset_initializer())
    else:
    offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer)

    pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False)
    pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False)

    step = tf.minimum(step, self._max_bn_steps - 1)

    pop_mean = pop_mean_all_steps[step]
    pop_var = pop_var_all_steps[step]

    batch_mean, batch_var = tf.nn.moments(x, [0])

    def batch_statistics():
    pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay)
    pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay)
    with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]):
    return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon)

    def population_statistics():
    return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon)

    return tf.cond(self._training, batch_statistics, population_statistics)

    def __call__(self, x, state, scope=None):
    with tf.variable_scope(scope or type(self).__name__):
    c, h, step = state
    _step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0))

    x_size = x.get_shape().as_list()[1]
    W_xh = tf.get_variable('W_xh',
    [x_size, 4 * self._num_units],
    initializer=orthogonal_lstm_initializer())
    W_hh = tf.get_variable('W_hh',
    [self._num_units, 4 * self._num_units],
    initializer=orthogonal_lstm_initializer())

    hh = tf.matmul(h, W_hh)
    xh = tf.matmul(x, W_xh)

    bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True)
    bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True)

    hidden = bn_xh + bn_hh

    f, i, o, j = tf.split(1, 4, hidden)

    new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j)
    bn_new_c = self._batch_norm(new_c, 'c', _step)

    new_h = self._activation(bn_new_c) * tf.sigmoid(o)
    return new_h, (new_c, new_h, step+1)

    def orthogonal_lstm_initializer():
    def orthogonal(shape, dtype=tf.float32, partition_info=None):
    # taken from https://github.com/cooijmanstim/recurrent-batch-normalization
    # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a
    """ benanne lasagne ortho init (faster than qr approach)"""
    flat_shape = (shape[0], np.prod(shape[1:]))
    a = np.random.normal(0.0, 1.0, flat_shape)
    u, _, v = np.linalg.svd(a, full_matrices=False)
    q = u if u.shape == flat_shape else v # pick the one with the correct shape
    q = q.reshape(shape)
    return tf.constant(q[:shape[0], :shape[1]], dtype)
    return orthogonal

    def offset_initializer():
    def _initializer(shape, dtype=tf.float32, partition_info=None):
    size = shape[0]
    assert size % 4 == 0
    size = size // 4
    res = [np.ones((size)), np.zeros((size*3))]
    return tf.constant(np.concatenate(res, axis=0), dtype)
    return _initializer