Last active
June 4, 2017 18:21
-
-
Save mirceamironenco/618711111e612109332e30d56c037ac2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class Layer(object): | |
| def __init__(self, scope="dense_layer"): | |
| self.scope = scope | |
| def __call__(self, x, **kwargs): | |
| with tf.name_scope(self.scope): | |
| return self.output(x, **kwargs) | |
| def output(self, x, **kwargs): | |
| raise NotImplementedError() | |
| @staticmethod | |
| def variable_summary(var): | |
| """ | |
| Used to log summaries of variables. | |
| """ | |
| mean = tf.reduce_mean(var) | |
| tf.summary.scalar(var.name + '_mean', mean) | |
| tf.summary.histogram(var.name, var) | |
| stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) | |
| tf.summary.scalar(var.name + '_stddev', stddev) | |
| class Net: | |
| @staticmethod | |
| def cross_entropy(y_logit, y_true, summary=True): | |
| with tf.name_scope('cross_entropy'): | |
| ce = tf.nn.softmax_cross_entropy_with_logits(logits=y_logit, | |
| labels=y_true) | |
| mean_ce = tf.reduce_mean(ce) | |
| if summary: | |
| tf.summary.scalar('cross_entropy', mean_ce) | |
| return mean_ce | |
| def loss(self, y_pred, y_true, zero_step, max_step, task='cl'): | |
| kl_div = self.kl_loss() | |
| if task == 'cl': | |
| nll = Net.cross_entropy(y_pred, y_true) | |
| else: | |
| nll = Net.mse(y_pred, y_true) | |
| if self.anneal: | |
| with tf.name_scope('annealing'): | |
| beta_curr = (tf.cast(self.global_step, | |
| tf.float32) - zero_step) / max_step | |
| beta_t = tf.maximum(tf.cast(beta_curr, tf.float32), 0.) | |
| annealing = tf.minimum(1., tf.cond( | |
| self.global_step < zero_step, | |
| lambda: tf.zeros((1,))[0], | |
| lambda: beta_t)) | |
| tf.summary.scalar('annealing_beta', annealing) | |
| else: | |
| annealing = 1. | |
| with tf.name_scope('lowerbound'): | |
| lowerbound = nll + annealing * kl_div | |
| tf.summary.scalar('lower bound', lowerbound) | |
| return lowerbound | |
| def kl_loss(self): | |
| kl_div = 0 | |
| for i, layer in enumerate(self.layers): | |
| if type(layer) not in [StudentDense, StudentConv]: | |
| continue | |
| layer_kl = layer.layer_kl() | |
| tf.summary.scalar('Layer{}_KL'.format(i + 1), layer_kl) | |
| kl_div += layer_kl | |
| kldiv = (-1. / self.ds_size) * kl_div | |
| tf.summary.scalar('kl_div', kl_div) | |
| return kldiv | |
| class StudentDense(Layer): | |
| def __init__(self, fan_out, gam_v, scope='student_dense_', | |
| var_summaries=False): | |
| """ | |
| Student prior: P(w) = \int N(w|0,t) Gam(t|v/2,v/2) dt | |
| Approx. posterior: q(w,t) = q(w|t)q(t) = N(w|\mu, t) LN(t|\mu_t, \sigma^{2}_t) | |
| """ | |
| self.fan_out = fan_out | |
| self.gam_v = gam_v | |
| self.var_smmaries = var_summaries | |
| self.params = None | |
| super(StudentDense, self).__init__(scope=scope) | |
| def initialize(self, x): | |
| fan_in = x.get_shape()[-1].value | |
| with tf.variable_scope(self.scope): | |
| self.mu_w = he_init((fan_in, self.fan_out), name='mu_w', | |
| std_scale=1e-6) | |
| self.mu_t = he_init((fan_in, self.fan_out), name='mu_t', | |
| mu=-11.) | |
| self.logsigma_t = he_init((fan_in, self.fan_out), mu=-11., | |
| std_scale=1e-9, name='logsigma_t') | |
| self.b = tf.Variable(tf.constant(0.0, shape=[self.fan_out]), | |
| name='bias') | |
| self.gam_v = tf.constant(self.gam_v, | |
| shape=(fan_in, self.fan_out), | |
| name='gamma_v') | |
| self.params = [self.mu_w, self.gam_v, self.mu_t, | |
| self.logsigma_t] | |
| if self.var_smmaries: | |
| for var in self.params: | |
| variable_summary(var) | |
| def layer_kl(self): | |
| v = self.gam_v | |
| v2 = self.gam_v / 2.0 | |
| mu_t, sigma_t = self.mu_t, tf.exp(self.logsigma_t) | |
| kl_qt = .5 * (-v * mu_t - tf.log(2 * np.pi) - v * tf.log(v2) | |
| + v * tf.exp(mu_t + sigma_t / 2.) - 1.) + tf.lgamma( | |
| v2) - self.logsigma_t | |
| kl_qwt = (tf.square(self.mu_w) / 2.) * tf.exp( | |
| -self.mu_w + .5 * tf.square(sigma_t)) | |
| return tf.reduce_sum(-kl_qt) + tf.reduce_sum(-kl_qwt) | |
| def output(self, x, deterministic=False, **kwargs): | |
| if not self.params: | |
| self.initialize(x) | |
| eps_t = tf.random_normal(tf.shape(self.mu_t), 0., 1.) | |
| sigma_t = tf.exp(self.logsigma_t) | |
| if not deterministic: | |
| t = tf.exp(self.mu_t + sigma_t * eps_t) | |
| else: | |
| t = tf.exp(self.mu_t) | |
| t = clip_val(t, top=0.3) | |
| eps_w = tf.random_normal(tf.shape(self.mu_w), 0., 1.) | |
| w = self.mu_w + t * eps_w if not deterministic else self.mu_w | |
| out = tf.matmul(x, w) + self.b | |
| tf.summary.histogram(self.scope + '/output', out) | |
| return out | |
| class StudentConv(Layer): | |
| def __init__(self, filter_shape, nfilters, gam_v, padding='SAME', | |
| strides=(1, 1, 1, 1), | |
| scope='student_conv_', var_summaries=False): | |
| assert padding in ['SAME', 'VALID'] | |
| self.filter_shape = filter_shape | |
| self.nfilters = nfilters | |
| self.padding = padding | |
| self.strides = strides | |
| self.gam_v = gam_v | |
| self.var_summaries = var_summaries | |
| self.params = None | |
| super(StudentConv, self).__init__(scope=scope) | |
| def initialize(self, x): | |
| input_channels = x.get_shape()[-1].value | |
| kernel_shape = list(self.filter_shape) + [input_channels, | |
| self.nfilters] | |
| with tf.variable_scope(self.scope): | |
| self.mu_w = he_init(kernel_shape, name='mu_w', std_scale=1e-6) | |
| self.mu_t = he_init(kernel_shape, name='mu_t', mu=-11.) | |
| self.logsigma_t = he_init(kernel_shape, mu=-11., | |
| std_scale=1e-9, name='logsigma_t') | |
| self.b = tf.Variable(tf.constant(0.0, shape=kernel_shape[-1:]), | |
| name='bias') | |
| self.gam_v = tf.constant(self.gam_v, shape=kernel_shape, | |
| name='gamma_v') | |
| self.params = [self.mu_w, self.gam_v, self.mu_t, | |
| self.logsigma_t] | |
| if self.var_summaries: | |
| for var in self.params: | |
| variable_summary(var) | |
| def layer_kl(self): | |
| v = self.gam_v | |
| v2 = self.gam_v / 2. | |
| mu_t, sigma_t = self.mu_t, tf.exp(self.logsigma_t) | |
| kl_qt = .5 * (-v * mu_t - tf.log(2 * np.pi) - v * tf.log(v2) | |
| + v * tf.exp(mu_t + sigma_t / 2.) - 1.) + tf.lgamma( | |
| v2) - self.logsigma_t | |
| kl_qwt = (tf.square(self.mu_w) / 2.) * tf.exp( | |
| -self.mu_w + .5 * tf.square(sigma_t)) | |
| return tf.reduce_sum(-kl_qt) + tf.reduce_sum(-kl_qwt) | |
| def output(self, x, deterministic=False, **kwargs): | |
| if not self.params: | |
| self.initialize(x) | |
| eps_t = tf.random_normal(tf.shape(self.mu_t), 0., 1.) | |
| sigma_t = tf.exp(self.logsigma_t) | |
| if not deterministic: | |
| t = tf.exp(self.mu_t + sigma_t * eps_t) | |
| else: | |
| t = tf.exp(self.mu_t) | |
| t = clip_val(t, top=0.3) | |
| eps_w = tf.random_normal(tf.shape(self.mu_w), 0., 1.) | |
| w = self.mu_w + t * eps_w if not deterministic else self.mu_w | |
| out = tf.nn.conv2d(x, filter=w, strides=self.strides, | |
| padding=self.padding, | |
| use_cudnn_on_gpu=True) + self.b | |
| tf.summary.histogram(self.scope + '/output', out) | |
| return out | |
| def variable_summary(var): | |
| """ | |
| Log summaries of variables. | |
| """ | |
| mean = tf.reduce_mean(var) | |
| tf.summary.scalar(var.name + '_mean', mean) | |
| tf.summary.histogram(var.name, var) | |
| stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) | |
| tf.summary.scalar(var.name + '_stddev', stddev) | |
| def he_init(shape, mu=0., name='weights', he_type='he2', std_scale=1., | |
| trainable=True): | |
| in_size, out_size = init_shapes(shape) | |
| if he_type == 'he1': | |
| stddev = tf.cast(tf.sqrt((2. / in_size)), tf.float32) | |
| elif he_type == 'he2': | |
| stddev = tf.cast(tf.sqrt(4. / (in_size + out_size)), tf.float32) | |
| else: | |
| raise Exception() | |
| init_w = tf.random_normal(shape, mean=mu, stddev=stddev * std_scale) | |
| return tf.Variable(init_w, name=name, trainable=trainable) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment