Created
April 24, 2026 18:10
-
-
Save Dogacel/5029067457edbeb9b483af4c77ca2618 to your computer and use it in GitHub Desktop.
FlashInfer-Bench Modal Cloud Benchmark Runner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Official-evaluation mirror: runs our solution under the exact environment | |
| the contest uses (per EVALUATION.md of the starter kit). | |
| - Uses the official Docker image: flashinfer/flashinfer-ci-cu132:latest | |
| - Installs our solution by dropping solution.json | |
| - Downloads the contest dataset from HuggingFace into a Modal volume | |
| - Invokes the flashinfer-bench CLI verbatim per track | |
| - Attempts to lock GPU clocks with `nvidia-smi -ac 3996,1965` | |
| """ | |
| import sys | |
| import modal | |
| from pathlib import Path | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| dataset_vol = modal.Volume.from_name("mlsys26-contest", create_if_missing=True) | |
| results_vol = modal.Volume.from_name("flashinfer-eval-results", create_if_missing=True) | |
| DATASET_PATH = "/data" | |
| DATASET_SUBDIR = "mlsys26-contest" | |
| RESULTS_PATH = "/results" | |
| # Official image per EVALUATION.md. Python is whatever the image ships. | |
| # We add our extension's build-time deps (git for cloning CUTLASS). | |
| # | |
| # CRITICAL: the image ships `flashinfer-bench 0.1.2` from PyPI, but the spec | |
| # says "flashinfer-bench latest main, built from source" — and the 0.1.2 | |
| # evaluator compares indexer output tensors element-by-element with | |
| # required_matched_ratio=1.0, which rejects our valid-but-reordered top-K | |
| # output. We override the image install with a fresh pip install from | |
| # git main to get the lenient DSA evaluator the contest actually runs. | |
| image = ( | |
| modal.Image.from_registry("flashinfer/flashinfer-ci-cu132:20260401-2c675fb", | |
| add_python="3.12") | |
| .apt_install("git", "wget", "build-essential", "cmake") | |
| .pip_install("huggingface_hub") | |
| .run_commands( | |
| # Force-reinstall latest main of flashinfer-bench. Override the | |
| # PyPI 0.1.2 already in the image. Pip pulls/refreshes deps as | |
| # needed (pydantic, safetensors, tvm-ffi, ...). | |
| "pip install --force-reinstall --upgrade " | |
| "git+https://github.com/flashinfer-ai/flashinfer-bench.git@main", | |
| ) | |
| .pip_install( | |
| # cupti-python is required for accurate GPU timing (per | |
| # EVALUATION.md). Without it, flashinfer-bench's timer falls | |
| # back to CUDA events, which include kernel-launch overhead | |
| # (~2-4 μs per call) and over-report latency. The official | |
| # image is *supposed* to ship this but the released 0.1.2 | |
| # dep-chain doesn't pull it in. | |
| "cupti-python", | |
| ) | |
| # Caches the CUTLASS clone + torch extension build dir across runs. | |
| .env({ | |
| "TORCH_EXTENSIONS_DIR": "/results/torch_ext_cache", | |
| "DSA_CUTLASS_DIR": "/results/cutlass", | |
| }) | |
| ) | |
| app = modal.App("flashinfer-eval") | |
| def _download_dataset_if_missing(dataset_root): | |
| from pathlib import Path | |
| if not dataset_root.exists() or not any(dataset_root.iterdir()): | |
| print(f"Downloading contest dataset to {dataset_root}...") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id="flashinfer-ai/mlsys26-contest", | |
| repo_type="dataset", | |
| local_dir=str(dataset_root), | |
| ) | |
| dataset_vol.commit() | |
| @app.function(image=image, gpu="B200:1", timeout=3600, | |
| volumes={DATASET_PATH: dataset_vol, RESULTS_PATH: results_vol}) | |
| def run_eval_fast(solution_json: str, | |
| solution_name: str, | |
| definition: str, | |
| env_vars: dict = None) -> dict: | |
| """Fast iteration variant: bypass the CLI, call the Python API directly | |
| with `profile_baseline=False`. Same image / dataset / versions as the | |
| official eval — only difference is skipping the 100-iter PyTorch | |
| baseline measurement per workload (which is ~98% of wall-clock per | |
| our CLAUDE.md notes). Use this for "does our solution pass" checks; | |
| use run_eval for the authoritative speedup_factor report. | |
| env_vars: optional dict[str,str] applied via os.environ BEFORE any | |
| flashinfer_bench / solution import, so the solution's import-time env | |
| reads pick them up.""" | |
| import os | |
| if env_vars: | |
| for k, v in env_vars.items(): | |
| os.environ[k] = str(v) | |
| print(f"env {k}={v}") | |
| import subprocess | |
| from pathlib import Path | |
| from flashinfer_bench import ( | |
| Benchmark, BenchmarkConfig, Solution, TraceSet) | |
| dataset_root = Path(DATASET_PATH) / DATASET_SUBDIR | |
| _download_dataset_if_missing(dataset_root) | |
| # Lock clocks (best effort — fails without priv; logged, continue). | |
| r = subprocess.run(["nvidia-smi", "-ac", "3996,1965"], | |
| capture_output=True, text=True) | |
| print(f"nvidia-smi -ac: rc={r.returncode} {r.stdout.strip()[:120]}") | |
| # Load trace set + filter to target definition / solution. | |
| trace_set = TraceSet.from_path(str(dataset_root)) | |
| if definition not in trace_set.definitions: | |
| print(f"Available definitions: {list(trace_set.definitions)}") | |
| raise ValueError(f"Definition {definition!r} not in trace set") | |
| sol = Solution.model_validate_json(solution_json) | |
| defn = trace_set.definitions[definition] | |
| workloads = trace_set.workloads.get(definition, []) | |
| print(f"Running {len(workloads)} workloads for {definition}") | |
| bench_ts = TraceSet( | |
| root=trace_set.root, | |
| definitions={defn.name: defn}, | |
| solutions={defn.name: [sol]}, | |
| workloads={defn.name: workloads}, | |
| traces={defn.name: []}, | |
| ) | |
| # Keep defaults (warmup_runs=10, iterations=50, num_trials=3, | |
| # rtol/atol=0.01, timeout_seconds=300) — matches official eval except | |
| # for (a) profile_baseline=False (speed) and (b) use_isolated_runner= | |
| # False (PersistentRunner avoids 128× subprocess-spawn + Triton-JIT | |
| # + extension-warmup cost; trades away subprocess isolation which | |
| # isn't needed for single-solution verification). | |
| cfg = BenchmarkConfig(profile_baseline=False, | |
| use_isolated_runner=False) | |
| bench = Benchmark(bench_ts, cfg) | |
| result_ts = bench.run_all(dump_traces=False) | |
| traces = result_ts.traces.get(defn.name, []) | |
| results = {} | |
| for t in traces: | |
| if t.evaluation: | |
| e = {"status": t.evaluation.status.value} | |
| if t.evaluation.performance: | |
| e["latency_ms"] = t.evaluation.performance.latency_ms | |
| if t.evaluation.correctness: | |
| e["abs_err"] = t.evaluation.correctness.max_absolute_error | |
| e["rel_err"] = t.evaluation.correctness.max_relative_error | |
| if t.evaluation.status.value != "PASSED" and t.evaluation.log: | |
| # Surface the first per-workload error so we can debug | |
| # extension / import problems without silent RUNTIME_ERROR. | |
| print(f"LOG {t.workload.uuid[:12]}…: {t.evaluation.log[-2000:]}") | |
| results[t.workload.uuid] = e | |
| # Pretty print summary. Enum value is "PASSED" uppercase. | |
| passed = sum(1 for r in results.values() if r["status"] == "PASSED") | |
| failed = [(u, r) for u, r in results.items() if r["status"] != "PASSED"] | |
| lats = [r["latency_ms"] for r in results.values() if r.get("latency_ms") is not None] | |
| mean_us = (sum(lats) / len(lats) * 1000) if lats else 0.0 | |
| print(f"\n=== {definition} ===") | |
| print(f" PASSED: {passed} / {len(results)}") | |
| print(f" mean latency: {mean_us:.3f} μs (n={len(lats)})") | |
| print(" per-workload (sorted by latency):") | |
| sorted_items = sorted(results.items(), | |
| key=lambda kv: kv[1].get("latency_ms") or 0.0) | |
| for u, r in sorted_items: | |
| lat = r.get("latency_ms") | |
| lat_us = f"{lat*1000:7.3f}" if lat is not None else " n/a" | |
| print(f" {u[:12]}… {lat_us} μs {r['status']}") | |
| for u, r in failed[:10]: | |
| print(f" FAIL {u[:8]}… {r['status']} abs={r.get('abs_err')} rel={r.get('rel_err')}") | |
| return {"passed": passed, "total": len(results), "mean_us": mean_us, | |
| "per_workload": {u: r.get("latency_ms") for u, r in results.items()}, | |
| "failed": [(u, r) for u, r in failed[:30]]} | |
| @app.function(image=image, gpu="B200:1", timeout=3600, | |
| volumes={DATASET_PATH: dataset_vol, RESULTS_PATH: results_vol}) | |
| def run_eval(solution_json: str, | |
| solution_name: str, | |
| definition: str, | |
| extra_args: list = None) -> dict: | |
| """Drop our solution into the dataset's solutions/ tree and invoke the | |
| official flashinfer-bench CLI.""" | |
| import os, subprocess, json, shutil | |
| from pathlib import Path | |
| dataset_root = Path(DATASET_PATH) / DATASET_SUBDIR | |
| # ---- 1. Download the contest dataset from HuggingFace if missing ---- | |
| _download_dataset_if_missing(dataset_root) | |
| print(f"Dataset contents at {dataset_root}:") | |
| for p in sorted(dataset_root.rglob("*"))[:30]: | |
| print(" ", p.relative_to(dataset_root)) | |
| # ---- 2. Drop our solution.json into the expected location ---- | |
| # flashinfer-bench expects solutions under <dataset>/solutions/<name>/solution.json | |
| sol_dir = dataset_root / "solutions" / solution_name | |
| sol_dir.mkdir(parents=True, exist_ok=True) | |
| (sol_dir / "solution.json").write_text(solution_json) | |
| print(f"Wrote solution to {sol_dir / 'solution.json'}") | |
| # ---- 3. Attempt to lock GPU clocks (no-op without priv; log outcome) ---- | |
| lock_r = subprocess.run( | |
| ["nvidia-smi", "-ac", "3996,1965"], | |
| capture_output=True, text=True) | |
| print(f"nvidia-smi -ac 3996,1965: rc={lock_r.returncode}") | |
| if lock_r.stdout: print(" stdout:", lock_r.stdout.strip()) | |
| if lock_r.stderr: print(" stderr:", lock_r.stderr.strip()) | |
| # ---- 4. Invoke the official CLI ---- | |
| cmd = [ | |
| "flashinfer-bench", "run", | |
| "--local", str(dataset_root), | |
| "--definitions", definition, | |
| "--save-results", | |
| "--use-isolated-runner", | |
| "--log-level", "INFO", | |
| "--resume", | |
| "--timeout", "300", | |
| ] + list(extra_args or []) | |
| print(f"$ {' '.join(cmd)}") | |
| r = subprocess.run(cmd, capture_output=True, text=True, timeout=3000, | |
| cwd=str(dataset_root)) | |
| print("=== STDOUT ===") | |
| print(r.stdout) | |
| print("=== STDERR ===") | |
| print(r.stderr) | |
| # ---- 5. Find result files the CLI wrote ---- | |
| result_files = [] | |
| for root in [dataset_root, Path(RESULTS_PATH)]: | |
| for f in root.rglob("*.json"): | |
| if "result" in f.name.lower() or "trace" in f.name.lower(): | |
| result_files.append(str(f)) | |
| print(f"Result files found: {len(result_files)}") | |
| for f in result_files[:10]: | |
| print(" ", f) | |
| return { | |
| "returncode": r.returncode, | |
| "stdout": r.stdout, | |
| "stderr": r.stderr, | |
| "result_files": result_files, | |
| } | |
| @app.local_entrypoint() | |
| def main(mode: str = "run_fast", | |
| definition: str = "dsa_topk_indexer_fp8_h64_d128_topk2048_ps64", | |
| extra: str = "", | |
| env: str = ""): | |
| """ | |
| mode = "probe" → inspect the image; don't run any benchmark | |
| mode = "run" → pack + full CLI eval (slow — runs PyTorch baseline too) | |
| mode = "run_fast" → pack + Python-API eval with profile_baseline=False | |
| (default — skips baseline iterations, ~10-20x faster) | |
| extra → additional CLI args for "run" mode, comma-separated | |
| (e.g. "--num-trials,3,--warmup-runs,1") | |
| env → comma-separated K=V pairs applied as environment | |
| variables inside the Modal container BEFORE any | |
| flashinfer_bench / solution import (e.g. | |
| "DSA_SPARSE_K_SPLITS_CAP=16,DSA_SPARSE_BLOCK_K=64"). | |
| """ | |
| if mode == "probe": | |
| probe.remote() | |
| return | |
| from flashinfer_bench.data import Solution | |
| from scripts.pack_solution import pack_solution | |
| # Parse "K=V,K2=V2" → {K: V, K2: V2}. | |
| env_vars: dict = {} | |
| if env: | |
| for pair in env.split(","): | |
| pair = pair.strip() | |
| if not pair: | |
| continue | |
| if "=" not in pair: | |
| raise ValueError(f"--env entry {pair!r} must be K=V") | |
| k, v = pair.split("=", 1) | |
| env_vars[k.strip()] = v.strip() | |
| print(f"Forwarding env vars to container: {env_vars}") | |
| print("Packing solution...") | |
| sol_path = pack_solution() | |
| sol = Solution.model_validate_json(sol_path.read_text()) | |
| print(f"Loaded: {sol.name} definition={sol.definition}") | |
| if mode == "run_fast": | |
| result = run_eval_fast.remote( | |
| solution_json=sol_path.read_text(), | |
| solution_name=sol.name, | |
| definition=definition, | |
| env_vars=env_vars, | |
| ) | |
| print(f"\n=== result ===") | |
| print(f" PASSED: {result['passed']} / {result['total']}") | |
| print(f" mean latency: {result['mean_us']:.3f} μs") | |
| if result['failed']: | |
| print(f" first failures:") | |
| for u, r in result['failed'][:10]: | |
| print(f" {u[:8]}… {r['status']}") | |
| else: | |
| extra_args = [a for a in extra.split(",") if a] if extra else [] | |
| result = run_eval.remote( | |
| solution_json=sol_path.read_text(), | |
| solution_name=sol.name, | |
| definition=definition, | |
| extra_args=extra_args, | |
| ) | |
| print(f"\n=== final returncode: {result['returncode']} ===") | |
| if result['returncode'] != 0: | |
| print("\n=== stderr (full) ===") | |
| print(result['stderr']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment