Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active April 23, 2026 16:50
Show Gist options
  • Select an option

  • Save malfet/17cd64ca959e4870b3c2d14fd0395236 to your computer and use it in GitHub Desktop.

Select an option

Save malfet/17cd64ca959e4870b3c2d14fd0395236 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Repack torchaudio wheels as cp310-abi3 (stable ABI) wheels.
Downloads the cp310 wheel for each platform from PyPI, verifies all native
extensions use the stable ABI, patches the WHEEL tag (adds PEP 427 Build
number) and METADATA (adds torch>=<version> dependency), and repacks.
RECORD regeneration and repacking are handled by auditwheel's InWheelCtx.
Usage:
python repack_torchaudio_abi3.py [--version 2.11.0] [--build 1] [--outdir output/]
"""
import argparse
import os
import subprocess
import sys
import tempfile
from pathlib import Path
from auditwheel.wheeltools import InWheelCtx
# (source abi, source platform, new tag)
PLATFORMS = [
("cp310", "manylinux_2_28_x86_64", "cp310-abi3-manylinux_2_28_x86_64"),
("cp310", "manylinux_2_28_aarch64", "cp310-abi3-manylinux_2_28_aarch64"),
("cp310", "win_amd64", "cp310-abi3-win_amd64"),
("cp310", "macosx_11_0_arm64", "cp310-abi3-macosx_11_0_arm64"),
]
def download_wheel(version, abi, platform, destdir):
filename = f"torchaudio-{version}-{abi}-{abi}-{platform}.whl"
dest = os.path.join(destdir, filename)
if os.path.exists(dest):
return dest
subprocess.check_call([
sys.executable, "-m", "pip", "download",
f"torchaudio=={version}",
"--no-deps", "--only-binary=:all:",
"--python-version", abi.replace("cp", ""),
"--abi", abi,
"--platform", platform,
"-q", "-d", destdir,
])
return dest
def verify_native_files(root_path):
"""Check that all .so files are abi3-tagged. .pyd and .dylib are allowed."""
for root, _, files in os.walk(root_path):
for f in files:
if f.endswith(".so") and ".abi3." not in f:
raise RuntimeError(f"{f} is a .so but not abi3-tagged — cannot repack")
def patch_wheel_tag(wheel_path, new_tag, build_number):
"""Replace Tag: line(s) and set Build: in WHEEL, preserving other fields."""
content = wheel_path.read_text(encoding="utf-8")
trailing = content[len(content.rstrip("\n")):]
lines = [
l for l in content.rstrip("\n").split("\n")
if not l.startswith("Tag:") and not l.startswith("Build:")
]
lines.append(f"Build: {build_number}")
lines.append(f"Tag: {new_tag}")
wheel_path.write_text("\n".join(lines) + trailing, encoding="utf-8")
def patch_metadata_deps(meta_path, version):
content = meta_path.read_text(encoding="utf-8")
if "Requires-Dist: torch" not in content:
pos = content.index("\n\n")
content = content[:pos] + f"\nRequires-Dist: torch (>={version})" + content[pos:]
meta_path.write_text(content, encoding="utf-8")
def process_wheel(version, abi, platform, new_tag, build_number, download_dir, outdir):
print(f"\n{'=' * 60}")
print(f"Platform: {platform} -> {new_tag}")
src_whl = download_wheel(version, abi, platform, download_dir)
print(f" Source: {os.path.basename(src_whl)}")
dist_info = f"torchaudio-{version}.dist-info"
dst_name = f"torchaudio-{version}-{build_number}-{new_tag}.whl"
out_path = Path(outdir).resolve() / dst_name
with InWheelCtx(Path(src_whl).resolve()) as ctx:
ctx.out_wheel = out_path
dist_info_path = ctx.path / dist_info
verify_native_files(ctx.path)
patch_wheel_tag(dist_info_path / "WHEEL", new_tag, build_number)
patch_metadata_deps(dist_info_path / "METADATA", version)
size_mb = os.path.getsize(out_path) / (1024 * 1024)
print(f" Output: {dst_name} ({size_mb:.1f} MB)")
return out_path
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", default="2.11.0", help="torchaudio version (default: 2.11.0)")
parser.add_argument("--build", default=1, type=int, help="PEP 427 build number (default: 1)")
parser.add_argument("--outdir", default="output", help="output directory (default: output/)")
args = parser.parse_args()
os.makedirs(args.outdir, exist_ok=True)
with tempfile.TemporaryDirectory() as download_dir:
for abi, platform, new_tag in PLATFORMS:
process_wheel(args.version, abi, platform, new_tag, args.build, download_dir, args.outdir)
print(f"\n{'=' * 60}")
print(f"Done! Wheels written to {args.outdir}/")
for f in sorted(os.listdir(args.outdir)):
if f.endswith(".whl"):
print(f" {f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment