Skip to content

Instantly share code, notes, and snippets.

@saulshanabrook
Created April 30, 2026 18:05
Show Gist options
  • Select an option

  • Save saulshanabrook/3f46d8450fd10783866e1a5f51c0514a to your computer and use it in GitHub Desktop.

Select an option

Save saulshanabrook/3f46d8450fd10783866e1a5f51c0514a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f479154d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b721ce8547b465da62eb5e21a7c5417",
"version_major": 2,
"version_minor": 1
},
"text/plain": [
"VisualizerWidget(egraphs=['{\"nodes\":{\"primitive-i64-3\":{\"op\":\"3\",\"children\":[],\"eclass\":\"i64-3\",\"cost\":1.0,\"su…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(sort Layout)\n",
"(sort Num)\n",
"(sort Vec[Num] (Vec Num))\n",
"(constructor Layout_strided (Vec[Num] Vec[Num]) Layout)\n",
"(constructor Int (i64) Num)\n",
"(Layout_strided (vec-of (Int 4) (Int 3) (Int 4)) (vec-of (Int 12) (Int 4) (Int 1)))\n",
"(ruleset __main__.layout_rules)\n",
"(constructor Num___add__ (Num Num) Num)\n",
"(rewrite (Num___add__ (Int _i) (Int _j)) (Int (+ _i _j)) :ruleset __main__.layout_rules)\n",
"(constructor Num___mul__ (Num Num) Num)\n",
"(rewrite (Num___mul__ (Int _i) (Int _j)) (Int (* _i _j)) :ruleset __main__.layout_rules)\n",
"(constructor GetOffset (Layout Vec[Num]) Num)\n",
"(sort UnstableFn[Num,Num,Num] (UnstableFn (Num Num) Num))\n",
"(primitive _lambda_0 (Num Num) Num (Num___add__ _0 _1))\n",
"(sort UnstableFn[Num,i64] (UnstableFn (i64) Num))\n",
"(primitive _lambda_1 (Vec[Num] Vec[Num] i64) Num (Num___mul__ (vec-get _1 _2) (vec-get _0 _2)))\n",
"(sort Vec[i64] (Vec i64))\n",
"(rewrite (GetOffset (Layout_strided _shape _strides) _idxs) (vec-foldl (unstable-fn \"_lambda_0\") (Int 0) _multiplied_idxs) :when ((= _multiplied_idxs (unstable-vec-map (unstable-fn \"_lambda_1\" _strides _idxs) (vec-range (vec-length _shape))))) :ruleset __main__.layout_rules)\n",
"(constructor Layout_column_major (Vec[Num]) Layout)\n",
"(sort UnstableFn[Vec[Num],Num,Vec[Num]] (UnstableFn (Num Vec[Num]) Vec[Num]))\n",
"(primitive _lambda_2 (Num Vec[Num]) Vec[Num] (vec-append (vec-of (Num___mul__ (vec-get _1 0) _0)) _1))\n",
"(rewrite (Layout_column_major _shape) (Layout_strided _shape (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape 0))) :ruleset __main__.layout_rules)\n",
"(function Layout__strided_as_column_major_strides (Layout) Vec[Num] :no-merge)\n",
"(rule ((= _l (Layout_strided _shape _strides)))\n",
" ((set (Layout__strided_as_column_major_strides _l) (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape 0))))\n",
" :ruleset __main__.layout_rules )\n",
"(rewrite (Layout_strided _shape _strides) (Layout_column_major _shape) :when ((= _strides (Layout__strided_as_column_major_strides _l))) :ruleset __main__.layout_rules)\n",
"(constructor Layout_row_major (Vec[Num]) Layout)\n",
"(rewrite (Layout_row_major _shape) (Layout_strided _shape (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape (- (vec-length _shape) 1)))) :ruleset __main__.layout_rules)\n",
"(run-schedule (saturate (run __main__.layout_rules)))\n",
"(check (= (Layout_strided (vec-of (Int 4) (Int 3) (Int 4)) (vec-of (Int 12) (Int 4) (Int 1))) (Layout_column_major (vec-of (Int 4) (Int 3) (Int 4)))))\n",
"\n"
]
}
],
"source": [
"from __future__ import annotations\n",
"\n",
"from egglog import *\n",
"\n",
"\n",
"class Num(Expr):\n",
" @method(egg_fn=\"Int\")\n",
" def __init__(self, value: i64Like) -> None: ...\n",
"\n",
" def __add__(self, other: Num) -> Num: ...\n",
"\n",
" def __mul__(self, other: Num) -> Num: ...\n",
"\n",
"\n",
"converter(i64, Num, Num)\n",
"\n",
"\n",
"class Layout(Expr):\n",
" \"\"\"\n",
" Two layouts should be equivalent when offset by all all logical indices they give the same\n",
" list of physical offsets.\n",
" \"\"\"\n",
"\n",
" @classmethod\n",
" def strided(cls, shape: Vec[Num], strides: Vec[Num]) -> Layout: ...\n",
"\n",
" @classmethod\n",
" def column_major(cls, shape: Vec[Num]) -> Layout: ...\n",
"\n",
" @property\n",
" def _strided_as_column_major_strides(self) -> Vec[Num]: ...\n",
"\n",
" @classmethod\n",
" def row_major(cls, shape: Vec[Num]) -> Layout: ...\n",
"\n",
" @method(egg_fn=\"GetOffset\")\n",
" def __getitem__(self, indxs: Vec[Num]) -> Num: ...\n",
"\n",
"\n",
"# Sofia B WIP on zulip PLDI\n",
"\n",
"\n",
"@ruleset\n",
"def layout_rules(\n",
" shape: Vec[Num], strides: Vec[Num], idxs: Vec[Num], i: i64, j: i64, multiplied_idxs: Vec[Num], l: Layout\n",
"):\n",
" yield rewrite(Num(i) + Num(j)).to(Num(i + j))\n",
" yield rewrite(Num(i) * Num(j)).to(Num(i * j))\n",
"\n",
" yield rewrite(Layout.strided(shape, strides)[idxs]).to(\n",
" vec_foldl(lambda acc, x: acc + x, Num(0), multiplied_idxs),\n",
" multiplied_idxs == shape.length().range().map(lambda i_: idxs[i_] * strides[i_]),\n",
" )\n",
" column_major_strides = vec_foldr(lambda s, acc: Vec(acc[i64(0)] * s).append(acc), Vec(Num(1)), shape.remove(0))\n",
"\n",
" yield rewrite(Layout.column_major(shape)).to(\n",
" Layout.strided(shape, column_major_strides),\n",
" )\n",
"\n",
" yield rule(l == Layout.strided(shape, strides)).then(\n",
" set_(l._strided_as_column_major_strides).to(column_major_strides)\n",
" )\n",
"\n",
" yield rewrite(Layout.strided(shape, strides)).to(\n",
" Layout.column_major(shape),\n",
" # [Num(12), Num(4), Num(1)]\n",
" # [Num(3) * Num(4) * Num(1), Num(4) * Num(1), Num(1)]\n",
" strides == l._strided_as_column_major_strides, # when\n",
" )\n",
"\n",
" yield rewrite(Layout.row_major(shape)).to(\n",
" Layout.strided(\n",
" shape, vec_foldr(lambda s, acc: Vec(acc[0] * s).append(acc), Vec(Num(1)), shape.remove(shape.length() - 1))\n",
" )\n",
" )\n",
"\n",
"\n",
"layout = Layout.column_major(Vec(Num(4), Num(3), Num(4)))\n",
"egraph = EGraph(save_egglog_string=True)\n",
"layout2 = Layout.strided(Vec(Num(4), Num(3), Num(4)), Vec(Num(12), Num(4), Num(1)))\n",
"\n",
"res = layout[Vec(Num(1), Num(2), Num(3))]\n",
"egraph.register(res, layout2)\n",
"egraph.run(layout_rules.saturate())\n",
"# egraph.display()\n",
"egraph.check(res == Num(23))\n",
"egraph.check(layout == layout2)\n",
"\n",
"\n",
"egraph2 = EGraph(save_egglog_string=True)\n",
"egraph2.register(layout2)\n",
"egraph2.run(layout_rules.saturate())\n",
"egraph2.check(layout2 == layout)\n",
"egraph.display()\n",
"print(egraph2.as_egglog_string)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8a07884",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "egglog (3.13.11)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment