Skip to content

Instantly share code, notes, and snippets.

@santolucito
Last active September 16, 2025 20:12
Show Gist options
  • Select an option

  • Save santolucito/43519726b0181e95581cca9501f157a9 to your computer and use it in GitHub Desktop.

Select an option

Save santolucito/43519726b0181e95581cca9501f157a9 to your computer and use it in GitHub Desktop.
# amaranth_sygus_pbe_demo.py
#
# Minimal Amaranth demo for a SyGuS-style PBE bitvector task on FPGA-like hardware.
# It enumerates a tiny fixed-shape program: y = op2( op1(x, c0), c1 )
# over 4-bit values, and checks each candidate against example IO pairs stored in ROM.
# A candidate that satisfies all examples is reported as a solution.
#
# This is a hardware-accelerated-friendly shape: the checker is a small pipeline; the
# enumerator steps through configuration space (op1, op2, c0, c1). On real FPGA you would
# replicate the checker and/or feed batches; here we simulate with Amaranth's Simulator.
#
# Requires: amaranth (>=0.4). Install: `pip install amaranth`
#
# Finds a solution in about 30 seconds on a CPU. Still need to try on real hardware (FPGA).
#
# Author: santolucito (c) 2025, provided for educational demo purposes.
from amaranth import *
from amaranth.sim import Simulator
# ----- DSL definition -----
# 4-bit domain; small op set that is expressive enough to solve simple PBE tasks
# while keeping search space tiny for a demo.
OP_NOP = 0 # identity
OP_NOT = 1 # bitwise not
OP_ANDC = 2 # x & c
OP_ORC = 3 # x | c
OP_XORC = 4 # x ^ c
OP_ADDC = 5 # (x + c) mod 16
OP_ROTL1 = 6 # rotate left by 1
OP_ROTR1 = 7 # rotate right by 1
NUM_OPS = 8
class OpUnit(Elaboratable):
"""Single operation: y = op(x, c) over 4-bit words."""
def __init__(self):
self.x = Signal(4)
self.c = Signal(4)
self.op_sel = Signal(range(NUM_OPS))
self.y = Signal(4)
def elaborate(self, platform):
m = Module()
x, c, y, op = self.x, self.c, self.y, self.op_sel
with m.Switch(op):
with m.Case(OP_NOP): m.d.comb += y.eq(x)
with m.Case(OP_NOT): m.d.comb += y.eq(~x)
with m.Case(OP_ANDC): m.d.comb += y.eq(x & c)
with m.Case(OP_ORC): m.d.comb += y.eq(x | c)
with m.Case(OP_XORC): m.d.comb += y.eq(x ^ c)
with m.Case(OP_ADDC): m.d.comb += y.eq((x + c) & 0xF)
with m.Case(OP_ROTL1): m.d.comb += y.eq(Cat(x[3], x[:3]))
with m.Case(OP_ROTR1): m.d.comb += y.eq(Cat(x[1:], x[0]))
return m
class TwoStageProgram(Elaboratable):
"""Fixed-shape program: y = op2( op1(x, c0), c1 )"""
def __init__(self):
self.x = Signal(4)
self.c0 = Signal(4)
self.c1 = Signal(4)
self.op1_sel = Signal(range(NUM_OPS))
self.op2_sel = Signal(range(NUM_OPS))
self.y = Signal(4)
def elaborate(self, platform):
m = Module()
m.submodules.op1 = op1 = OpUnit()
m.submodules.op2 = op2 = OpUnit()
m.d.comb += [
op1.x.eq(self.x),
op1.c.eq(self.c0),
op1.op_sel.eq(self.op1_sel),
op2.x.eq(op1.y),
op2.c.eq(self.c1),
op2.op_sel.eq(self.op2_sel),
self.y.eq(op2.y),
]
return m
class ExampleROM(Elaboratable):
"""Stores up to N example (x,y) pairs as 4-bit values."""
def __init__(self, examples):
assert len(examples) > 0 and len(examples) <= 256
self.N = len(examples)
self.addr = Signal(range(self.N))
self.x = Signal(4)
self.y = Signal(4)
self._memx = Memory(width=4, depth=self.N, init=[e[0] & 0xF for e in examples])
self._memy = Memory(width=4, depth=self.N, init=[e[1] & 0xF for e in examples])
def elaborate(self, platform):
m = Module()
rp_x = m.submodules.rp_x = self._memx.read_port()
rp_y = m.submodules.rp_y = self._memy.read_port()
m.d.comb += [
rp_x.addr.eq(self.addr),
rp_y.addr.eq(self.addr),
self.x.eq(rp_x.data),
self.y.eq(rp_y.data),
]
return m
class EnumeratorChecker(Elaboratable):
"""
Enumerate all configs for y = op2(op1(x,c0), c1) and check against examples.
Reports the first satisfying config via `found` and latches the config.
Config layout (incremented as a counter):
[ op2_sel (3b) | op1_sel (3b) | c1 (4b) | c0 (4b) ] -> total 14 bits
=> 2^14 = 16384 candidates.
"""
def __init__(self, examples):
self.rom = ExampleROM(examples)
# Outputs
self.found = Signal()
self.no_sol = Signal()
self.sol_c0 = Signal(4)
self.sol_c1 = Signal(4)
self.sol_op1 = Signal(3)
self.sol_op2 = Signal(3)
# Stats/visibility
self.candidate_idx = Signal(14) # enumerator counter
self.example_idx = Signal(range(self.rom.N))
self.pass_accum = Signal() # AND of per-example matches
def elaborate(self, platform):
m = Module()
m.submodules.rom = rom = self.rom
prog = m.submodules.prog = TwoStageProgram()
# Decode fields from candidate counter
c0 = Signal(4)
c1 = Signal(4)
op1_sel = Signal(3)
op2_sel = Signal(3)
m.d.comb += [
c0.eq(self.candidate_idx[0:4]),
c1.eq(self.candidate_idx[4:8]),
op1_sel.eq(self.candidate_idx[8:11]),
op2_sel.eq(self.candidate_idx[11:14]),
]
# Hook program under test
m.d.comb += [
prog.c0.eq(c0),
prog.c1.eq(c1),
prog.op1_sel.eq(op1_sel),
prog.op2_sel.eq(op2_sel),
prog.x.eq(rom.x),
]
# FSM to test candidates across all examples
with m.FSM(reset="IDLE") as fsm:
with m.State("IDLE"):
# Reset indices and accumulators
m.d.sync += [
self.candidate_idx.eq(0),
self.example_idx.eq(0),
self.pass_accum.eq(1),
self.found.eq(0),
self.no_sol.eq(0),
]
m.next = "LOAD_EX"
with m.State("LOAD_EX"):
# Point ROM at current example
m.d.sync += rom.addr.eq(self.example_idx)
m.next = "EVAL"
with m.State("EVAL"):
# Combinational program output available same cycle
eq = Signal()
m.d.comb += eq.eq(prog.y == rom.y)
# Accumulate pass/fail
with m.If(eq):
m.d.sync += self.pass_accum.eq(self.pass_accum & 1)
with m.Else():
m.d.sync += self.pass_accum.eq(0)
# Advance examples or decide on candidate
with m.If(self.example_idx == (rom.N - 1)):
with m.If(self.pass_accum):
# Found satisfying candidate; latch solution fields
m.d.sync += [
self.found.eq(1),
self.sol_c0.eq(c0),
self.sol_c1.eq(c1),
self.sol_op1.eq(op1_sel),
self.sol_op2.eq(op2_sel),
]
m.next = "DONE"
with m.Else():
# Next candidate
with m.If(self.candidate_idx == ((1 << 14) - 1)):
m.d.sync += self.no_sol.eq(1)
m.next = "DONE"
with m.Else():
m.d.sync += [
self.candidate_idx.eq(self.candidate_idx + 1),
self.example_idx.eq(0),
self.pass_accum.eq(1),
]
m.next = "LOAD_EX"
with m.Else():
m.d.sync += self.example_idx.eq(self.example_idx + 1)
m.next = "LOAD_EX"
with m.State("DONE"):
m.next = "DONE"
return m
# ----- Demo harness -----
# Choose a ground-truth function within the DSL: e.g., y = ( (x ^ 0xA) + 1 ) & 0xF
def gt_func(x):
return ((x ^ 0xA) + 1) & 0xF
# Build IO examples (like SyGuS PBE). For a realistic task, include 3–8 pairs.
EXAMPLES = [(0x0, gt_func(0x0)),
(0x3, gt_func(0x3)),
(0x7, gt_func(0x7)),
(0xE, gt_func(0xE))]
def run_sim():
top = EnumeratorChecker(EXAMPLES)
sim = Simulator(top)
def proc():
# Run until DONE flags are set or a cycle cap is reached
cycle_cap = 100000 # generous for this small search; real FPGA is instant per cycle
for i in range(cycle_cap):
found = (yield top.found)
no_sol = (yield top.no_sol)
if found or no_sol:
break
yield
if (yield top.found):
c0 = (yield top.sol_c0)
c1 = (yield top.sol_c1)
o1 = (yield top.sol_op1)
o2 = (yield top.sol_op2)
idx = (yield top.candidate_idx)
print("FOUND candidate at idx", idx,
f"\nop1={o1} c0=0x{c0:X} -> op2={o2} c1=0x{c1:X}")
elif (yield top.no_sol):
print("No solution in search space.")
else:
print("Cycle cap reached without decision.")
sim.add_clock(1e-6) # 1 MHz sim clock (arbitrary)
sim.add_sync_process(proc)
sim.run()
if __name__ == "__main__":
run_sim()
# ----- Notes / Extensions -----
# 1) To target FPGA: wrap EnumeratorChecker in a top-level with AXI-lite for config/reads
# and AXI-Stream/BRAM for example blocks. Replicate the TwoStageProgram pipeline N times.
# 2) To move toward predicate form P(grid, x, y, color): widen x to encode (x,y,color)
# bitvectors; examples become (grid_encoding, cell_encoding, label_bit). The same
# enumerator/checker pattern applies.
# 3) To enlarge the grammar without exploding timing: keep per-stage op fan-in small and
# pipeline between stages. Prefer multiple narrow replicas to one wide tree.
# 4) Integrate with a CEGIS loop by: (a) having the host stream new counterexamples into
# BRAM; (b) having hardware return the first K satisfying candidates or per-candidate
# scores; (c) constraining fields on host to prune equivalent configs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment