Last active
June 1, 2019 02:59
-
-
Save mohira/9346672ed38e0cd161adf27cae0c1200 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from sklearn.datasets import load_iris | |
| from sklearn.tree import DecisionTreeClassifier | |
| %matplotlib inline | |
| # すぐに試すためのmodelを作っておく(なので、データ分割とかしてない) | |
| iris = load_iris() | |
| X_train = pd.DataFrame(iris.data, columns=iris.feature_names) | |
| y_train = iris.target | |
| model = DecisionTreeClassifier() | |
| model.fit(X_train, y_train) | |
| def plot_feature_importance(model, X_train): | |
| # https://qiita.com/takapy0210/items/73415599579f2588080e を一部改造 | |
| n_features = X_train.shape[1] | |
| plt.barh(range(n_features), sorted( | |
| model.feature_importances_), align='center') | |
| plt.yticks(np.arange(n_features), X_train.columns) | |
| plt.xlabel('Feature importance') | |
| plt.ylabel('Feature') | |
| plt.xlim(0, 1.0) | |
| plot_feature_importance(model, X_train) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| from sklearn.datasets import load_iris | |
| from sklearn.tree import DecisionTreeClassifier | |
| %matplotlib inline | |
| # すぐに試すためのmodelを作っておく(なので、データ分割とかしてない) | |
| iris = load_iris() | |
| X_train = pd.DataFrame(iris.data, columns=iris.feature_names) | |
| y_train = iris.target | |
| model = DecisionTreeClassifier() | |
| model.fit(X_train, y_train) | |
| def visualize_feature_importance(model, X_train): | |
| # barplotに持ち込むためのDFを生成 | |
| df = pd.DataFrame() | |
| df['feature_name'] = X_train.columns | |
| df['feature_importance'] = model.feature_importances_ | |
| # 降順のほうが見やすいのでソート | |
| df.sort_values(by='feature_importance', ascending=False, inplace=True) | |
| # 重要度0はノイズに思えるので除外している | |
| df = df[df['feature_importance'] > 0] | |
| sns.barplot(x='feature_importance', | |
| y='feature_name', | |
| data=df) | |
| plt.xlim(0, 1.0) # 上限を設けた方がいい(もし、低い値で団子状態でも錯覚しないようにするため) | |
| visualize_feature_importance(model, X_train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment