Skip to content

Instantly share code, notes, and snippets.

@joelanders
Created March 11, 2021 10:52
Show Gist options
  • Select an option

  • Save joelanders/e3a9f4f57fb9ebacc859d7f73595d46a to your computer and use it in GitHub Desktop.

Select an option

Save joelanders/e3a9f4f57fb9ebacc859d7f73595d46a to your computer and use it in GitHub Desktop.

Revisions

  1. joelanders created this gist Mar 11, 2021.
    70 changes: 70 additions & 0 deletions cumulative_plotter.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,70 @@
    from datetime import datetime
    from datetime import timedelta
    import pandas as pd
    import plotly.express as px
    import plotly.graph_objs as go
    import statsmodels.api as sm
    import math

    # df = pd.read_csv('owid-covid-data.csv')
    # df = pd.read_csv('vaccinations.csv')
    df = pd.read_csv('vaccinations.csv')
    df = df[df.date >= "2021-01-01 00:00:00.000"]
    df = df[df.date <= "2021-03-12 00:00:00.000"]

    df['day_of_year'] = [(datetime.strptime(d, "%Y-%m-%d") - datetime(2021,1,1)).days for d in df['date']]

    # XXX sloppy way to set the x scale
    x_days = 190
    x_max_date = datetime(2021,1,1) + timedelta(days=x_days)

    fig = go.Figure(
    layout_yaxis_range=[0,100],
    layout_xaxis_range=["2021-01-01 00:00:00.0000", x_max_date],
    layout_yaxis_title_text="cumulative doses per hundred people",
    layout_title_text="cumulative doses per hundred people",
    layout_title_x=0.5,
    )

    fig.update_yaxes(nticks=20)

    # for country_index, country in enumerate(['United Kingdom', 'United States', 'Italy', 'France', 'Germany']):
    for country_index, country in enumerate(['United Kingdom', 'United States', 'Germany']):
    country_df = df[df.location == country]
    country_df = country_df.dropna(subset=['total_vaccinations_per_hundred'])

    recent_country_df = country_df[df.date >= "2021-02-25 00:00:00.000"]

    x = sm.add_constant(recent_country_df['day_of_year'])
    model = sm.OLS(recent_country_df['total_vaccinations_per_hundred'], x)
    results = model.fit()
    print(results.params)

    b = results.params[0]
    m = results.params[1]

    # scatter plot for data
    fig.add_trace(
    go.Scatter(
    x=country_df['date'],
    y=country_df['total_vaccinations_per_hundred'],
    mode="markers",
    marker_color=px.colors.qualitative.Dark2[country_index],
    name=country,
    showlegend=False,
    )
    )

    # extrapolated line
    fig.add_trace(
    go.Scatter(
    x=["2021-01-01 00:00:00.0000", x_max_date],
    y=[results.params[0], (m*x_days + b)],
    mode="lines",
    marker_color=px.colors.qualitative.Dark2[country_index],
    name=country,
    showlegend=True,
    )
    )

    fig.show()