Last active
October 3, 2015 09:17
-
-
Save forkloop/4d5ee23d7a677bd6dd79 to your computer and use it in GitHub Desktop.
Particle Filter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! /usr/bin/env python | |
| """ | |
| particle filter | |
| http://people.csail.mit.edu/rplatt/hw5.html | |
| using SIR | |
| @forkloop | |
| """ | |
| from __future__ import division | |
| import numpy as np | |
| import random as rd | |
| from operator import mul | |
| from math import floor | |
| from math import ceil | |
| import matplotlib.pyplot as plt | |
| PLOT = False | |
| particle_num = 200 | |
| obs_var = 0.5 | |
| mov_var = 0.4 | |
| xy_range = (6, 7) | |
| start = (1.5, 5.5) | |
| grid = np.zeros((xy_range[1], xy_range[0])) | |
| grid[[0,2,4,6], 0:5] = 1 | |
| path = [3]*8 + [1]*3 | |
| pos = [[0.5,0.5,1.5,4.5], | |
| [0.5,0.5,2, 4], | |
| [0.5,0.5,2.5,3.5], | |
| [0.5,0.5,3, 3], | |
| [0.5,0.5,3.5,2.5], | |
| [0.5,0.5,4, 2], | |
| [0.5,0.5,4.5,1.5], | |
| [0.5,0.5,5, 1], | |
| [1.5,5.5,5.5,0.5], | |
| [2, 5, 5.5,0.5], | |
| [2.5,4.5,5.5,0.5], | |
| [3, 4, 5.5,0.5]] | |
| def plot(particle): | |
| fig = plt.figure() | |
| ax = fig.gca() | |
| ax.set_aspect('equal') | |
| ax.pcolor(grid, cmap=plt.cm.binary) | |
| for x in xrange(particle_num): | |
| ax.plot(particle[x][0], particle[x][1], 'o') | |
| ax.set_xlim(0,6) | |
| ax.set_ylim(0,7) | |
| ax.grid() | |
| plt.show() | |
| """ | |
| Update the belief after observation | |
| """ | |
| def observe_update(obs): | |
| pass | |
| """ | |
| Update the belief after movement | |
| """ | |
| def move_update(m): | |
| pass | |
| def check_particle(p): | |
| if ( p[0] < 0 ): | |
| return False | |
| elif ( p[0] <= 5 ): | |
| if ( 1<=p[1]<=2 or 3<=p[1]<=4 or 5<=p[1]<=6 ): | |
| return True | |
| else: | |
| return False | |
| elif ( p[0] <= 6 ): | |
| if ( 0<=p[1]<=7 ): | |
| return True | |
| else: | |
| return False | |
| else: | |
| return False | |
| def transform(p): | |
| res = np.zeros((particle_num, 4)) | |
| for x in xrange(particle_num): | |
| if (p[x][0]<=5): | |
| res[x][0] = ceil(p[x][1]) - p[x][1] | |
| res[x][1] = 1-res[x][0] | |
| res[x][2] = p[x][0] | |
| res[x][3] = 6-p[x][0] | |
| else: | |
| res[x][0] = 7-p[x][1] | |
| res[x][1] = p[x][1] | |
| res[x][2] = p[x][0] | |
| res[x][3] = 6-p[x][0] | |
| return res | |
| def init(): | |
| pass | |
| def check_move(p, dis, dr): | |
| if dr == 1: | |
| if (0<=p[0]<=5): | |
| if (p[1]+dis>ceil(p[1])): | |
| return ceil(p[1]) | |
| elif (p[1]+dis<floor(p[1])): | |
| return floor(p[1]) | |
| else: | |
| return p[1]+dis | |
| else: | |
| if (p[1]+dis>xy_range[1]): | |
| return xy_range[1] | |
| elif (p[1]+dis<0): | |
| return 0 | |
| else: | |
| return p[1]+dis | |
| else: | |
| if (1<=p[1]<=2 or 3<=p[1]<=4 or 5<=p[1]<=6): | |
| if (p[0]+dis>xy_range[0]): | |
| return xy_range[0] | |
| elif (p[0]+dis<0): | |
| return 0 | |
| else: | |
| return p[0]+dis | |
| else: | |
| if (p[0]+dis>xy_range[0]): | |
| return xy_range[0] | |
| elif (p[0]+dis<(xy_range[0]-1)): | |
| return (xy_range[0]-1) | |
| else: | |
| return (p[0]+dis) | |
| """ | |
| Main | |
| """ | |
| def main(): | |
| # init | |
| n = 0 | |
| tmp_particle = np.zeros((particle_num, 2)) | |
| particle = np.zeros((particle_num, 2)) | |
| weight = np.ones((particle_num, 1)) | |
| accu_weight = np.ones((particle_num, 1)) | |
| while n < particle_num : | |
| p = [rd.uniform(0, xy_range[0]), rd.uniform(0, xy_range[1])] | |
| if check_particle(p): | |
| particle[n,:] = p | |
| n += 1 | |
| # plot(particle) | |
| for x in xrange(11): | |
| # import pdb;pdb.set_trace() | |
| plot(particle) | |
| d = path[x] | |
| # move update | |
| for n in xrange(particle_num): | |
| disp = abs(rd.gauss(0, mov_var)) | |
| # disp = rd.gauss(0, mov_var) | |
| if d == 1: | |
| particle[n][1] = check_move(particle[n,:], -disp, d) | |
| # elif d == 1: | |
| # tmp = [particle[n][0], particle[n][1]-disp] | |
| # if (check_particle(tmp)): | |
| # particle[n][1] -= disp | |
| # else: | |
| # particle[n][1] = floor(particle[n][1]) | |
| # elif d == 2: | |
| # tmp = [particle[n][0]-disp, particle[n][1]] | |
| # if (check_particle(tmp)): | |
| # particle[n][0] -= disp | |
| # else: | |
| # particle[n][0] = floor(particle[n][0]) | |
| elif d == 3: | |
| particle[n][0] = check_move(particle[n,:], disp, d) | |
| else: | |
| pass | |
| # observe update | |
| obs_pos = np.zeros(4) | |
| for n in xrange(4): | |
| obs_pos[n] = pos[x+1][n] + rd.gauss(0, obs_var) | |
| particle_pos = transform(particle) | |
| for n in xrange(particle_num): | |
| diff = np.linalg.norm((obs_pos - particle_pos[n,:])) | |
| weight[n] /= diff | |
| # resample | |
| weight = weight/sum(weight) | |
| for n in xrange(particle_num): | |
| accu_weight[n] = sum(weight[:n+1]) | |
| # heck, this is damn stupid | |
| for n in xrange(particle_num): | |
| p = rd.random() | |
| for m in xrange(particle_num): | |
| if p < accu_weight[m] : | |
| break | |
| tmp_particle[n,:] = particle[m,:] | |
| particle = tmp_particle | |
| if __name__ == '__main__': | |
| #import pdb;pdb.set_trace() | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
http://people.csail.mit.edu/rplatt/hw5.html