Skip to content

Instantly share code, notes, and snippets.

@evancasey
Created August 9, 2016 00:52
Show Gist options
  • Select an option

  • Save evancasey/3192b882693d413cad879f257046e05a to your computer and use it in GitHub Desktop.

Select an option

Save evancasey/3192b882693d413cad879f257046e05a to your computer and use it in GitHub Desktop.
CEM rollout
def rollout(self, w, render):
"""
Plays one episode to `max_num_steps` or a terminal state, given a weight vector w.
Returns a scalar of the reward sum of the episode.
"""
self._update_network(w)
observation = self.env.reset()
cart_position, pole_angle, cart_velocity, angle_rate_of_change = observation
total_reward = 0.0
for _ in xrange(self.max_num_steps - 1):
action = self.pred_network.calc_action(observation.reshape(1,4))
observation, reward, is_terminal, info = self.env.step(action)
total_reward += reward
if is_terminal:
break
if render:
self.env.render()
return total_reward
def _update_network(self, w):
"""
Updates the network with the weight vector `w`. This side effects the
existing `pred_network`.
"""
assign_w = self.pred_network.w.assign(w[:-1].reshape(4,1))
assign_b = self.pred_network.b.assign([w[-1]])
ops = tf.group(assign_w, assign_b, name="update")
self.sess.run(ops)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment