Skip to content

Instantly share code, notes, and snippets.

@segeljakt
Created April 23, 2024 12:24
Show Gist options
  • Select an option

  • Save segeljakt/e6342d5d3243d0e867dfd3d2d5763645 to your computer and use it in GitHub Desktop.

Select an option

Save segeljakt/e6342d5d3243d0e867dfd3d2d5763645 to your computer and use it in GitHub Desktop.
(datatype Expr
; Arithmetic
(Num i64)
(Add Expr Expr :cost 1)
(Sub Expr Expr)
(Mul Expr Expr :cost 10)
(Div Expr Expr)
(Pow Expr Expr)
(Mod Expr Expr)
(Log10 Expr)
(Floor Expr)
(Max Expr Expr))
(ruleset universe)
(relation Universe (Expr))
(rule [(= e (Num a))] [(Universe e)] :ruleset universe)
(rule [(= e (Add a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Sub a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Mul a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Div a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Pow a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Mod a b))] [(Universe e)] :ruleset universe)
(rule [(= e (Log10 a))] [(Universe e)] :ruleset universe)
(rule [(= e (Floor a))] [(Universe e)] :ruleset universe)
(rule [(= e (Max a b))] [(Universe e)] :ruleset universe)
; rule6-karatsuba for x*y:
;
; Rule 1: a = b*10^n + c
; where n = floor(log10(a))
; b = a / 10^n
; c = a % 10^n
;
; x*y = (x1*10^n + x0) * (y1*10^n + y0)
;
; Rule 2: (a+b) * (c+d) = ac + ad + bc + bd
;
; ... = x1*10^n*y1*10^n + x1*10^n*y0 + x0*y1*10^n + x0*y0
;
; Rule 3: a*b = b*a
;
; ... = x1*y1*10^n*10^n + x1*y0*10^n + x0*y1*10^n + x0*y0
;
; Rule 4: a^n*a^m = a^(n+m)
;
; ... = x1*y1*10^2n + x1*y0*10^n + x0*y1*10^n + x0*y0
;
; Rule 5: a*b + c*b = (a + c)*b
;
; ... = x1*y1*10^2n + (x1*y0 + x0*y1)*10^n + x0*y0
;
; z2 = x1*y1
; z1 = x1*y0 + x0*y1
; z0 = x0*y0
;
; ... = z2*10^2n + z1*10^n + z0
;
; Rule 6: a*d + b*c = (a+b)*(c+d) - a*c - b*d
;
; z1 = (x1+x0)*(y1+y0) - x1*y1 - x0*y0
; = z3 - z2 - z0
;
; ... = z2*10^2n + (z3 - z2 - z0)*10^n + z0
(ruleset rule1-split)
; Rule 1: x = x1 * 10^n + x0
(rule [(Universe x)]
[(let n (Floor (Log10 x))) ; n = floor(log10(e))
(let pow (Pow (Num 10) n)) ; pow = 10^n
(let x1 (Div x pow)) ; x = e / pow
(let x0 (Mod x pow)) ; y = e % pow
(union x (Add (Mul x1 pow) x0))]
:ruleset rule1-split)
(ruleset rule2-mul-add-distribute)
; Rule 2: (a + b) * (c + d) = ac + ad + bc + bd
(rule [(= e (Mul (Add a b) (Add c d)))]
[(let ac (Mul a c))
(let ad (Mul a d))
(let bc (Mul b c))
(let bd (Mul b d))
(union e (Add (Add ac ad) (Add bc bd)))]
:ruleset rule2-mul-add-distribute)
(ruleset rule3-mul-commutative)
; Rule 3: a*b = b*a
(rewrite (Mul a b) (Mul b a) :ruleset rule3-mul-commutative)
(ruleset rule4-pow-add)
; Rule 4: a^n*a^m = a^(n+m)
(rule [(= e (Mul (Pow a n) (Pow a m)))]
[(let nm (Add n m))
(union e (Pow a nm))]
:ruleset rule4-pow-add)
(ruleset rule5-mul-factor)
; Rule 5: a*b + c*b = (a + c)*b
(rule [(= e (Add (Mul a b) (Mul c b)))]
[(let ac (Add a c))
(union e (Mul ac b))]
:ruleset rule5-mul-factor)
(ruleset rule6-karatsuba)
; Rule 6: a*d + b*c = (a+b)*(c+d) - a*c - b*d
(rule [(= e (Add (Mul a d) (Mul b c)))]
[(let ab (Add a b))
(let cd (Add c d))
(let ac (Mul a c))
(let bd (Mul b d))
(union e (Sub (Sub (Mul ab cd) ac) bd))]
:ruleset rule6-karatsuba)
; Test: x = x1*10^n + x0
(push 1)
(let x (Num 1234))
(run-schedule
(saturate universe)
(saturate rule1-split))
(let n (Floor (Log10 x)))
(let pow (Pow (Num 10) n))
(let x1 (Div x pow))
(let x0 (Mod x pow))
(let expected (Add (Mul x1 pow) x0))
(check (= x expected))
(pop 1)
; Test: (a + b) * (c + d) = ac + ad + bc + bd
(push 1)
(let a (Num 1234))
(let b (Num 5678))
(let c (Num 9101))
(let d (Num 1121))
(let provided (Mul (Add a b) (Add c d)))
(run-schedule
(saturate universe)
(saturate rule1-split)
(saturate rule2-mul-add-distribute))
(let ac (Mul a c))
(let ad (Mul a d))
(let bc (Mul b c))
(let bd (Mul b d))
(let expected (Add (Add ac ad) (Add bc bd)))
(check (= provided expected))
(pop 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment