Created
August 6, 2016 20:04
-
-
Save evancasey/51fa98e03a7c51ec1bee984c77b9ed32 to your computer and use it in GitHub Desktop.
CEM Linear
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class LinearModel(object): | |
| def __init__(self, | |
| sess, | |
| num_features, | |
| name = "BinaryLinear"): | |
| self.sess = sess | |
| with tf.variable_scope(name): | |
| # Initialize input and weights | |
| self.inputs = tf.placeholder(tf.float32, shape=[1, num_features], | |
| name="inputs") | |
| self.w = tf.Variable(tf.random_normal([num_features, 1]), | |
| name="weights") | |
| self.b = tf.Variable(tf.random_normal([1]), name="bias") | |
| # Create the linear model | |
| self.model = tf.nn.bias_add(tf.matmul(self.inputs, self.w), self.b) | |
| def calc_action(self, observation): | |
| """ | |
| Evaluate the policy for a given observation. | |
| Returns a binary action. | |
| """ | |
| pred_value = self.model.eval({self.inputs: observation}, session=self.sess) | |
| return int(pred_value < 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment