Skip to content

Instantly share code, notes, and snippets.

@forkloop
Last active October 3, 2015 09:17
Show Gist options
  • Select an option

  • Save forkloop/4d5ee23d7a677bd6dd79 to your computer and use it in GitHub Desktop.

Select an option

Save forkloop/4d5ee23d7a677bd6dd79 to your computer and use it in GitHub Desktop.
Particle Filter
#! /usr/bin/env python
"""
particle filter
http://people.csail.mit.edu/rplatt/hw5.html
using SIR
@forkloop
"""
from __future__ import division
import numpy as np
import random as rd
from operator import mul
from math import floor
from math import ceil
import matplotlib.pyplot as plt
PLOT = False
particle_num = 200
obs_var = 0.5
mov_var = 0.4
xy_range = (6, 7)
start = (1.5, 5.5)
grid = np.zeros((xy_range[1], xy_range[0]))
grid[[0,2,4,6], 0:5] = 1
path = [3]*8 + [1]*3
pos = [[0.5,0.5,1.5,4.5],
[0.5,0.5,2, 4],
[0.5,0.5,2.5,3.5],
[0.5,0.5,3, 3],
[0.5,0.5,3.5,2.5],
[0.5,0.5,4, 2],
[0.5,0.5,4.5,1.5],
[0.5,0.5,5, 1],
[1.5,5.5,5.5,0.5],
[2, 5, 5.5,0.5],
[2.5,4.5,5.5,0.5],
[3, 4, 5.5,0.5]]
def plot(particle):
fig = plt.figure()
ax = fig.gca()
ax.set_aspect('equal')
ax.pcolor(grid, cmap=plt.cm.binary)
for x in xrange(particle_num):
ax.plot(particle[x][0], particle[x][1], 'o')
ax.set_xlim(0,6)
ax.set_ylim(0,7)
ax.grid()
plt.show()
"""
Update the belief after observation
"""
def observe_update(obs):
pass
"""
Update the belief after movement
"""
def move_update(m):
pass
def check_particle(p):
if ( p[0] < 0 ):
return False
elif ( p[0] <= 5 ):
if ( 1<=p[1]<=2 or 3<=p[1]<=4 or 5<=p[1]<=6 ):
return True
else:
return False
elif ( p[0] <= 6 ):
if ( 0<=p[1]<=7 ):
return True
else:
return False
else:
return False
def transform(p):
res = np.zeros((particle_num, 4))
for x in xrange(particle_num):
if (p[x][0]<=5):
res[x][0] = ceil(p[x][1]) - p[x][1]
res[x][1] = 1-res[x][0]
res[x][2] = p[x][0]
res[x][3] = 6-p[x][0]
else:
res[x][0] = 7-p[x][1]
res[x][1] = p[x][1]
res[x][2] = p[x][0]
res[x][3] = 6-p[x][0]
return res
def init():
pass
def check_move(p, dis, dr):
if dr == 1:
if (0<=p[0]<=5):
if (p[1]+dis>ceil(p[1])):
return ceil(p[1])
elif (p[1]+dis<floor(p[1])):
return floor(p[1])
else:
return p[1]+dis
else:
if (p[1]+dis>xy_range[1]):
return xy_range[1]
elif (p[1]+dis<0):
return 0
else:
return p[1]+dis
else:
if (1<=p[1]<=2 or 3<=p[1]<=4 or 5<=p[1]<=6):
if (p[0]+dis>xy_range[0]):
return xy_range[0]
elif (p[0]+dis<0):
return 0
else:
return p[0]+dis
else:
if (p[0]+dis>xy_range[0]):
return xy_range[0]
elif (p[0]+dis<(xy_range[0]-1)):
return (xy_range[0]-1)
else:
return (p[0]+dis)
"""
Main
"""
def main():
# init
n = 0
tmp_particle = np.zeros((particle_num, 2))
particle = np.zeros((particle_num, 2))
weight = np.ones((particle_num, 1))
accu_weight = np.ones((particle_num, 1))
while n < particle_num :
p = [rd.uniform(0, xy_range[0]), rd.uniform(0, xy_range[1])]
if check_particle(p):
particle[n,:] = p
n += 1
# plot(particle)
for x in xrange(11):
# import pdb;pdb.set_trace()
plot(particle)
d = path[x]
# move update
for n in xrange(particle_num):
disp = abs(rd.gauss(0, mov_var))
# disp = rd.gauss(0, mov_var)
if d == 1:
particle[n][1] = check_move(particle[n,:], -disp, d)
# elif d == 1:
# tmp = [particle[n][0], particle[n][1]-disp]
# if (check_particle(tmp)):
# particle[n][1] -= disp
# else:
# particle[n][1] = floor(particle[n][1])
# elif d == 2:
# tmp = [particle[n][0]-disp, particle[n][1]]
# if (check_particle(tmp)):
# particle[n][0] -= disp
# else:
# particle[n][0] = floor(particle[n][0])
elif d == 3:
particle[n][0] = check_move(particle[n,:], disp, d)
else:
pass
# observe update
obs_pos = np.zeros(4)
for n in xrange(4):
obs_pos[n] = pos[x+1][n] + rd.gauss(0, obs_var)
particle_pos = transform(particle)
for n in xrange(particle_num):
diff = np.linalg.norm((obs_pos - particle_pos[n,:]))
weight[n] /= diff
# resample
weight = weight/sum(weight)
for n in xrange(particle_num):
accu_weight[n] = sum(weight[:n+1])
# heck, this is damn stupid
for n in xrange(particle_num):
p = rd.random()
for m in xrange(particle_num):
if p < accu_weight[m] :
break
tmp_particle[n,:] = particle[m,:]
particle = tmp_particle
if __name__ == '__main__':
#import pdb;pdb.set_trace()
main()
@forkloop
Copy link
Copy Markdown
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment