Created
April 30, 2026 18:05
-
-
Save saulshanabrook/3f46d8450fd10783866e1a5f51c0514a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "f479154d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "5b721ce8547b465da62eb5e21a7c5417", | |
| "version_major": 2, | |
| "version_minor": 1 | |
| }, | |
| "text/plain": [ | |
| "VisualizerWidget(egraphs=['{\"nodes\":{\"primitive-i64-3\":{\"op\":\"3\",\"children\":[],\"eclass\":\"i64-3\",\"cost\":1.0,\"su…" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(sort Layout)\n", | |
| "(sort Num)\n", | |
| "(sort Vec[Num] (Vec Num))\n", | |
| "(constructor Layout_strided (Vec[Num] Vec[Num]) Layout)\n", | |
| "(constructor Int (i64) Num)\n", | |
| "(Layout_strided (vec-of (Int 4) (Int 3) (Int 4)) (vec-of (Int 12) (Int 4) (Int 1)))\n", | |
| "(ruleset __main__.layout_rules)\n", | |
| "(constructor Num___add__ (Num Num) Num)\n", | |
| "(rewrite (Num___add__ (Int _i) (Int _j)) (Int (+ _i _j)) :ruleset __main__.layout_rules)\n", | |
| "(constructor Num___mul__ (Num Num) Num)\n", | |
| "(rewrite (Num___mul__ (Int _i) (Int _j)) (Int (* _i _j)) :ruleset __main__.layout_rules)\n", | |
| "(constructor GetOffset (Layout Vec[Num]) Num)\n", | |
| "(sort UnstableFn[Num,Num,Num] (UnstableFn (Num Num) Num))\n", | |
| "(primitive _lambda_0 (Num Num) Num (Num___add__ _0 _1))\n", | |
| "(sort UnstableFn[Num,i64] (UnstableFn (i64) Num))\n", | |
| "(primitive _lambda_1 (Vec[Num] Vec[Num] i64) Num (Num___mul__ (vec-get _1 _2) (vec-get _0 _2)))\n", | |
| "(sort Vec[i64] (Vec i64))\n", | |
| "(rewrite (GetOffset (Layout_strided _shape _strides) _idxs) (vec-foldl (unstable-fn \"_lambda_0\") (Int 0) _multiplied_idxs) :when ((= _multiplied_idxs (unstable-vec-map (unstable-fn \"_lambda_1\" _strides _idxs) (vec-range (vec-length _shape))))) :ruleset __main__.layout_rules)\n", | |
| "(constructor Layout_column_major (Vec[Num]) Layout)\n", | |
| "(sort UnstableFn[Vec[Num],Num,Vec[Num]] (UnstableFn (Num Vec[Num]) Vec[Num]))\n", | |
| "(primitive _lambda_2 (Num Vec[Num]) Vec[Num] (vec-append (vec-of (Num___mul__ (vec-get _1 0) _0)) _1))\n", | |
| "(rewrite (Layout_column_major _shape) (Layout_strided _shape (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape 0))) :ruleset __main__.layout_rules)\n", | |
| "(function Layout__strided_as_column_major_strides (Layout) Vec[Num] :no-merge)\n", | |
| "(rule ((= _l (Layout_strided _shape _strides)))\n", | |
| " ((set (Layout__strided_as_column_major_strides _l) (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape 0))))\n", | |
| " :ruleset __main__.layout_rules )\n", | |
| "(rewrite (Layout_strided _shape _strides) (Layout_column_major _shape) :when ((= _strides (Layout__strided_as_column_major_strides _l))) :ruleset __main__.layout_rules)\n", | |
| "(constructor Layout_row_major (Vec[Num]) Layout)\n", | |
| "(rewrite (Layout_row_major _shape) (Layout_strided _shape (vec-foldr (unstable-fn \"_lambda_2\") (vec-of (Int 1)) (vec-remove _shape (- (vec-length _shape) 1)))) :ruleset __main__.layout_rules)\n", | |
| "(run-schedule (saturate (run __main__.layout_rules)))\n", | |
| "(check (= (Layout_strided (vec-of (Int 4) (Int 3) (Int 4)) (vec-of (Int 12) (Int 4) (Int 1))) (Layout_column_major (vec-of (Int 4) (Int 3) (Int 4)))))\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from __future__ import annotations\n", | |
| "\n", | |
| "from egglog import *\n", | |
| "\n", | |
| "\n", | |
| "class Num(Expr):\n", | |
| " @method(egg_fn=\"Int\")\n", | |
| " def __init__(self, value: i64Like) -> None: ...\n", | |
| "\n", | |
| " def __add__(self, other: Num) -> Num: ...\n", | |
| "\n", | |
| " def __mul__(self, other: Num) -> Num: ...\n", | |
| "\n", | |
| "\n", | |
| "converter(i64, Num, Num)\n", | |
| "\n", | |
| "\n", | |
| "class Layout(Expr):\n", | |
| " \"\"\"\n", | |
| " Two layouts should be equivalent when offset by all all logical indices they give the same\n", | |
| " list of physical offsets.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def strided(cls, shape: Vec[Num], strides: Vec[Num]) -> Layout: ...\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def column_major(cls, shape: Vec[Num]) -> Layout: ...\n", | |
| "\n", | |
| " @property\n", | |
| " def _strided_as_column_major_strides(self) -> Vec[Num]: ...\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def row_major(cls, shape: Vec[Num]) -> Layout: ...\n", | |
| "\n", | |
| " @method(egg_fn=\"GetOffset\")\n", | |
| " def __getitem__(self, indxs: Vec[Num]) -> Num: ...\n", | |
| "\n", | |
| "\n", | |
| "# Sofia B WIP on zulip PLDI\n", | |
| "\n", | |
| "\n", | |
| "@ruleset\n", | |
| "def layout_rules(\n", | |
| " shape: Vec[Num], strides: Vec[Num], idxs: Vec[Num], i: i64, j: i64, multiplied_idxs: Vec[Num], l: Layout\n", | |
| "):\n", | |
| " yield rewrite(Num(i) + Num(j)).to(Num(i + j))\n", | |
| " yield rewrite(Num(i) * Num(j)).to(Num(i * j))\n", | |
| "\n", | |
| " yield rewrite(Layout.strided(shape, strides)[idxs]).to(\n", | |
| " vec_foldl(lambda acc, x: acc + x, Num(0), multiplied_idxs),\n", | |
| " multiplied_idxs == shape.length().range().map(lambda i_: idxs[i_] * strides[i_]),\n", | |
| " )\n", | |
| " column_major_strides = vec_foldr(lambda s, acc: Vec(acc[i64(0)] * s).append(acc), Vec(Num(1)), shape.remove(0))\n", | |
| "\n", | |
| " yield rewrite(Layout.column_major(shape)).to(\n", | |
| " Layout.strided(shape, column_major_strides),\n", | |
| " )\n", | |
| "\n", | |
| " yield rule(l == Layout.strided(shape, strides)).then(\n", | |
| " set_(l._strided_as_column_major_strides).to(column_major_strides)\n", | |
| " )\n", | |
| "\n", | |
| " yield rewrite(Layout.strided(shape, strides)).to(\n", | |
| " Layout.column_major(shape),\n", | |
| " # [Num(12), Num(4), Num(1)]\n", | |
| " # [Num(3) * Num(4) * Num(1), Num(4) * Num(1), Num(1)]\n", | |
| " strides == l._strided_as_column_major_strides, # when\n", | |
| " )\n", | |
| "\n", | |
| " yield rewrite(Layout.row_major(shape)).to(\n", | |
| " Layout.strided(\n", | |
| " shape, vec_foldr(lambda s, acc: Vec(acc[0] * s).append(acc), Vec(Num(1)), shape.remove(shape.length() - 1))\n", | |
| " )\n", | |
| " )\n", | |
| "\n", | |
| "\n", | |
| "layout = Layout.column_major(Vec(Num(4), Num(3), Num(4)))\n", | |
| "egraph = EGraph(save_egglog_string=True)\n", | |
| "layout2 = Layout.strided(Vec(Num(4), Num(3), Num(4)), Vec(Num(12), Num(4), Num(1)))\n", | |
| "\n", | |
| "res = layout[Vec(Num(1), Num(2), Num(3))]\n", | |
| "egraph.register(res, layout2)\n", | |
| "egraph.run(layout_rules.saturate())\n", | |
| "# egraph.display()\n", | |
| "egraph.check(res == Num(23))\n", | |
| "egraph.check(layout == layout2)\n", | |
| "\n", | |
| "\n", | |
| "egraph2 = EGraph(save_egglog_string=True)\n", | |
| "egraph2.register(layout2)\n", | |
| "egraph2.run(layout_rules.saturate())\n", | |
| "egraph2.check(layout2 == layout)\n", | |
| "egraph.display()\n", | |
| "print(egraph2.as_egglog_string)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a8a07884", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "egglog (3.13.11)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment