Skip to content

Instantly share code, notes, and snippets.

@evancasey
Created August 6, 2016 20:04
Show Gist options
  • Select an option

  • Save evancasey/51fa98e03a7c51ec1bee984c77b9ed32 to your computer and use it in GitHub Desktop.

Select an option

Save evancasey/51fa98e03a7c51ec1bee984c77b9ed32 to your computer and use it in GitHub Desktop.
CEM Linear
class LinearModel(object):
def __init__(self,
sess,
num_features,
name = "BinaryLinear"):
self.sess = sess
with tf.variable_scope(name):
# Initialize input and weights
self.inputs = tf.placeholder(tf.float32, shape=[1, num_features],
name="inputs")
self.w = tf.Variable(tf.random_normal([num_features, 1]),
name="weights")
self.b = tf.Variable(tf.random_normal([1]), name="bias")
# Create the linear model
self.model = tf.nn.bias_add(tf.matmul(self.inputs, self.w), self.b)
def calc_action(self, observation):
"""
Evaluate the policy for a given observation.
Returns a binary action.
"""
pred_value = self.model.eval({self.inputs: observation}, session=self.sess)
return int(pred_value < 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment