Created
February 2, 2017 04:14
-
-
Save neka-nat/bcee330abf6ecf392945960b5ebc1734 to your computer and use it in GitHub Desktop.
Astar path calculation using networkx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import numpy as np | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| EPS = 1.0e-9 | |
| def add_cross_edge(gp, shape): | |
| """ | |
| 2DGridのグラフに斜め方向のエッジを追加する | |
| """ | |
| for node in gp.nodes_iter(): | |
| nx_node = (node[0] + 1, node[1] + 1) | |
| if nx_node[0] < shape[0] and nx_node[1] < shape[1]: | |
| gp.add_edge(node, nx_node) | |
| nx_node = (node[0] + 1, node[1] - 1) | |
| if nx_node[0] < shape[0] and nx_node[1] >= 0: | |
| gp.add_edge(node, nx_node) | |
| ngrid = 20 | |
| gp = nx.grid_graph(dim=[ngrid, ngrid]) | |
| add_cross_edge(gp, [ngrid, ngrid]) | |
| idcs = np.random.choice(len(gp.nodes()), int(ngrid * ngrid * 0.2), replace=False) | |
| # スタート・ゴール・障害物を設定する | |
| st, gl, obs = gp.nodes()[idcs[0]], gp.nodes()[idcs[1]], [gp.nodes()[i] for i in idcs[2:]] | |
| gp.node[st]['color'] = 'green' | |
| gp.node[gl]['color'] = 'red' | |
| for o in obs: | |
| gp.node[o]['color'] = 'black' | |
| xobs = np.array(obs, dtype=np.float32) | |
| def point_line_distance(pt, start, end): | |
| """ | |
| 点と線分の距離 | |
| """ | |
| line_vec = start - end | |
| pt_vec = start - pt | |
| line_len = np.linalg.norm(line_vec) | |
| if line_len < EPS: | |
| return np.linalg.norm(pt - start) | |
| t = np.dot(line_vec, pt_vec) / line_len | |
| t = max(min(1.0, t), 0.0) | |
| return np.linalg.norm(line_vec * t - pt_vec) | |
| def dist(a, b): | |
| """ | |
| ヒューリスティック関数 | |
| """ | |
| x1 = np.array(a, dtype=np.float32) | |
| x2 = np.array(b, dtype=np.float32) | |
| return np.linalg.norm(x1 - x2) | |
| def cost(a, b, k1=1.0, k2=10.0, kind='dist'): | |
| """ | |
| コスト関数 | |
| """ | |
| x1 = np.array(a, dtype=np.float32) | |
| x2 = np.array(b, dtype=np.float32) | |
| dist = np.linalg.norm(x1 - x2) | |
| if kind == 'intsct': | |
| if min([point_line_distance(xob, x1, x2) for xob in xobs]) < EPS: | |
| penalty = 1.0 / EPS | |
| else: | |
| penalty = 0.0 | |
| elif kind == 'dist': | |
| penalty = max([1.0 / max(EPS, point_line_distance(xob, x1, x2)) for xob in xobs]) | |
| else: | |
| penalty = 0.0 | |
| return k1 * dist + k2 * penalty | |
| for u, v, d in gp.edges_iter(data=True): | |
| d['weight'] = cost(u, v) | |
| path = nx.astar_path(gp, st, gl, dist) | |
| length = nx.astar_path_length(gp, st, gl, dist) | |
| #path = nx.dijkstra_path(gp, st, gl) | |
| #length = nx.dijkstra_path_length(gp, st, gl) | |
| print(path) | |
| print(length) | |
| for p in path[1:-1]: | |
| if gp.node[p].get('color', '') == 'black': | |
| continue | |
| gp.node[p]['color'] = 'blue' | |
| nx.draw(gp, | |
| pos=dict((n, n) for n in gp.nodes()), | |
| node_color=[gp.node[n].get('color', 'white') for n in gp.nodes_iter()], | |
| node_size=200) | |
| plt.axis('equal') | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment