{ "cells": [ { "cell_type": "markdown", "id": "ca4581d4-c4cc-45fe-893c-556611fee0d8", "metadata": {}, "source": [ "A little over eight years ago, I published a [post](https://austinrochford.com/posts/2017-07-09-mrpymc3.html) entitled _MRPyMC3 - Multilevel Regression and Poststratification with PyMC3_, showing how to perform [multilevel regression with poststratification](https://en.wikipedia.org/wiki/Multilevel_regression_with_poststratification) (MRP) in Python with PyMC. I periodically enjoy [revisiting](https://austinrochford.com/posts/revisit-survival-pymc.html) old posts after both technology and my understanding of the problem advances. This post revisits MRP with a few notable changes:\n", "\n", "* [PyMC](http://pymc.io/) is now on major version 5 instead of version 3,\n", "* we use [nutpie](https://pymc-devs.github.io/nutpie/) for sampling from the model instead of PyMC's built in sampler, and\n", "* we use [Polars](http://pymc.io/) instead of [pandas](https://pandas.pydata.org/).\n", "\n", "We will not repeat the previous post's full exposition of MRP and will rather focus on the mechanics of its implementation.\n", "\n", "First we import the necessary packages and do a bit of light configuration." ] }, { "cell_type": "code", "execution_count": 1, "id": "edcdf70a-c255-4be0-b937-bce507e14487", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = \"retina\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "5bb22b68-0a11-4759-8534-f2aed0cc0dcf", "metadata": {}, "outputs": [], "source": [ "from itertools import zip_longest\n", "import os\n", "from urllib import request\n", "import us\n", "from zipfile import ZipFile" ] }, { "cell_type": "code", "execution_count": 3, "id": "1fb32e07-91f8-4391-a09c-b49cc349900e", "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "from matplotlib import cm, pyplot as plt, ticker\n", "from matplotlib.colorbar import ColorbarBase\n", "from matplotlib.colors import Normalize\n", "import numpy as np\n", "import nutpie\n", "import polars as pl\n", "import pymc as pm\n", "from pyreadstat import read_dta\n", "from pytensor import tensor as pt\n", "import seaborn as sns\n", "from seaborn import objects as so\n", "from scipy.special import logit" ] }, { "cell_type": "code", "execution_count": 4, "id": "8c061dbe-097c-4f1a-a6d7-b76c24d545c1", "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"darkgrid\", {\"axes.linewidth\": 1, \"axes.edgecolor\": \"black\"})" ] }, { "cell_type": "markdown", "id": "c6b5e21b-6c6f-416a-b222-935bb3e1f033", "metadata": {}, "source": [ "## Load and transform the data\n", "\n", "As in the previous post, we follow [Jonathan Kastellec](https://jkastellec.scholar.princeton.edu/)'s excellent [MRP Primer](https://jkastellec.scholar.princeton.edu/publications/mrp_primer), which focuses on estimating state-level opinions of gay marriage in 2005/2006 from polling data.\n", "\n", "First we download and decompress the data." ] }, { "cell_type": "code", "execution_count": 5, "id": "4ea540e1-ab76-430c-8a9a-788ad6edd517", "metadata": {}, "outputs": [], "source": [ "DATA_PATH = \"./data\"\n", "DATA_URI = \"https://jkastellec.scholar.princeton.edu/sites/g/files/toruqf3871/files/jkastellec/files/mrp_primer_replication_files.zip\"" ] }, { "cell_type": "code", "execution_count": 6, "id": "67c2f68e-be22-46ab-bd3c-8a82e1aa45d3", "metadata": {}, "outputs": [], "source": [ "USER_AGENT = \"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "21facc6d-c899-464f-8c3c-d1204c4f8aeb", "metadata": {}, "outputs": [], "source": [ "if not os.path.isdir(DATA_PATH):\n", " os.mkdir(DATA_PATH)\n", "\n", "dest_path = os.path.join(DATA_PATH, os.path.basename(DATA_URI))\n", "\n", "if not os.path.exists(dest_path):\n", " opener = request.build_opener()\n", " opener.addheaders = [(\"User-agent\", USER_AGENT)]\n", " request.install_opener(opener)\n", " request.urlretrieve(DATA_URI, dest_path)\n", "\n", " with ZipFile(dest_path) as src:\n", " src.extractall(DATA_PATH)" ] }, { "cell_type": "markdown", "id": "8d3a2887-7e29-4bf6-bbf6-54690aa5af28", "metadata": {}, "source": [ "### Poll data\n", "\n", "Next we load and do some light feature engineering on the polling data necessary to build the multilevel model." ] }, { "cell_type": "code", "execution_count": 8, "id": "c94c9aba-cfc9-49d1-8419-ffd1b3ac0ae8", "metadata": {}, "outputs": [], "source": [ "UNZIPPED_DIR = \"MRP_Primer_Replication_Files\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "d68ae868-6eba-4ae3-b911-826f21813ef0", "metadata": {}, "outputs": [], "source": [ "POLL_PATH = os.path.join(DATA_PATH, UNZIPPED_DIR, \"gay_marriage_megapoll.dta\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "fc65705e-7437-4343-8011-51ad97696821", "metadata": {}, "outputs": [], "source": [ "POLL_COLS = [\n", " \"race_wbh\",\n", " \"age_cat\",\n", " \"edu_cat\",\n", " \"female\",\n", " \"region\",\n", " \"state\",\n", " \"poll\",\n", " \"yes_of_all\",\n", "]\n", "\n", "NN_COLS = [\"race_wbh\", \"age_cat\", \"edu_cat\"]" ] }, { "cell_type": "code", "execution_count": 11, "id": "f8c9dbfc-80e3-4fa7-9e74-16b72ca90864", "metadata": {}, "outputs": [], "source": [ "def to_zero_indexed(name):\n", " col = pl.col(name)\n", "\n", " return col - col.min()" ] }, { "cell_type": "code", "execution_count": 12, "id": "a67ad88d-a553-46f5-8135-89fa8ab8eae1", "metadata": {}, "outputs": [], "source": [ "CAT_COLS = [\"age_cat\", \"edu_cat\", \"race_wbh\"]\n", "\n", "cat_col_transforms = [to_zero_indexed(name) for name in CAT_COLS]" ] }, { "cell_type": "code", "execution_count": 13, "id": "78f1d8ef-640e-4a9b-b25e-a0ec11a21b54", "metadata": {}, "outputs": [], "source": [ "GENDER = pl.Enum([\"Male\", \"Female\"])\n", "POLL = pl.Enum(\n", " [\n", " \"ABC 2004Jan15\",\n", " \"Gall2004Mar05\",\n", " \"Gall2005Aug22\",\n", " \"Pew 2004Dec01\",\n", " \"Pew 2004Feb11\",\n", " ]\n", ")\n", "RACE = pl.Enum([\"White\", \"Black\", \"Hispanic\"])\n", "REGION = pl.Enum([\"dc\", \"midwest\", \"northeast\", \"south\", \"west\"])\n", "STATE = pl.Enum(sorted([state.abbr for state in us.states.STATES] + [\"DC\"]))" ] }, { "cell_type": "code", "execution_count": 14, "id": "c61728c1-9761-4519-ae76-02791fffbe53", "metadata": {}, "outputs": [], "source": [ "ENUM_CASTS = {\n", " \"female\": (GENDER, \"gender\"),\n", " \"poll\": (POLL, None),\n", " \"race_wbh\": (RACE, \"race\"),\n", " \"region\": (REGION, None),\n", " \"state\": (STATE, None),\n", "}\n", "\n", "\n", "def cast_enum_cols(df):\n", " for name, (enum, new_name) in ENUM_CASTS.items():\n", " if name in df.columns:\n", " df = df.with_columns(pl.col(name).cast(enum).alias(new_name or name))\n", "\n", " if new_name is not None:\n", " df = df.drop(name)\n", "\n", " return df" ] }, { "cell_type": "code", "execution_count": 15, "id": "877e89c1-f101-43c2-be00-1888c42d6b99", "metadata": {}, "outputs": [], "source": [ "poll_df = (\n", " pl.from_pandas(read_dta(POLL_PATH)[0])\n", " .select(POLL_COLS)\n", " .drop_nulls(NN_COLS)\n", " .with_columns(*cat_col_transforms)\n", " .group_by(POLL_COLS[:-1])\n", " .agg(pl.sum(\"yes_of_all\"), poll_pop=pl.len().cast(pl.Int32))\n", " .filter(pl.col(\"state\") != \"\")\n", " .pipe(cast_enum_cols)\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "e4366cd6-e8b6-4094-84f1-d57b54fa02b9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| age_cat | edu_cat | region | state | poll | yes_of_all | poll_pop | gender | race |
|---|---|---|---|---|---|---|---|---|
| i64 | i64 | enum | enum | enum | i64 | i32 | enum | enum |
| 1 | 1 | "south" | "DE" | "Gall2005Aug22" | 0 | 1 | "Male" | "Black" |
| 3 | 3 | "south" | "FL" | "Pew 2004Dec01" | 3 | 5 | "Male" | "White" |
| 0 | 3 | "south" | "VA" | "Pew 2004Feb11" | 1 | 1 | "Female" | "White" |
| 0 | 1 | "midwest" | "OH" | "Gall2004Mar05" | 0 | 1 | "Male" | "White" |
| 3 | 1 | "south" | "TN" | "ABC 2004Jan15" | 0 | 1 | "Male" | "Black" |
| … | … | … | … | … | … | … | … | … |
| 3 | 1 | "west" | "CA" | "Pew 2004Feb11" | 0 | 1 | "Female" | "Black" |
| 1 | 3 | "midwest" | "IA" | "ABC 2004Jan15" | 1 | 1 | "Female" | "White" |
| 0 | 3 | "south" | "TN" | "Pew 2004Feb11" | 1 | 1 | "Female" | "White" |
| 0 | 2 | "northeast" | "NY" | "Gall2004Mar05" | 0 | 1 | "Female" | "White" |
| 3 | 1 | "midwest" | "WI" | "Gall2005Aug22" | 0 | 1 | "Female" | "White" |
| age_cat | edu_cat | state | pop | region | gender | race |
|---|---|---|---|---|---|---|
| i64 | i64 | enum | i64 | enum | enum | enum |
| 0 | 0 | "AK" | 467 | "west" | "Male" | "White" |
| 1 | 0 | "AK" | 377 | "west" | "Male" | "White" |
| 2 | 0 | "AK" | 419 | "west" | "Male" | "White" |
| 3 | 0 | "AK" | 343 | "west" | "Male" | "White" |
| 0 | 1 | "AK" | 958 | "west" | "Male" | "White" |
| … | … | … | … | … | … | … |
| 3 | 2 | "WY" | 4 | "west" | "Female" | "Hispanic" |
| 0 | 3 | "WY" | 8 | "west" | "Female" | "Hispanic" |
| 1 | 3 | "WY" | 16 | "west" | "Female" | "Hispanic" |
| 2 | 3 | "WY" | 10 | "west" | "Female" | "Hispanic" |
| 3 | 3 | "WY" | 1 | "west" | "Female" | "Hispanic" |
| state | kerry_04 | p_relig | region |
|---|---|---|---|
| enum | f64 | f64 | enum |
| "AK" | 0.355 | 0.154431 | "west" |
| "AL" | 0.368 | 0.410083 | "south" |
| "AR" | 0.446 | 0.436301 | "south" |
| "AZ" | 0.444 | 0.142887 | "west" |
| "CA" | 0.543 | 0.087176 | "west" |
| … | … | … | … |
| "VT" | 0.589 | 0.02912 | "northeast" |
| "WA" | 0.528 | 0.128227 | "west" |
| "WI" | 0.497 | 0.129527 | "midwest" |
| "WV" | 0.432 | 0.116178 | "south" |
| "WY" | 0.291 | 0.208844 | "west" |
| state | poll_pop |
|---|---|
| enum | i32 |
| "AK" | 0 |
| "HI" | 0 |
| "DC" | 6 |
| "WY" | 14 |
| "DE" | 16 |
Sampler Progress
\n", "Total Chains: 8
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 8\n", "
\n", "Sampling for 3 minutes
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "| Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
|---|---|---|---|---|
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.05 | \n", "255 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.04 | \n", "511 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.04 | \n", "511 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.04 | \n", "511 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.05 | \n", "255 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.05 | \n", "255 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.04 | \n", "1023 | \n", "
| \n", " \n", " | \n", "1400 | \n", "0 | \n", "0.04 | \n", "255 | \n", "
<xarray.DataArray ()> Size: 8B\n",
"array(1.0036893)