"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state""" import tensorflow as tf, numpy as np RNNCell = tf.nn.rnn_cell.RNNCell class BNLSTMCell(RNNCell): '''Batch normalized LSTM as described in arxiv.org/abs/1603.09025''' def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95): """ * max bn steps is the maximum number of steps for which to store separate population stats """ self._num_units = num_units self._training = is_training_tensor self._max_bn_steps = max_bn_steps self._activation = activation self._decay = decay self._initial_scale = 0.1 @property def state_size(self): return (self._num_units, self._num_units, 1) @property def output_size(self): return self._num_units def _batch_norm(self, x, name_scope, step, epsilon=1e-5, no_offset=False, set_forget_gate_bias=False): '''Assume 2d [batch, values] tensor''' with tf.variable_scope(name_scope): size = x.get_shape().as_list()[1] scale = tf.get_variable('scale', [size], initializer=tf.constant_initializer(self._initial_scale)) if no_offset: offset = 0 elif set_forget_gate_bias: offset = tf.get_variable('offset', [size], initializer=offset_initializer()) else: offset = tf.get_variable('offset', [size], initializer=tf.zeros_initializer) pop_mean_all_steps = tf.get_variable('pop_mean', [self._max_bn_steps, size], initializer=tf.zeros_initializer, trainable=False) pop_var_all_steps = tf.get_variable('pop_var', [self._max_bn_steps, size], initializer=tf.ones_initializer(), trainable=False) step = tf.minimum(step, self._max_bn_steps - 1) pop_mean = pop_mean_all_steps[step] pop_var = pop_var_all_steps[step] batch_mean, batch_var = tf.nn.moments(x, [0]) def batch_statistics(): pop_mean_new = pop_mean * self._decay + batch_mean * (1 - self._decay) pop_var_new = pop_var * self._decay + batch_var * (1 - self._decay) with tf.control_dependencies([pop_mean.assign(pop_mean_new), pop_var.assign(pop_var_new)]): return tf.nn.batch_normalization(x, batch_mean, batch_var, offset, scale, epsilon) def population_statistics(): return tf.nn.batch_normalization(x, pop_mean, pop_var, offset, scale, epsilon) return tf.cond(self._training, batch_statistics, population_statistics) def __call__(self, x, state, scope=None): with tf.variable_scope(scope or type(self).__name__): c, h, step = state _step = tf.squeeze(tf.gather(tf.cast(step, tf.int32), 0)) x_size = x.get_shape().as_list()[1] W_xh = tf.get_variable('W_xh', [x_size, 4 * self._num_units], initializer=orthogonal_lstm_initializer()) W_hh = tf.get_variable('W_hh', [self._num_units, 4 * self._num_units], initializer=orthogonal_lstm_initializer()) hh = tf.matmul(h, W_hh) xh = tf.matmul(x, W_xh) bn_hh = self._batch_norm(hh, 'hh', _step, set_forget_gate_bias=True) bn_xh = self._batch_norm(xh, 'xh', _step, no_offset=True) hidden = bn_xh + bn_hh f, i, o, j = tf.split(1, 4, hidden) new_c = c * tf.sigmoid(f) + tf.sigmoid(i) * self._activation(j) bn_new_c = self._batch_norm(new_c, 'c', _step) new_h = self._activation(bn_new_c) * tf.sigmoid(o) return new_h, (new_c, new_h, step+1) def orthogonal_lstm_initializer(): def orthogonal(shape, dtype=tf.float32, partition_info=None): # taken from https://github.com/cooijmanstim/recurrent-batch-normalization # taken from https://gist.github.com/kastnerkyle/f7464d98fe8ca14f2a1a """ benanne lasagne ortho init (faster than qr approach)""" flat_shape = (shape[0], np.prod(shape[1:])) a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v # pick the one with the correct shape q = q.reshape(shape) return tf.constant(q[:shape[0], :shape[1]], dtype) return orthogonal def offset_initializer(): def _initializer(shape, dtype=tf.float32, partition_info=None): size = shape[0] assert size % 4 == 0 size = size // 4 res = [np.ones((size)), np.zeros((size*3))] return tf.constant(np.concatenate(res, axis=0), dtype) return _initializer