Created
May 15, 2019 19:18
-
-
Save acl21/4f60cc010d9fd6eed3230dc9c67c93a8 to your computer and use it in GitHub Desktop.
Mini-Batch Gradient Descent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def do_mini_batch_gradient_descent(): | |
| w, b, eta = init_w, init_b, 1.0 | |
| mini_batch_size, num_points_seen = 2, 0 | |
| for i in range(max_epochs): | |
| dw, db = 0, 0 | |
| for x,y in zip(X,Y): | |
| dw += grad_w(w, b, x, y) | |
| db += grad_b(w, b, x, y) | |
| num_points_seen += 1 | |
| if num_points_seen % mini_batch_size == 0: | |
| # seen one mini-batch | |
| w = w - eta * dw | |
| b = b - eta * db | |
| # reset gradients | |
| dw, db = 0, 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment