Last active
August 9, 2016 00:51
-
-
Save evancasey/d1da4195697242ac3dad18697d45044b to your computer and use it in GitHub Desktop.
CEM agent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(self, render = False): | |
| """ | |
| Given an initial distribution vector of w_i, compute a new distribution | |
| vector via the cross-entropy method. | |
| Returns a policy network and stores performance history as `perf_hist` | |
| """ | |
| tf.initialize_all_variables().run() | |
| # Additional param for bias | |
| distrib_means = np.zeros(self.num_features + 1) | |
| distrib_vars = np.full((self.num_features + 1,), .1) | |
| for i in xrange(self.n_iter): | |
| # Step 1: sample 'batch_size' w_i's from initial distribution | |
| batch_weights = self._sample_weights(self.batch_size, distrib_means, | |
| distrib_vars) | |
| # Step 2: perform rollout and evaluate each w_i | |
| batch_scores = np.apply_along_axis(self.rollout, 1, batch_weights, | |
| render) | |
| self.perf_hist.append(np.mean(batch_scores, 0)) | |
| # Step 3: select the top 'elite_frac' w_i's | |
| top_weights = self._top_weights(batch_weights, batch_scores) | |
| # Step 4: fit a new Gaussian distrib. over the top scoring w_i's | |
| noise = max(5 - i / 10, 0) | |
| distrib_means = np.mean(top_weights, 0) | |
| distrib_vars = np.var(top_weights, 0) + noise | |
| return self.pred_network |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment