Skip to content

Instantly share code, notes, and snippets.

@24hours
Created March 3, 2016 07:14
Show Gist options
  • Select an option

  • Save 24hours/0545f92d5407bcdd3106 to your computer and use it in GitHub Desktop.

Select an option

Save 24hours/0545f92d5407bcdd3106 to your computer and use it in GitHub Desktop.
import tensorflow as tf
tf.reset_default_graph()
isTrain = tf.placeholder(tf.bool)
user_input = tf.placeholder(tf.float32)
# ema = tf.train.ExponentialMovingAverage(decay=.5)
with tf.device('/cpu:0'):
beta = tf.Variable(tf.ones([1]))
batch_mean = tf.Print(beta.assign(user_input), ['beta assign'])
ema = tf.train.ExponentialMovingAverage(decay=0.5)
ema_apply_op = ema.apply([batch_mean])
# ema_apply_op = tf.Print(ema.apply([batch_mean]), ["ema_apply_op"])
ema_mean = tf.Print(ema.average(batch_mean), ['ema_mean'])
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.Print(tf.identity(batch_mean), ["mean_var_with_update"])
# return tf.identity(batch_mean)
mean = tf.Print(tf.cond(isTrain,
mean_var_with_update,
lambda: (tf.Print(ema_mean, ["ema_mean(cond)"]))),
["evaluating mean", isTrain])
# ======= End Here ==========
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
u_input = [[2], [3], [4] ]
for u in u_input:
aa = sess.run([mean], feed_dict={user_input:u, isTrain: True })
print("Train", aa)
# for u in u_input:
# aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False })
# print("Test correct", aa)
print("Testing")
for u in u_input:
aa = sess.run([mean], feed_dict={user_input:u, isTrain: False })
print("Test", aa)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment