Created
October 31, 2017 17:41
-
-
Save nalvared/abdce13fff49486ab13c7a2cdd894069 to your computer and use it in GitHub Desktop.
Revisions
-
nalvared created this gist
Oct 31, 2017 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,34 @@ # PlotHelper ## Import library The _plothelper.py_ file should be in the same folder ```python import numpy as np import pandas as pd from plothelper import PlotHelper as PH %matplotlib inline ``` ## Load dataset with pandas ```python # load dataset dfHouses = pd.read_csv('kc_house_data.csv') ``` ## Using PH.pairplot This is the same function that _seaborn.pairplot()_ but it has two new parameters: - max_per_row: maximum charts in the same row - reg_line_color: if the type of the chart is _kind='reg'_, this property paints the regression line in the color choosen ```python PH().pairplot(data=dfHouses, x_vars=dfHouses.columns.tolist(), y_vars=['price'], max_per_row=5, kind='reg', reg_line_color='red') ``` This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,45 @@ import numpy as np import seaborn as sns class PlotHelper: def splitVariables(self, x_vars=None, y_vars=None, max_per_row=10): slices = [] for y in y_vars: curr_y = y l = [] if max_per_row > len(x_vars): for x in x_vars: if y != x: l.append(x) slices.append([curr_y, l]) else: i = 0 for s in range(int(np.ceil(len(x_vars)/max_per_row))): l = [] for x in x_vars[i:max_per_row+i]: if y != x: l.append(x) slices.append([l,curr_y]) i += max_per_row return slices def pairplot(self, data=None, hue=None, hue_order=None, palette=None, vars=None, x_vars=None, y_vars=None, kind='scatter', diag_kind='hist', markers=None, size=2.5, aspect=1, dropna=True, plot_kws=None, diag_kws=None, grid_kws=None, wrap=True, max_per_row=None, reg_line_color=None): if kind == 'reg' and reg_line_color != None: plot_kws={'line_kws':{'color':reg_line_color}} if max_per_row == None: return sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette, vars=vars, x_vars=x_vars, y_vars=y_vars, kind=kind, diag_kind=diag_kind, markers=markers, size=size, aspect=aspect, dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws) else: slices = self.splitVariables(x_vars, y_vars, max_per_row) for i in range(len(slices)): sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette, vars=vars, x_vars=slices[i][0], y_vars=slices[i][1], kind=kind, diag_kind=diag_kind, markers=markers, size=size, aspect=aspect, dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)