Created
September 30, 2017 02:35
-
-
Save w-garcia/3b39e62998a7a33afc086273e22bc117 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample_from_discretized_mix_logistic(l, nr_mix): | |
| ls = int_shape(l) | |
| xs = ls[:-1] + [3] | |
| # unpack parameters | |
| # N different logistic models | |
| logit_probs = l[:, :, :, :nr_mix] | |
| l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3]) | |
| print("l:") | |
| print(l) | |
| # sample mixture indicator from softmax | |
| logit_probs_rand_noise = tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))) | |
| indices_softmax = tf.argmax(logit_probs - logit_probs_rand_noise, 3) | |
| print("indices_softmax:") | |
| print(indices_softmax) | |
| # Get the max mixture from each pixel location | |
| sel = tf.one_hot(indices_softmax, depth=nr_mix, dtype=tf.float32) | |
| sel = tf.reshape(sel, xs[:-1] + [1, nr_mix]) | |
| print("sel:") | |
| print(sel) | |
| # select logistic parameters | |
| # Select from nr_mix tensors the one with max according to sel | |
| means = tf.reduce_sum(l[:, :, :, :, :nr_mix] * sel, 4) | |
| print("means:") | |
| print(means) | |
| log_scales = tf.maximum(tf.reduce_sum( | |
| l[:, :, :, :, nr_mix:2 * nr_mix] * sel, 4), -7.) | |
| coeffs = tf.reduce_sum(tf.nn.tanh( | |
| l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, 4) | |
| # sample from logistic & clip to interval | |
| # we don't actually round to the nearest 8bit value when sampling | |
| u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5) | |
| print("u:") | |
| print(u) | |
| # Probability matrix for each channel across all pixel positions | |
| x = means + tf.exp(log_scales) * (tf.log(u) - tf.log(1. - u)) | |
| print("x:") | |
| print(x) | |
| print("coeffs:") | |
| print(coeffs) | |
| # Max proability value of red bounded in [-1, 1] | |
| r = x[:, :, :, 0] | |
| x0 = tf.minimum(tf.maximum(r, -1.), 1.) # clamp | |
| # Max proability value of green bounded in [-1, 1] | |
| g = x[:, :, :, 1] + coeffs[:, :, :, 0] * x0 | |
| x1 = tf.minimum(tf.maximum(g, -1.), 1.) # clamp | |
| # Max proability value of blue bounded in [-1, 1] | |
| b = x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1 | |
| x2 = tf.minimum(tf.maximum(b, -1.), 1.) # clamp | |
| return tf.concat([ | |
| tf.reshape(x0, xs[:-1] + [1]), | |
| tf.reshape(x1, xs[:-1] + [1]), | |
| tf.reshape(x2, xs[:-1] + [1])], 3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment