Last active
April 20, 2026 21:24
-
-
Save lucasff/2842cf2f7d318530310ad92c2be49585 to your computer and use it in GitHub Desktop.
provisioning.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env bash | |
| # ────────────────────────────────────────────────────────────────────────────── | |
| # Vast.ai PROVISIONING_SCRIPT for the watermark-remover service. | |
| # | |
| # Host this file at a public URL (GitHub Gist recommended) and set: | |
| # PROVISIONING_SCRIPT=https://gist.githubusercontent.com/.../provisioning.sh | |
| # | |
| # Required env vars (set in Vast.ai instance template): | |
| # HF_TOKEN — HuggingFace access token (for FLUX.1-Kontext-dev) | |
| # GITHUB_TOKEN — GitHub PAT with repo read access (to fetch main.py) | |
| # | |
| # Optional env vars: | |
| # GITHUB_REPO — owner/repo (default: tendailuke/huren) | |
| # BRANCH — git branch (default: main) | |
| # KONTEXT_MODEL — HF model ID (default: black-forest-labs/FLUX.1-Kontext-dev) | |
| # KONTEXT_STEPS — diffusion steps (default: 28) | |
| # KONTEXT_MAX_PX — max input image side in px (default: 1024) | |
| # PORT — service port (default: 8081) | |
| # ────────────────────────────────────────────────────────────────────────────── | |
| set -euo pipefail | |
| WORKDIR="/workspace/watermark-remover" | |
| GITHUB_REPO="${GITHUB_REPO:-tendailuke/huren}" | |
| BRANCH="${BRANCH:-main}" | |
| KONTEXT_MODEL="${KONTEXT_MODEL:-black-forest-labs/FLUX.1-Kontext-dev}" | |
| PORT="${PORT:-8081}" | |
| HOST="${HOST:-0.0.0.0}" | |
| LOG_FILE="/workspace/watermark-remover.log" | |
| LOG="$LOG_FILE" | |
| mkdir -p /workspace "$WORKDIR" | |
| STATUS_FILE="/tmp/provision-status.json" | |
| # Write a status JSON visible to vastai-watch via the status-server. | |
| set_status() { | |
| local stage="$1" progress="$2" detail="${3:-}" | |
| printf '{"status":"loading","stage":"%s","progress":%s,"detail":"%s"}\n' \ | |
| "$stage" "$progress" "$detail" > "$STATUS_FILE" | |
| echo "[provision] [$stage $progress%] $detail" | tee -a "$LOG_FILE" | |
| } | |
| set_status "starting" 0 "Provisioning started" | |
| # ── Geo detection — enable CN mirrors if running in China ───────────────────── | |
| COUNTRY=$(curl -fsSL --max-time 5 "https://ipinfo.io/country" 2>/dev/null | tr -d '[:space:]' || echo "") | |
| CN_MODE=0 | |
| if [[ "$COUNTRY" == "CN" ]]; then | |
| CN_MODE=1 | |
| echo "[provision] CN instance detected — will use CN mirrors for PyTorch and HuggingFace." | tee -a "$LOG_FILE" | |
| fi | |
| # ── Fetch status-server.py and start it immediately ────────────────────────── | |
| # This occupies PORT so vastai-watch sees "loading" instead of connection refused. | |
| echo "[provision] Fetching status-server.py..." | tee -a "$LOG_FILE" | |
| curl -fsSL \ | |
| -H "Authorization: Bearer ${GITHUB_TOKEN}" \ | |
| -H "Accept: application/vnd.github.raw" \ | |
| "https://api.github.com/repos/${GITHUB_REPO}/contents/services/watermark-remover/status-server.py?ref=${BRANCH}" \ | |
| -o "${WORKDIR}/status-server.py" | |
| STATUS_FILE="$STATUS_FILE" PORT="$PORT" python3 "${WORKDIR}/status-server.py" >> "${LOG%.log}-status-server.log" 2>&1 & | |
| STATUS_SERVER_PID=$! | |
| echo "[provision] Status server PID ${STATUS_SERVER_PID} on :${PORT}" | tee -a "$LOG_FILE" | |
| # ── Explicit venv binaries (never rely on PATH order) ───────────────────────── | |
| VENV="/venv/main" | |
| PIP="$VENV/bin/pip" | |
| PYTHON="$VENV/bin/python3" | |
| UVICORN="$VENV/bin/uvicorn" | |
| # ── Activate venv ───────────────────────────────────────────────────────────── | |
| set_status "venv" 5 "Activating virtual environment" | |
| if . "$VENV/bin/activate"; then | |
| echo "Virtual environment activated: $("$PYTHON" --version)" | tee -a "$LOG_FILE" | |
| else | |
| echo "Failed to activate virtual environment" | tee -a "$LOG_FILE" | |
| exit 1 | |
| fi | |
| # ── Change to workspace ──────────────────────────────────────────────────────── | |
| if cd /workspace/; then | |
| pwd | tee -a "$LOG_FILE" | |
| else | |
| echo "Failed to change directory to /workspace/" | tee -a "$LOG_FILE" | |
| exit 1 | |
| fi | |
| cd "$WORKDIR" | |
| # ── Fetch service code from private repo ────────────────────────────────────── | |
| set_status "fetch-code" 8 "Fetching service code from ${GITHUB_REPO}@${BRANCH}" | |
| if [[ -z "${GITHUB_TOKEN:-}" ]]; then | |
| echo "[provision] ERROR: GITHUB_TOKEN not set. Cannot fetch service code." | |
| exit 1 | |
| fi | |
| for file in main.py coordinator.py requirements.txt; do | |
| curl -fsSL \ | |
| -H "Authorization: Bearer ${GITHUB_TOKEN}" \ | |
| -H "Accept: application/vnd.github.raw" \ | |
| "https://api.github.com/repos/${GITHUB_REPO}/contents/services/watermark-remover/${file}?ref=${BRANCH}" \ | |
| -o "${WORKDIR}/${file}" | |
| echo "[provision] Downloaded ${file}." | |
| done | |
| # ── Install Python deps ─────────────────────────────────────────────────────── | |
| # Install torch first with the wheel URL that matches the installed CUDA driver. | |
| # Default PyPI torch wheels target CUDA 12.1+ and fail on older drivers. | |
| echo "[provision] Detecting CUDA driver version..." | |
| CUDA_DRIVER=$(nvidia-smi 2>/dev/null | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+" | head -1 || echo "0.0") | |
| CUDA_MAJOR=$(echo "$CUDA_DRIVER" | cut -d. -f1) | |
| CUDA_MINOR=$(echo "$CUDA_DRIVER" | cut -d. -f2) | |
| CUDA_NUM=$((CUDA_MAJOR * 10 + CUDA_MINOR)) # e.g. 12.1 → 121, 12.0 → 120 | |
| if [ "$CUDA_NUM" -ge 124 ]; then TORCH_CU="cu124" | |
| elif [ "$CUDA_NUM" -ge 121 ]; then TORCH_CU="cu121" | |
| elif [ "$CUDA_NUM" -ge 120 ]; then TORCH_CU="cu121" # cu120 wheels not published; cu121 works on 12.0 driver with minor mismatch—use cu118 for safety | |
| elif [ "$CUDA_NUM" -ge 118 ]; then TORCH_CU="cu118" | |
| else TORCH_CU="cu118" | |
| fi | |
| # CUDA 12.0 driver (520.x) cannot run cu121 wheels — use cu118 which is compatible. | |
| if [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -eq 0 ]; then | |
| TORCH_CU="cu118" | |
| fi | |
| if [ "$CN_MODE" -eq 1 ]; then | |
| TORCH_INDEX="https://mirrors.nju.edu.cn/pytorch/whl/${TORCH_CU}" | |
| else | |
| TORCH_INDEX="https://download.pytorch.org/whl/${TORCH_CU}" | |
| fi | |
| # Check if torch is already installed and CUDA works — skip reinstall if so. | |
| install_torch() { | |
| set_status "torch" 12 "Installing torch ${TORCH_CU} from ${TORCH_INDEX}" | |
| "$PIP" install --quiet --no-cache-dir --force-reinstall \ | |
| "torch>=2.3.0" "torchvision" "torchaudio" \ | |
| --index-url "$TORCH_INDEX" | |
| echo "[provision] torch installed." | tee -a "$LOG_FILE" | |
| } | |
| echo "[provision] Checking existing torch/CUDA..." | tee -a "$LOG_FILE" | |
| if "$PYTHON" -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'" 2>/dev/null; then | |
| TORCH_VER=$("$PYTHON" -c "import torch; print(torch.__version__)" 2>/dev/null) | |
| echo "[provision] torch ${TORCH_VER} with CUDA already working — skipping reinstall." | tee -a "$LOG_FILE" | |
| set_status "torch" 12 "torch ${TORCH_VER} already installed" | |
| else | |
| echo "[provision] torch/CUDA not working — installing from ${TORCH_INDEX}" | tee -a "$LOG_FILE" | |
| install_torch | |
| fi | |
| set_status "deps" 20 "Installing Python dependencies" | |
| grep -vE '^torch' requirements.txt | "$PIP" install --quiet --no-cache-dir -r /dev/stdin | |
| echo "[provision] Deps installed." | |
| # ── HuggingFace login ───────────────────────────────────────────────────────── | |
| if [[ -z "${HF_TOKEN:-}" ]]; then | |
| echo "[provision] ERROR: HF_TOKEN not set. Cannot download ${KONTEXT_MODEL}." | |
| exit 1 | |
| fi | |
| if [ "$CN_MODE" -eq 1 ]; then | |
| export HF_ENDPOINT="https://hf-mirror.com" | |
| echo "[provision] CN mode: HF_ENDPOINT=${HF_ENDPOINT}" | tee -a "$LOG_FILE" | |
| fi | |
| set_status "hf-login" 25 "Logging in to HuggingFace" | |
| "$PYTHON" -c " | |
| import os | |
| from huggingface_hub import login | |
| login(token=os.environ['HF_TOKEN'], add_to_git_credential=True) | |
| print('[provision] HuggingFace login OK.') | |
| " | |
| # ── Pre-download model weights (with progress reporting to status-server) ───── | |
| set_status "model-download" 30 "Starting model download ${KONTEXT_MODEL} (~34 GB)" | |
| "$PYTHON" - <<PYEOF | |
| import json, os, time | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from huggingface_hub.utils import tqdm as hf_tqdm | |
| STATUS_FILE = Path(os.getenv("STATUS_FILE", "/tmp/provision-status.json")) | |
| model = os.environ.get("KONTEXT_MODEL", "black-forest-labs/FLUX.1-Kontext-dev") | |
| def write_status(progress: int, detail: str) -> None: | |
| STATUS_FILE.write_text(json.dumps({ | |
| "status": "loading", | |
| "stage": "model-download", | |
| "progress": progress, | |
| "detail": detail, | |
| })) | |
| # huggingface_hub fires tqdm callbacks — patch them to also write status. | |
| _orig_init = hf_tqdm.__init__ | |
| _total_bytes = [0] | |
| _done_bytes = [0] | |
| def _patched_init(self, *args, **kwargs): | |
| _orig_init(self, *args, **kwargs) | |
| if self.total: | |
| _total_bytes[0] = max(_total_bytes[0], self.total) | |
| _orig_update = hf_tqdm.update | |
| _last_write = [0.0] | |
| def _patched_update(self, n=1): | |
| _orig_update(self, n) | |
| _done_bytes[0] += n or 0 | |
| now = time.monotonic() | |
| if now - _last_write[0] >= 5 and _total_bytes[0] > 0: | |
| pct = min(30 + int(60 * _done_bytes[0] / _total_bytes[0]), 89) | |
| done_gb = _done_bytes[0] / 1024**3 | |
| total_gb = _total_bytes[0] / 1024**3 | |
| write_status(pct, f"Downloading model: {done_gb:.1f}/{total_gb:.1f} GB") | |
| _last_write[0] = now | |
| hf_tqdm.__init__ = _patched_init | |
| hf_tqdm.update = _patched_update | |
| write_status(30, f"Downloading {model}...") | |
| snapshot_download(model, ignore_patterns=["*.bin"]) | |
| write_status(90, "Model download complete") | |
| print("[provision] Model download complete.") | |
| PYEOF | |
| # ── Start workers (one per GPU) + coordinator ───────────────────────────────── | |
| # Each worker is pinned to one GPU via CUDA_VISIBLE_DEVICES. | |
| # The coordinator on PORT handles routing and exposes the aggregate /health. | |
| GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l) | |
| GPU_COUNT=$(( GPU_COUNT > 0 ? GPU_COUNT : 1 )) | |
| echo "[provision] GPUs detected: ${GPU_COUNT}" | |
| set_status "vram-check" 91 "Checking per-GPU VRAM" | |
| # ── VRAM check — abort early if any GPU has < 12 GB ────────────────────────── | |
| # Flux.1-Kontext requires ≥12 GB per GPU (4-bit path). vast.ai reports gpu_ram | |
| # as the *total* across all cards, so a 2x RTX 2080 Ti (2×11GB=22GB) would pass | |
| # a gpu_ram>=20 filter but OOM on every request. | |
| MIN_VRAM_GB="${MIN_VRAM_GB:-12}" | |
| echo "[provision] Checking per-GPU VRAM (minimum ${MIN_VRAM_GB} GB)..." | |
| VRAM_FAIL=0 | |
| for i in $(seq 0 $((GPU_COUNT - 1))); do | |
| VRAM_MIB=$(CUDA_VISIBLE_DEVICES=$i nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d ' ') | |
| VRAM_GB=$(( VRAM_MIB / 1024 )) | |
| echo "[provision] GPU ${i}: ${VRAM_GB} GB VRAM" | |
| if [ "$VRAM_GB" -lt "$MIN_VRAM_GB" ]; then | |
| echo "[provision] ERROR: GPU ${i} has only ${VRAM_GB} GB VRAM — need ≥${MIN_VRAM_GB} GB for Flux.1-Kontext. Aborting." | |
| VRAM_FAIL=1 | |
| fi | |
| done | |
| if [ "$VRAM_FAIL" -eq 1 ]; then | |
| exit 1 | |
| fi | |
| set_status "starting-workers" 93 "Starting GPU workers" | |
| WORKER_PORTS_LIST="" | |
| for i in $(seq 0 $((GPU_COUNT - 1))); do | |
| WORKER_PORT=$((PORT + 10 + i)) # e.g. PORT=8081 → workers on 8091, 8092, ... | |
| WORKER_LOG="${LOG%.log}-gpu${i}.log" | |
| CUDA_VISIBLE_DEVICES=$i nohup "$UVICORN" main:app \ | |
| --app-dir "$WORKDIR" \ | |
| --host 127.0.0.1 \ | |
| --port "$WORKER_PORT" \ | |
| --workers 1 \ | |
| --log-level info \ | |
| >> "$WORKER_LOG" 2>&1 & | |
| echo "[provision] GPU ${i} → worker on 127.0.0.1:${WORKER_PORT} (PID $!). Logs: ${WORKER_LOG}" | |
| [[ -n "$WORKER_PORTS_LIST" ]] && WORKER_PORTS_LIST="${WORKER_PORTS_LIST}," | |
| WORKER_PORTS_LIST="${WORKER_PORTS_LIST}${WORKER_PORT}" | |
| done | |
| # ── CUDA smoke-test: wait for workers to init, check logs for torch errors ──── | |
| # Workers import torch + load the model on startup. If the pre-installed torch | |
| # has a CUDA ABI mismatch, it surfaces as an error within the first ~30s. | |
| # Patterns: "CUDA driver version is insufficient", "no kernel image", | |
| # "libcuda.so", "CUDA error", "torch.cuda.is_available() returned False" | |
| CUDA_ERROR_PATTERNS="CUDA driver version is insufficient|no kernel image|libcuda\.so|CUDA error|cuda\.is_available.*False|RuntimeError.*CUDA|AssertionError.*CUDA" | |
| echo "[provision] Waiting 30s for workers to initialise (CUDA smoke-test)..." | tee -a "$LOG_FILE" | |
| sleep 30 | |
| CUDA_BAD=0 | |
| for i in $(seq 0 $((GPU_COUNT - 1))); do | |
| WORKER_LOG="${LOG%.log}-gpu${i}.log" | |
| if grep -qE "$CUDA_ERROR_PATTERNS" "$WORKER_LOG" 2>/dev/null; then | |
| echo "[provision] GPU ${i} worker CUDA error detected in ${WORKER_LOG} — will reinstall torch." | tee -a "$LOG_FILE" | |
| grep -E "$CUDA_ERROR_PATTERNS" "$WORKER_LOG" | head -3 | tee -a "$LOG_FILE" | |
| CUDA_BAD=1 | |
| fi | |
| done | |
| if [ "$CUDA_BAD" -eq 1 ]; then | |
| set_status "torch-retry" 94 "CUDA error detected — reinstalling torch" | |
| echo "[provision] Killing workers for torch reinstall..." | tee -a "$LOG_FILE" | |
| # Kill all uvicorn worker processes | |
| pkill -f "uvicorn main:app" 2>/dev/null || true | |
| sleep 2 | |
| # Force reinstall torch regardless of CN_MODE (already set above) | |
| install_torch | |
| set_status "starting-workers" 95 "Restarting GPU workers after torch reinstall" | |
| WORKER_PORTS_LIST="" | |
| for i in $(seq 0 $((GPU_COUNT - 1))); do | |
| WORKER_PORT=$((PORT + 10 + i)) | |
| WORKER_LOG="${LOG%.log}-gpu${i}.log" | |
| CUDA_VISIBLE_DEVICES=$i nohup "$UVICORN" main:app \ | |
| --app-dir "$WORKDIR" \ | |
| --host 127.0.0.1 \ | |
| --port "$WORKER_PORT" \ | |
| --workers 1 \ | |
| --log-level info \ | |
| >> "$WORKER_LOG" 2>&1 & | |
| echo "[provision] GPU ${i} → worker restarted on 127.0.0.1:${WORKER_PORT} (PID $!)" | tee -a "$LOG_FILE" | |
| [[ -n "$WORKER_PORTS_LIST" ]] && WORKER_PORTS_LIST="${WORKER_PORTS_LIST}," | |
| WORKER_PORTS_LIST="${WORKER_PORTS_LIST}${WORKER_PORT}" | |
| done | |
| echo "[provision] Torch reinstalled and workers restarted." | tee -a "$LOG_FILE" | |
| fi | |
| # ── Hand off PORT from status-server to coordinator ─────────────────────────── | |
| set_status "starting-coordinator" 97 "Starting coordinator on :${PORT}" | |
| if kill "$STATUS_SERVER_PID" 2>/dev/null; then | |
| echo "[provision] Status server (PID ${STATUS_SERVER_PID}) stopped." | tee -a "$LOG_FILE" | |
| # Brief pause so the OS releases the port before coordinator binds it | |
| sleep 1 | |
| fi | |
| # Coordinator: routes /remove-watermark to idle workers, aggregates /health | |
| WORKER_PORTS="$WORKER_PORTS_LIST" nohup "$UVICORN" coordinator:app \ | |
| --app-dir "$WORKDIR" \ | |
| --host "$HOST" \ | |
| --port "$PORT" \ | |
| --workers 1 \ | |
| --log-level info \ | |
| >> "${LOG%.log}-coordinator.log" 2>&1 & | |
| echo "[provision] Coordinator on ${HOST}:${PORT} → workers [${WORKER_PORTS_LIST}] (PID $!)" | |
| echo "[provision] Logs: tail -f ${LOG%.log}*.log" | |
| echo "[provision] Done." |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment