Skip to content

Instantly share code, notes, and snippets.

@nalvared
Created October 31, 2017 17:41
Show Gist options
  • Select an option

  • Save nalvared/abdce13fff49486ab13c7a2cdd894069 to your computer and use it in GitHub Desktop.

Select an option

Save nalvared/abdce13fff49486ab13c7a2cdd894069 to your computer and use it in GitHub Desktop.

Revisions

  1. nalvared created this gist Oct 31, 2017.
    34 changes: 34 additions & 0 deletions ExamplePlotHelper.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,34 @@

    # PlotHelper

    ## Import library
    The _plothelper.py_ file should be in the same folder

    ```python
    import numpy as np
    import pandas as pd
    from plothelper import PlotHelper as PH
    %matplotlib inline
    ```

    ## Load dataset with pandas

    ```python
    # load dataset
    dfHouses = pd.read_csv('kc_house_data.csv')
    ```

    ## Using PH.pairplot
    This is the same function that _seaborn.pairplot()_ but it has two new parameters:
    - max_per_row: maximum charts in the same row
    - reg_line_color: if the type of the chart is _kind='reg'_, this property paints the regression line in the color choosen

    ```python
    PH().pairplot(data=dfHouses,
    x_vars=dfHouses.columns.tolist(),
    y_vars=['price'],
    max_per_row=5,
    kind='reg',
    reg_line_color='red')
    ```

    45 changes: 45 additions & 0 deletions plothelper.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,45 @@
    import numpy as np
    import seaborn as sns
    class PlotHelper:

    def splitVariables(self, x_vars=None, y_vars=None, max_per_row=10):
    slices = []
    for y in y_vars:
    curr_y = y
    l = []
    if max_per_row > len(x_vars):
    for x in x_vars:
    if y != x:
    l.append(x)
    slices.append([curr_y, l])
    else:
    i = 0
    for s in range(int(np.ceil(len(x_vars)/max_per_row))):
    l = []
    for x in x_vars[i:max_per_row+i]:
    if y != x:
    l.append(x)
    slices.append([l,curr_y])
    i += max_per_row
    return slices

    def pairplot(self, data=None, hue=None, hue_order=None, palette=None,
    vars=None, x_vars=None, y_vars=None, kind='scatter',
    diag_kind='hist', markers=None, size=2.5, aspect=1,
    dropna=True, plot_kws=None, diag_kws=None,
    grid_kws=None, wrap=True, max_per_row=None, reg_line_color=None):
    if kind == 'reg' and reg_line_color != None:
    plot_kws={'line_kws':{'color':reg_line_color}}
    if max_per_row == None:
    return sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette,
    vars=vars, x_vars=x_vars, y_vars=y_vars, kind=kind,
    diag_kind=diag_kind, markers=markers, size=size, aspect=aspect,
    dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)
    else:
    slices = self.splitVariables(x_vars, y_vars, max_per_row)
    for i in range(len(slices)):
    sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette,
    vars=vars, x_vars=slices[i][0], y_vars=slices[i][1], kind=kind,
    diag_kind=diag_kind, markers=markers, size=size, aspect=aspect,
    dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)