Created
November 16, 2017 15:35
-
-
Save bbtfr/7aaa20bb6e231631c3856fe424d5ae4f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| ITEMS_COUNT = 100 | |
| MAX_SIMILARITY = 100 | |
| CLUSTER_THRESHOLD = 80 | |
| SPLIT_THRESHOLD = 50 | |
| # 随机生成一个相似度矩阵 | |
| # 相似度(0-100),数值越大越相似 | |
| mat = np.random.randint(MAX_SIMILARITY, size=(ITEMS_COUNT, ITEMS_COUNT)) | |
| group = np.arange(ITEMS_COUNT) | |
| while True: | |
| # 取相似度矩阵中最大元素,直到相似度矩阵中没有大于阈值(CLUSTER_THRESHOLD)的元素 | |
| idx = np.unravel_index(np.argmax(mat), mat.shape) | |
| if mat[idx] < CLUSTER_THRESHOLD: | |
| break | |
| print(group) | |
| print(idx, mat[idx]) | |
| # 当且仅当两组元素之间两两相似度均大于阈值(SPLIT_THRESHOLD),合并两组 | |
| # 判断两组元素之间两两相似度是为了防止类内有相似度过低的元素 | |
| # 应该可以缓存一下不能被合并的两个组的编号减少重复运算 | |
| # 两组元素之间两两相似度矩阵 | |
| submat = mat[group == group[idx[0]]][:, group == group[idx[1]]] | |
| # 上述矩阵所有元素都大于阈值(SPLIT_THRESHOLD),或等于 -1 | |
| if np.all(np.logical_or(submat > SPLIT_THRESHOLD, submat == -1)): | |
| group[group == group[idx[1]]] = group[idx[0]] | |
| # 把相似度矩阵当前搜索到的元素置为-1,这样下次继续搜索相似度矩阵中最大元素就可以了 | |
| mat[idx[0],idx[1]] = -1 | |
| mat[idx[1],idx[0]] = -1 | |
| print(group) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment