Skip to content

Instantly share code, notes, and snippets.

@Dogacel
Created April 24, 2026 18:10
Show Gist options
  • Select an option

  • Save Dogacel/5029067457edbeb9b483af4c77ca2618 to your computer and use it in GitHub Desktop.

Select an option

Save Dogacel/5029067457edbeb9b483af4c77ca2618 to your computer and use it in GitHub Desktop.
FlashInfer-Bench Modal Cloud Benchmark Runner
"""
Official-evaluation mirror: runs our solution under the exact environment
the contest uses (per EVALUATION.md of the starter kit).
- Uses the official Docker image: flashinfer/flashinfer-ci-cu132:latest
- Installs our solution by dropping solution.json
- Downloads the contest dataset from HuggingFace into a Modal volume
- Invokes the flashinfer-bench CLI verbatim per track
- Attempts to lock GPU clocks with `nvidia-smi -ac 3996,1965`
"""
import sys
import modal
from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
dataset_vol = modal.Volume.from_name("mlsys26-contest", create_if_missing=True)
results_vol = modal.Volume.from_name("flashinfer-eval-results", create_if_missing=True)
DATASET_PATH = "/data"
DATASET_SUBDIR = "mlsys26-contest"
RESULTS_PATH = "/results"
# Official image per EVALUATION.md. Python is whatever the image ships.
# We add our extension's build-time deps (git for cloning CUTLASS).
#
# CRITICAL: the image ships `flashinfer-bench 0.1.2` from PyPI, but the spec
# says "flashinfer-bench latest main, built from source" — and the 0.1.2
# evaluator compares indexer output tensors element-by-element with
# required_matched_ratio=1.0, which rejects our valid-but-reordered top-K
# output. We override the image install with a fresh pip install from
# git main to get the lenient DSA evaluator the contest actually runs.
image = (
modal.Image.from_registry("flashinfer/flashinfer-ci-cu132:20260401-2c675fb",
add_python="3.12")
.apt_install("git", "wget", "build-essential", "cmake")
.pip_install("huggingface_hub")
.run_commands(
# Force-reinstall latest main of flashinfer-bench. Override the
# PyPI 0.1.2 already in the image. Pip pulls/refreshes deps as
# needed (pydantic, safetensors, tvm-ffi, ...).
"pip install --force-reinstall --upgrade "
"git+https://github.com/flashinfer-ai/flashinfer-bench.git@main",
)
.pip_install(
# cupti-python is required for accurate GPU timing (per
# EVALUATION.md). Without it, flashinfer-bench's timer falls
# back to CUDA events, which include kernel-launch overhead
# (~2-4 μs per call) and over-report latency. The official
# image is *supposed* to ship this but the released 0.1.2
# dep-chain doesn't pull it in.
"cupti-python",
)
# Caches the CUTLASS clone + torch extension build dir across runs.
.env({
"TORCH_EXTENSIONS_DIR": "/results/torch_ext_cache",
"DSA_CUTLASS_DIR": "/results/cutlass",
})
)
app = modal.App("flashinfer-eval")
def _download_dataset_if_missing(dataset_root):
from pathlib import Path
if not dataset_root.exists() or not any(dataset_root.iterdir()):
print(f"Downloading contest dataset to {dataset_root}...")
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="flashinfer-ai/mlsys26-contest",
repo_type="dataset",
local_dir=str(dataset_root),
)
dataset_vol.commit()
@app.function(image=image, gpu="B200:1", timeout=3600,
volumes={DATASET_PATH: dataset_vol, RESULTS_PATH: results_vol})
def run_eval_fast(solution_json: str,
solution_name: str,
definition: str,
env_vars: dict = None) -> dict:
"""Fast iteration variant: bypass the CLI, call the Python API directly
with `profile_baseline=False`. Same image / dataset / versions as the
official eval — only difference is skipping the 100-iter PyTorch
baseline measurement per workload (which is ~98% of wall-clock per
our CLAUDE.md notes). Use this for "does our solution pass" checks;
use run_eval for the authoritative speedup_factor report.
env_vars: optional dict[str,str] applied via os.environ BEFORE any
flashinfer_bench / solution import, so the solution's import-time env
reads pick them up."""
import os
if env_vars:
for k, v in env_vars.items():
os.environ[k] = str(v)
print(f"env {k}={v}")
import subprocess
from pathlib import Path
from flashinfer_bench import (
Benchmark, BenchmarkConfig, Solution, TraceSet)
dataset_root = Path(DATASET_PATH) / DATASET_SUBDIR
_download_dataset_if_missing(dataset_root)
# Lock clocks (best effort — fails without priv; logged, continue).
r = subprocess.run(["nvidia-smi", "-ac", "3996,1965"],
capture_output=True, text=True)
print(f"nvidia-smi -ac: rc={r.returncode} {r.stdout.strip()[:120]}")
# Load trace set + filter to target definition / solution.
trace_set = TraceSet.from_path(str(dataset_root))
if definition not in trace_set.definitions:
print(f"Available definitions: {list(trace_set.definitions)}")
raise ValueError(f"Definition {definition!r} not in trace set")
sol = Solution.model_validate_json(solution_json)
defn = trace_set.definitions[definition]
workloads = trace_set.workloads.get(definition, [])
print(f"Running {len(workloads)} workloads for {definition}")
bench_ts = TraceSet(
root=trace_set.root,
definitions={defn.name: defn},
solutions={defn.name: [sol]},
workloads={defn.name: workloads},
traces={defn.name: []},
)
# Keep defaults (warmup_runs=10, iterations=50, num_trials=3,
# rtol/atol=0.01, timeout_seconds=300) — matches official eval except
# for (a) profile_baseline=False (speed) and (b) use_isolated_runner=
# False (PersistentRunner avoids 128× subprocess-spawn + Triton-JIT
# + extension-warmup cost; trades away subprocess isolation which
# isn't needed for single-solution verification).
cfg = BenchmarkConfig(profile_baseline=False,
use_isolated_runner=False)
bench = Benchmark(bench_ts, cfg)
result_ts = bench.run_all(dump_traces=False)
traces = result_ts.traces.get(defn.name, [])
results = {}
for t in traces:
if t.evaluation:
e = {"status": t.evaluation.status.value}
if t.evaluation.performance:
e["latency_ms"] = t.evaluation.performance.latency_ms
if t.evaluation.correctness:
e["abs_err"] = t.evaluation.correctness.max_absolute_error
e["rel_err"] = t.evaluation.correctness.max_relative_error
if t.evaluation.status.value != "PASSED" and t.evaluation.log:
# Surface the first per-workload error so we can debug
# extension / import problems without silent RUNTIME_ERROR.
print(f"LOG {t.workload.uuid[:12]}…: {t.evaluation.log[-2000:]}")
results[t.workload.uuid] = e
# Pretty print summary. Enum value is "PASSED" uppercase.
passed = sum(1 for r in results.values() if r["status"] == "PASSED")
failed = [(u, r) for u, r in results.items() if r["status"] != "PASSED"]
lats = [r["latency_ms"] for r in results.values() if r.get("latency_ms") is not None]
mean_us = (sum(lats) / len(lats) * 1000) if lats else 0.0
print(f"\n=== {definition} ===")
print(f" PASSED: {passed} / {len(results)}")
print(f" mean latency: {mean_us:.3f} μs (n={len(lats)})")
print(" per-workload (sorted by latency):")
sorted_items = sorted(results.items(),
key=lambda kv: kv[1].get("latency_ms") or 0.0)
for u, r in sorted_items:
lat = r.get("latency_ms")
lat_us = f"{lat*1000:7.3f}" if lat is not None else " n/a"
print(f" {u[:12]}… {lat_us} μs {r['status']}")
for u, r in failed[:10]:
print(f" FAIL {u[:8]}… {r['status']} abs={r.get('abs_err')} rel={r.get('rel_err')}")
return {"passed": passed, "total": len(results), "mean_us": mean_us,
"per_workload": {u: r.get("latency_ms") for u, r in results.items()},
"failed": [(u, r) for u, r in failed[:30]]}
@app.function(image=image, gpu="B200:1", timeout=3600,
volumes={DATASET_PATH: dataset_vol, RESULTS_PATH: results_vol})
def run_eval(solution_json: str,
solution_name: str,
definition: str,
extra_args: list = None) -> dict:
"""Drop our solution into the dataset's solutions/ tree and invoke the
official flashinfer-bench CLI."""
import os, subprocess, json, shutil
from pathlib import Path
dataset_root = Path(DATASET_PATH) / DATASET_SUBDIR
# ---- 1. Download the contest dataset from HuggingFace if missing ----
_download_dataset_if_missing(dataset_root)
print(f"Dataset contents at {dataset_root}:")
for p in sorted(dataset_root.rglob("*"))[:30]:
print(" ", p.relative_to(dataset_root))
# ---- 2. Drop our solution.json into the expected location ----
# flashinfer-bench expects solutions under <dataset>/solutions/<name>/solution.json
sol_dir = dataset_root / "solutions" / solution_name
sol_dir.mkdir(parents=True, exist_ok=True)
(sol_dir / "solution.json").write_text(solution_json)
print(f"Wrote solution to {sol_dir / 'solution.json'}")
# ---- 3. Attempt to lock GPU clocks (no-op without priv; log outcome) ----
lock_r = subprocess.run(
["nvidia-smi", "-ac", "3996,1965"],
capture_output=True, text=True)
print(f"nvidia-smi -ac 3996,1965: rc={lock_r.returncode}")
if lock_r.stdout: print(" stdout:", lock_r.stdout.strip())
if lock_r.stderr: print(" stderr:", lock_r.stderr.strip())
# ---- 4. Invoke the official CLI ----
cmd = [
"flashinfer-bench", "run",
"--local", str(dataset_root),
"--definitions", definition,
"--save-results",
"--use-isolated-runner",
"--log-level", "INFO",
"--resume",
"--timeout", "300",
] + list(extra_args or [])
print(f"$ {' '.join(cmd)}")
r = subprocess.run(cmd, capture_output=True, text=True, timeout=3000,
cwd=str(dataset_root))
print("=== STDOUT ===")
print(r.stdout)
print("=== STDERR ===")
print(r.stderr)
# ---- 5. Find result files the CLI wrote ----
result_files = []
for root in [dataset_root, Path(RESULTS_PATH)]:
for f in root.rglob("*.json"):
if "result" in f.name.lower() or "trace" in f.name.lower():
result_files.append(str(f))
print(f"Result files found: {len(result_files)}")
for f in result_files[:10]:
print(" ", f)
return {
"returncode": r.returncode,
"stdout": r.stdout,
"stderr": r.stderr,
"result_files": result_files,
}
@app.local_entrypoint()
def main(mode: str = "run_fast",
definition: str = "dsa_topk_indexer_fp8_h64_d128_topk2048_ps64",
extra: str = "",
env: str = ""):
"""
mode = "probe" → inspect the image; don't run any benchmark
mode = "run" → pack + full CLI eval (slow — runs PyTorch baseline too)
mode = "run_fast" → pack + Python-API eval with profile_baseline=False
(default — skips baseline iterations, ~10-20x faster)
extra → additional CLI args for "run" mode, comma-separated
(e.g. "--num-trials,3,--warmup-runs,1")
env → comma-separated K=V pairs applied as environment
variables inside the Modal container BEFORE any
flashinfer_bench / solution import (e.g.
"DSA_SPARSE_K_SPLITS_CAP=16,DSA_SPARSE_BLOCK_K=64").
"""
if mode == "probe":
probe.remote()
return
from flashinfer_bench.data import Solution
from scripts.pack_solution import pack_solution
# Parse "K=V,K2=V2" → {K: V, K2: V2}.
env_vars: dict = {}
if env:
for pair in env.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
raise ValueError(f"--env entry {pair!r} must be K=V")
k, v = pair.split("=", 1)
env_vars[k.strip()] = v.strip()
print(f"Forwarding env vars to container: {env_vars}")
print("Packing solution...")
sol_path = pack_solution()
sol = Solution.model_validate_json(sol_path.read_text())
print(f"Loaded: {sol.name} definition={sol.definition}")
if mode == "run_fast":
result = run_eval_fast.remote(
solution_json=sol_path.read_text(),
solution_name=sol.name,
definition=definition,
env_vars=env_vars,
)
print(f"\n=== result ===")
print(f" PASSED: {result['passed']} / {result['total']}")
print(f" mean latency: {result['mean_us']:.3f} μs")
if result['failed']:
print(f" first failures:")
for u, r in result['failed'][:10]:
print(f" {u[:8]}… {r['status']}")
else:
extra_args = [a for a in extra.split(",") if a] if extra else []
result = run_eval.remote(
solution_json=sol_path.read_text(),
solution_name=sol.name,
definition=definition,
extra_args=extra_args,
)
print(f"\n=== final returncode: {result['returncode']} ===")
if result['returncode'] != 0:
print("\n=== stderr (full) ===")
print(result['stderr'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment