Last active
September 16, 2025 20:12
-
-
Save santolucito/43519726b0181e95581cca9501f157a9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # amaranth_sygus_pbe_demo.py | |
| # | |
| # Minimal Amaranth demo for a SyGuS-style PBE bitvector task on FPGA-like hardware. | |
| # It enumerates a tiny fixed-shape program: y = op2( op1(x, c0), c1 ) | |
| # over 4-bit values, and checks each candidate against example IO pairs stored in ROM. | |
| # A candidate that satisfies all examples is reported as a solution. | |
| # | |
| # This is a hardware-accelerated-friendly shape: the checker is a small pipeline; the | |
| # enumerator steps through configuration space (op1, op2, c0, c1). On real FPGA you would | |
| # replicate the checker and/or feed batches; here we simulate with Amaranth's Simulator. | |
| # | |
| # Requires: amaranth (>=0.4). Install: `pip install amaranth` | |
| # | |
| # Finds a solution in about 30 seconds on a CPU. Still need to try on real hardware (FPGA). | |
| # | |
| # Author: santolucito (c) 2025, provided for educational demo purposes. | |
| from amaranth import * | |
| from amaranth.sim import Simulator | |
| # ----- DSL definition ----- | |
| # 4-bit domain; small op set that is expressive enough to solve simple PBE tasks | |
| # while keeping search space tiny for a demo. | |
| OP_NOP = 0 # identity | |
| OP_NOT = 1 # bitwise not | |
| OP_ANDC = 2 # x & c | |
| OP_ORC = 3 # x | c | |
| OP_XORC = 4 # x ^ c | |
| OP_ADDC = 5 # (x + c) mod 16 | |
| OP_ROTL1 = 6 # rotate left by 1 | |
| OP_ROTR1 = 7 # rotate right by 1 | |
| NUM_OPS = 8 | |
| class OpUnit(Elaboratable): | |
| """Single operation: y = op(x, c) over 4-bit words.""" | |
| def __init__(self): | |
| self.x = Signal(4) | |
| self.c = Signal(4) | |
| self.op_sel = Signal(range(NUM_OPS)) | |
| self.y = Signal(4) | |
| def elaborate(self, platform): | |
| m = Module() | |
| x, c, y, op = self.x, self.c, self.y, self.op_sel | |
| with m.Switch(op): | |
| with m.Case(OP_NOP): m.d.comb += y.eq(x) | |
| with m.Case(OP_NOT): m.d.comb += y.eq(~x) | |
| with m.Case(OP_ANDC): m.d.comb += y.eq(x & c) | |
| with m.Case(OP_ORC): m.d.comb += y.eq(x | c) | |
| with m.Case(OP_XORC): m.d.comb += y.eq(x ^ c) | |
| with m.Case(OP_ADDC): m.d.comb += y.eq((x + c) & 0xF) | |
| with m.Case(OP_ROTL1): m.d.comb += y.eq(Cat(x[3], x[:3])) | |
| with m.Case(OP_ROTR1): m.d.comb += y.eq(Cat(x[1:], x[0])) | |
| return m | |
| class TwoStageProgram(Elaboratable): | |
| """Fixed-shape program: y = op2( op1(x, c0), c1 )""" | |
| def __init__(self): | |
| self.x = Signal(4) | |
| self.c0 = Signal(4) | |
| self.c1 = Signal(4) | |
| self.op1_sel = Signal(range(NUM_OPS)) | |
| self.op2_sel = Signal(range(NUM_OPS)) | |
| self.y = Signal(4) | |
| def elaborate(self, platform): | |
| m = Module() | |
| m.submodules.op1 = op1 = OpUnit() | |
| m.submodules.op2 = op2 = OpUnit() | |
| m.d.comb += [ | |
| op1.x.eq(self.x), | |
| op1.c.eq(self.c0), | |
| op1.op_sel.eq(self.op1_sel), | |
| op2.x.eq(op1.y), | |
| op2.c.eq(self.c1), | |
| op2.op_sel.eq(self.op2_sel), | |
| self.y.eq(op2.y), | |
| ] | |
| return m | |
| class ExampleROM(Elaboratable): | |
| """Stores up to N example (x,y) pairs as 4-bit values.""" | |
| def __init__(self, examples): | |
| assert len(examples) > 0 and len(examples) <= 256 | |
| self.N = len(examples) | |
| self.addr = Signal(range(self.N)) | |
| self.x = Signal(4) | |
| self.y = Signal(4) | |
| self._memx = Memory(width=4, depth=self.N, init=[e[0] & 0xF for e in examples]) | |
| self._memy = Memory(width=4, depth=self.N, init=[e[1] & 0xF for e in examples]) | |
| def elaborate(self, platform): | |
| m = Module() | |
| rp_x = m.submodules.rp_x = self._memx.read_port() | |
| rp_y = m.submodules.rp_y = self._memy.read_port() | |
| m.d.comb += [ | |
| rp_x.addr.eq(self.addr), | |
| rp_y.addr.eq(self.addr), | |
| self.x.eq(rp_x.data), | |
| self.y.eq(rp_y.data), | |
| ] | |
| return m | |
| class EnumeratorChecker(Elaboratable): | |
| """ | |
| Enumerate all configs for y = op2(op1(x,c0), c1) and check against examples. | |
| Reports the first satisfying config via `found` and latches the config. | |
| Config layout (incremented as a counter): | |
| [ op2_sel (3b) | op1_sel (3b) | c1 (4b) | c0 (4b) ] -> total 14 bits | |
| => 2^14 = 16384 candidates. | |
| """ | |
| def __init__(self, examples): | |
| self.rom = ExampleROM(examples) | |
| # Outputs | |
| self.found = Signal() | |
| self.no_sol = Signal() | |
| self.sol_c0 = Signal(4) | |
| self.sol_c1 = Signal(4) | |
| self.sol_op1 = Signal(3) | |
| self.sol_op2 = Signal(3) | |
| # Stats/visibility | |
| self.candidate_idx = Signal(14) # enumerator counter | |
| self.example_idx = Signal(range(self.rom.N)) | |
| self.pass_accum = Signal() # AND of per-example matches | |
| def elaborate(self, platform): | |
| m = Module() | |
| m.submodules.rom = rom = self.rom | |
| prog = m.submodules.prog = TwoStageProgram() | |
| # Decode fields from candidate counter | |
| c0 = Signal(4) | |
| c1 = Signal(4) | |
| op1_sel = Signal(3) | |
| op2_sel = Signal(3) | |
| m.d.comb += [ | |
| c0.eq(self.candidate_idx[0:4]), | |
| c1.eq(self.candidate_idx[4:8]), | |
| op1_sel.eq(self.candidate_idx[8:11]), | |
| op2_sel.eq(self.candidate_idx[11:14]), | |
| ] | |
| # Hook program under test | |
| m.d.comb += [ | |
| prog.c0.eq(c0), | |
| prog.c1.eq(c1), | |
| prog.op1_sel.eq(op1_sel), | |
| prog.op2_sel.eq(op2_sel), | |
| prog.x.eq(rom.x), | |
| ] | |
| # FSM to test candidates across all examples | |
| with m.FSM(reset="IDLE") as fsm: | |
| with m.State("IDLE"): | |
| # Reset indices and accumulators | |
| m.d.sync += [ | |
| self.candidate_idx.eq(0), | |
| self.example_idx.eq(0), | |
| self.pass_accum.eq(1), | |
| self.found.eq(0), | |
| self.no_sol.eq(0), | |
| ] | |
| m.next = "LOAD_EX" | |
| with m.State("LOAD_EX"): | |
| # Point ROM at current example | |
| m.d.sync += rom.addr.eq(self.example_idx) | |
| m.next = "EVAL" | |
| with m.State("EVAL"): | |
| # Combinational program output available same cycle | |
| eq = Signal() | |
| m.d.comb += eq.eq(prog.y == rom.y) | |
| # Accumulate pass/fail | |
| with m.If(eq): | |
| m.d.sync += self.pass_accum.eq(self.pass_accum & 1) | |
| with m.Else(): | |
| m.d.sync += self.pass_accum.eq(0) | |
| # Advance examples or decide on candidate | |
| with m.If(self.example_idx == (rom.N - 1)): | |
| with m.If(self.pass_accum): | |
| # Found satisfying candidate; latch solution fields | |
| m.d.sync += [ | |
| self.found.eq(1), | |
| self.sol_c0.eq(c0), | |
| self.sol_c1.eq(c1), | |
| self.sol_op1.eq(op1_sel), | |
| self.sol_op2.eq(op2_sel), | |
| ] | |
| m.next = "DONE" | |
| with m.Else(): | |
| # Next candidate | |
| with m.If(self.candidate_idx == ((1 << 14) - 1)): | |
| m.d.sync += self.no_sol.eq(1) | |
| m.next = "DONE" | |
| with m.Else(): | |
| m.d.sync += [ | |
| self.candidate_idx.eq(self.candidate_idx + 1), | |
| self.example_idx.eq(0), | |
| self.pass_accum.eq(1), | |
| ] | |
| m.next = "LOAD_EX" | |
| with m.Else(): | |
| m.d.sync += self.example_idx.eq(self.example_idx + 1) | |
| m.next = "LOAD_EX" | |
| with m.State("DONE"): | |
| m.next = "DONE" | |
| return m | |
| # ----- Demo harness ----- | |
| # Choose a ground-truth function within the DSL: e.g., y = ( (x ^ 0xA) + 1 ) & 0xF | |
| def gt_func(x): | |
| return ((x ^ 0xA) + 1) & 0xF | |
| # Build IO examples (like SyGuS PBE). For a realistic task, include 3–8 pairs. | |
| EXAMPLES = [(0x0, gt_func(0x0)), | |
| (0x3, gt_func(0x3)), | |
| (0x7, gt_func(0x7)), | |
| (0xE, gt_func(0xE))] | |
| def run_sim(): | |
| top = EnumeratorChecker(EXAMPLES) | |
| sim = Simulator(top) | |
| def proc(): | |
| # Run until DONE flags are set or a cycle cap is reached | |
| cycle_cap = 100000 # generous for this small search; real FPGA is instant per cycle | |
| for i in range(cycle_cap): | |
| found = (yield top.found) | |
| no_sol = (yield top.no_sol) | |
| if found or no_sol: | |
| break | |
| yield | |
| if (yield top.found): | |
| c0 = (yield top.sol_c0) | |
| c1 = (yield top.sol_c1) | |
| o1 = (yield top.sol_op1) | |
| o2 = (yield top.sol_op2) | |
| idx = (yield top.candidate_idx) | |
| print("FOUND candidate at idx", idx, | |
| f"\nop1={o1} c0=0x{c0:X} -> op2={o2} c1=0x{c1:X}") | |
| elif (yield top.no_sol): | |
| print("No solution in search space.") | |
| else: | |
| print("Cycle cap reached without decision.") | |
| sim.add_clock(1e-6) # 1 MHz sim clock (arbitrary) | |
| sim.add_sync_process(proc) | |
| sim.run() | |
| if __name__ == "__main__": | |
| run_sim() | |
| # ----- Notes / Extensions ----- | |
| # 1) To target FPGA: wrap EnumeratorChecker in a top-level with AXI-lite for config/reads | |
| # and AXI-Stream/BRAM for example blocks. Replicate the TwoStageProgram pipeline N times. | |
| # 2) To move toward predicate form P(grid, x, y, color): widen x to encode (x,y,color) | |
| # bitvectors; examples become (grid_encoding, cell_encoding, label_bit). The same | |
| # enumerator/checker pattern applies. | |
| # 3) To enlarge the grammar without exploding timing: keep per-stage op fan-in small and | |
| # pipeline between stages. Prefer multiple narrow replicas to one wide tree. | |
| # 4) Integrate with a CEGIS loop by: (a) having the host stream new counterexamples into | |
| # BRAM; (b) having hardware return the first K satisfying candidates or per-candidate | |
| # scores; (c) constraining fields on host to prune equivalent configs. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment