Skip to content

Instantly share code, notes, and snippets.

@lucasff
Last active April 20, 2026 21:24
Show Gist options
  • Select an option

  • Save lucasff/2842cf2f7d318530310ad92c2be49585 to your computer and use it in GitHub Desktop.

Select an option

Save lucasff/2842cf2f7d318530310ad92c2be49585 to your computer and use it in GitHub Desktop.
provisioning.sh
#!/usr/bin/env bash
# ──────────────────────────────────────────────────────────────────────────────
# Vast.ai PROVISIONING_SCRIPT for the watermark-remover service.
#
# Host this file at a public URL (GitHub Gist recommended) and set:
# PROVISIONING_SCRIPT=https://gist.githubusercontent.com/.../provisioning.sh
#
# Required env vars (set in Vast.ai instance template):
# HF_TOKEN — HuggingFace access token (for FLUX.1-Kontext-dev)
# GITHUB_TOKEN — GitHub PAT with repo read access (to fetch main.py)
#
# Optional env vars:
# GITHUB_REPO — owner/repo (default: tendailuke/huren)
# BRANCH — git branch (default: main)
# KONTEXT_MODEL — HF model ID (default: black-forest-labs/FLUX.1-Kontext-dev)
# KONTEXT_STEPS — diffusion steps (default: 28)
# KONTEXT_MAX_PX — max input image side in px (default: 1024)
# PORT — service port (default: 8081)
# ──────────────────────────────────────────────────────────────────────────────
set -euo pipefail
WORKDIR="/workspace/watermark-remover"
GITHUB_REPO="${GITHUB_REPO:-tendailuke/huren}"
BRANCH="${BRANCH:-main}"
KONTEXT_MODEL="${KONTEXT_MODEL:-black-forest-labs/FLUX.1-Kontext-dev}"
PORT="${PORT:-8081}"
HOST="${HOST:-0.0.0.0}"
LOG_FILE="/workspace/watermark-remover.log"
LOG="$LOG_FILE"
mkdir -p /workspace "$WORKDIR"
STATUS_FILE="/tmp/provision-status.json"
# Write a status JSON visible to vastai-watch via the status-server.
set_status() {
local stage="$1" progress="$2" detail="${3:-}"
printf '{"status":"loading","stage":"%s","progress":%s,"detail":"%s"}\n' \
"$stage" "$progress" "$detail" > "$STATUS_FILE"
echo "[provision] [$stage $progress%] $detail" | tee -a "$LOG_FILE"
}
set_status "starting" 0 "Provisioning started"
# ── Geo detection — enable CN mirrors if running in China ─────────────────────
COUNTRY=$(curl -fsSL --max-time 5 "https://ipinfo.io/country" 2>/dev/null | tr -d '[:space:]' || echo "")
CN_MODE=0
if [[ "$COUNTRY" == "CN" ]]; then
CN_MODE=1
echo "[provision] CN instance detected — will use CN mirrors for PyTorch and HuggingFace." | tee -a "$LOG_FILE"
fi
# ── Fetch status-server.py and start it immediately ──────────────────────────
# This occupies PORT so vastai-watch sees "loading" instead of connection refused.
echo "[provision] Fetching status-server.py..." | tee -a "$LOG_FILE"
curl -fsSL \
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
-H "Accept: application/vnd.github.raw" \
"https://api.github.com/repos/${GITHUB_REPO}/contents/services/watermark-remover/status-server.py?ref=${BRANCH}" \
-o "${WORKDIR}/status-server.py"
STATUS_FILE="$STATUS_FILE" PORT="$PORT" python3 "${WORKDIR}/status-server.py" >> "${LOG%.log}-status-server.log" 2>&1 &
STATUS_SERVER_PID=$!
echo "[provision] Status server PID ${STATUS_SERVER_PID} on :${PORT}" | tee -a "$LOG_FILE"
# ── Explicit venv binaries (never rely on PATH order) ─────────────────────────
VENV="/venv/main"
PIP="$VENV/bin/pip"
PYTHON="$VENV/bin/python3"
UVICORN="$VENV/bin/uvicorn"
# ── Activate venv ─────────────────────────────────────────────────────────────
set_status "venv" 5 "Activating virtual environment"
if . "$VENV/bin/activate"; then
echo "Virtual environment activated: $("$PYTHON" --version)" | tee -a "$LOG_FILE"
else
echo "Failed to activate virtual environment" | tee -a "$LOG_FILE"
exit 1
fi
# ── Change to workspace ────────────────────────────────────────────────────────
if cd /workspace/; then
pwd | tee -a "$LOG_FILE"
else
echo "Failed to change directory to /workspace/" | tee -a "$LOG_FILE"
exit 1
fi
cd "$WORKDIR"
# ── Fetch service code from private repo ──────────────────────────────────────
set_status "fetch-code" 8 "Fetching service code from ${GITHUB_REPO}@${BRANCH}"
if [[ -z "${GITHUB_TOKEN:-}" ]]; then
echo "[provision] ERROR: GITHUB_TOKEN not set. Cannot fetch service code."
exit 1
fi
for file in main.py coordinator.py requirements.txt; do
curl -fsSL \
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
-H "Accept: application/vnd.github.raw" \
"https://api.github.com/repos/${GITHUB_REPO}/contents/services/watermark-remover/${file}?ref=${BRANCH}" \
-o "${WORKDIR}/${file}"
echo "[provision] Downloaded ${file}."
done
# ── Install Python deps ───────────────────────────────────────────────────────
# Install torch first with the wheel URL that matches the installed CUDA driver.
# Default PyPI torch wheels target CUDA 12.1+ and fail on older drivers.
echo "[provision] Detecting CUDA driver version..."
CUDA_DRIVER=$(nvidia-smi 2>/dev/null | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+" | head -1 || echo "0.0")
CUDA_MAJOR=$(echo "$CUDA_DRIVER" | cut -d. -f1)
CUDA_MINOR=$(echo "$CUDA_DRIVER" | cut -d. -f2)
CUDA_NUM=$((CUDA_MAJOR * 10 + CUDA_MINOR)) # e.g. 12.1 → 121, 12.0 → 120
if [ "$CUDA_NUM" -ge 124 ]; then TORCH_CU="cu124"
elif [ "$CUDA_NUM" -ge 121 ]; then TORCH_CU="cu121"
elif [ "$CUDA_NUM" -ge 120 ]; then TORCH_CU="cu121" # cu120 wheels not published; cu121 works on 12.0 driver with minor mismatch—use cu118 for safety
elif [ "$CUDA_NUM" -ge 118 ]; then TORCH_CU="cu118"
else TORCH_CU="cu118"
fi
# CUDA 12.0 driver (520.x) cannot run cu121 wheels — use cu118 which is compatible.
if [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -eq 0 ]; then
TORCH_CU="cu118"
fi
if [ "$CN_MODE" -eq 1 ]; then
TORCH_INDEX="https://mirrors.nju.edu.cn/pytorch/whl/${TORCH_CU}"
else
TORCH_INDEX="https://download.pytorch.org/whl/${TORCH_CU}"
fi
# Check if torch is already installed and CUDA works — skip reinstall if so.
install_torch() {
set_status "torch" 12 "Installing torch ${TORCH_CU} from ${TORCH_INDEX}"
"$PIP" install --quiet --no-cache-dir --force-reinstall \
"torch>=2.3.0" "torchvision" "torchaudio" \
--index-url "$TORCH_INDEX"
echo "[provision] torch installed." | tee -a "$LOG_FILE"
}
echo "[provision] Checking existing torch/CUDA..." | tee -a "$LOG_FILE"
if "$PYTHON" -c "import torch; assert torch.cuda.is_available(), 'CUDA not available'" 2>/dev/null; then
TORCH_VER=$("$PYTHON" -c "import torch; print(torch.__version__)" 2>/dev/null)
echo "[provision] torch ${TORCH_VER} with CUDA already working — skipping reinstall." | tee -a "$LOG_FILE"
set_status "torch" 12 "torch ${TORCH_VER} already installed"
else
echo "[provision] torch/CUDA not working — installing from ${TORCH_INDEX}" | tee -a "$LOG_FILE"
install_torch
fi
set_status "deps" 20 "Installing Python dependencies"
grep -vE '^torch' requirements.txt | "$PIP" install --quiet --no-cache-dir -r /dev/stdin
echo "[provision] Deps installed."
# ── HuggingFace login ─────────────────────────────────────────────────────────
if [[ -z "${HF_TOKEN:-}" ]]; then
echo "[provision] ERROR: HF_TOKEN not set. Cannot download ${KONTEXT_MODEL}."
exit 1
fi
if [ "$CN_MODE" -eq 1 ]; then
export HF_ENDPOINT="https://hf-mirror.com"
echo "[provision] CN mode: HF_ENDPOINT=${HF_ENDPOINT}" | tee -a "$LOG_FILE"
fi
set_status "hf-login" 25 "Logging in to HuggingFace"
"$PYTHON" -c "
import os
from huggingface_hub import login
login(token=os.environ['HF_TOKEN'], add_to_git_credential=True)
print('[provision] HuggingFace login OK.')
"
# ── Pre-download model weights (with progress reporting to status-server) ─────
set_status "model-download" 30 "Starting model download ${KONTEXT_MODEL} (~34 GB)"
"$PYTHON" - <<PYEOF
import json, os, time
from pathlib import Path
from huggingface_hub import snapshot_download
from huggingface_hub.utils import tqdm as hf_tqdm
STATUS_FILE = Path(os.getenv("STATUS_FILE", "/tmp/provision-status.json"))
model = os.environ.get("KONTEXT_MODEL", "black-forest-labs/FLUX.1-Kontext-dev")
def write_status(progress: int, detail: str) -> None:
STATUS_FILE.write_text(json.dumps({
"status": "loading",
"stage": "model-download",
"progress": progress,
"detail": detail,
}))
# huggingface_hub fires tqdm callbacks — patch them to also write status.
_orig_init = hf_tqdm.__init__
_total_bytes = [0]
_done_bytes = [0]
def _patched_init(self, *args, **kwargs):
_orig_init(self, *args, **kwargs)
if self.total:
_total_bytes[0] = max(_total_bytes[0], self.total)
_orig_update = hf_tqdm.update
_last_write = [0.0]
def _patched_update(self, n=1):
_orig_update(self, n)
_done_bytes[0] += n or 0
now = time.monotonic()
if now - _last_write[0] >= 5 and _total_bytes[0] > 0:
pct = min(30 + int(60 * _done_bytes[0] / _total_bytes[0]), 89)
done_gb = _done_bytes[0] / 1024**3
total_gb = _total_bytes[0] / 1024**3
write_status(pct, f"Downloading model: {done_gb:.1f}/{total_gb:.1f} GB")
_last_write[0] = now
hf_tqdm.__init__ = _patched_init
hf_tqdm.update = _patched_update
write_status(30, f"Downloading {model}...")
snapshot_download(model, ignore_patterns=["*.bin"])
write_status(90, "Model download complete")
print("[provision] Model download complete.")
PYEOF
# ── Start workers (one per GPU) + coordinator ─────────────────────────────────
# Each worker is pinned to one GPU via CUDA_VISIBLE_DEVICES.
# The coordinator on PORT handles routing and exposes the aggregate /health.
GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l)
GPU_COUNT=$(( GPU_COUNT > 0 ? GPU_COUNT : 1 ))
echo "[provision] GPUs detected: ${GPU_COUNT}"
set_status "vram-check" 91 "Checking per-GPU VRAM"
# ── VRAM check — abort early if any GPU has < 12 GB ──────────────────────────
# Flux.1-Kontext requires ≥12 GB per GPU (4-bit path). vast.ai reports gpu_ram
# as the *total* across all cards, so a 2x RTX 2080 Ti (2×11GB=22GB) would pass
# a gpu_ram>=20 filter but OOM on every request.
MIN_VRAM_GB="${MIN_VRAM_GB:-12}"
echo "[provision] Checking per-GPU VRAM (minimum ${MIN_VRAM_GB} GB)..."
VRAM_FAIL=0
for i in $(seq 0 $((GPU_COUNT - 1))); do
VRAM_MIB=$(CUDA_VISIBLE_DEVICES=$i nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d ' ')
VRAM_GB=$(( VRAM_MIB / 1024 ))
echo "[provision] GPU ${i}: ${VRAM_GB} GB VRAM"
if [ "$VRAM_GB" -lt "$MIN_VRAM_GB" ]; then
echo "[provision] ERROR: GPU ${i} has only ${VRAM_GB} GB VRAM — need ≥${MIN_VRAM_GB} GB for Flux.1-Kontext. Aborting."
VRAM_FAIL=1
fi
done
if [ "$VRAM_FAIL" -eq 1 ]; then
exit 1
fi
set_status "starting-workers" 93 "Starting GPU workers"
WORKER_PORTS_LIST=""
for i in $(seq 0 $((GPU_COUNT - 1))); do
WORKER_PORT=$((PORT + 10 + i)) # e.g. PORT=8081 → workers on 8091, 8092, ...
WORKER_LOG="${LOG%.log}-gpu${i}.log"
CUDA_VISIBLE_DEVICES=$i nohup "$UVICORN" main:app \
--app-dir "$WORKDIR" \
--host 127.0.0.1 \
--port "$WORKER_PORT" \
--workers 1 \
--log-level info \
>> "$WORKER_LOG" 2>&1 &
echo "[provision] GPU ${i} → worker on 127.0.0.1:${WORKER_PORT} (PID $!). Logs: ${WORKER_LOG}"
[[ -n "$WORKER_PORTS_LIST" ]] && WORKER_PORTS_LIST="${WORKER_PORTS_LIST},"
WORKER_PORTS_LIST="${WORKER_PORTS_LIST}${WORKER_PORT}"
done
# ── CUDA smoke-test: wait for workers to init, check logs for torch errors ────
# Workers import torch + load the model on startup. If the pre-installed torch
# has a CUDA ABI mismatch, it surfaces as an error within the first ~30s.
# Patterns: "CUDA driver version is insufficient", "no kernel image",
# "libcuda.so", "CUDA error", "torch.cuda.is_available() returned False"
CUDA_ERROR_PATTERNS="CUDA driver version is insufficient|no kernel image|libcuda\.so|CUDA error|cuda\.is_available.*False|RuntimeError.*CUDA|AssertionError.*CUDA"
echo "[provision] Waiting 30s for workers to initialise (CUDA smoke-test)..." | tee -a "$LOG_FILE"
sleep 30
CUDA_BAD=0
for i in $(seq 0 $((GPU_COUNT - 1))); do
WORKER_LOG="${LOG%.log}-gpu${i}.log"
if grep -qE "$CUDA_ERROR_PATTERNS" "$WORKER_LOG" 2>/dev/null; then
echo "[provision] GPU ${i} worker CUDA error detected in ${WORKER_LOG} — will reinstall torch." | tee -a "$LOG_FILE"
grep -E "$CUDA_ERROR_PATTERNS" "$WORKER_LOG" | head -3 | tee -a "$LOG_FILE"
CUDA_BAD=1
fi
done
if [ "$CUDA_BAD" -eq 1 ]; then
set_status "torch-retry" 94 "CUDA error detected — reinstalling torch"
echo "[provision] Killing workers for torch reinstall..." | tee -a "$LOG_FILE"
# Kill all uvicorn worker processes
pkill -f "uvicorn main:app" 2>/dev/null || true
sleep 2
# Force reinstall torch regardless of CN_MODE (already set above)
install_torch
set_status "starting-workers" 95 "Restarting GPU workers after torch reinstall"
WORKER_PORTS_LIST=""
for i in $(seq 0 $((GPU_COUNT - 1))); do
WORKER_PORT=$((PORT + 10 + i))
WORKER_LOG="${LOG%.log}-gpu${i}.log"
CUDA_VISIBLE_DEVICES=$i nohup "$UVICORN" main:app \
--app-dir "$WORKDIR" \
--host 127.0.0.1 \
--port "$WORKER_PORT" \
--workers 1 \
--log-level info \
>> "$WORKER_LOG" 2>&1 &
echo "[provision] GPU ${i} → worker restarted on 127.0.0.1:${WORKER_PORT} (PID $!)" | tee -a "$LOG_FILE"
[[ -n "$WORKER_PORTS_LIST" ]] && WORKER_PORTS_LIST="${WORKER_PORTS_LIST},"
WORKER_PORTS_LIST="${WORKER_PORTS_LIST}${WORKER_PORT}"
done
echo "[provision] Torch reinstalled and workers restarted." | tee -a "$LOG_FILE"
fi
# ── Hand off PORT from status-server to coordinator ───────────────────────────
set_status "starting-coordinator" 97 "Starting coordinator on :${PORT}"
if kill "$STATUS_SERVER_PID" 2>/dev/null; then
echo "[provision] Status server (PID ${STATUS_SERVER_PID}) stopped." | tee -a "$LOG_FILE"
# Brief pause so the OS releases the port before coordinator binds it
sleep 1
fi
# Coordinator: routes /remove-watermark to idle workers, aggregates /health
WORKER_PORTS="$WORKER_PORTS_LIST" nohup "$UVICORN" coordinator:app \
--app-dir "$WORKDIR" \
--host "$HOST" \
--port "$PORT" \
--workers 1 \
--log-level info \
>> "${LOG%.log}-coordinator.log" 2>&1 &
echo "[provision] Coordinator on ${HOST}:${PORT} → workers [${WORKER_PORTS_LIST}] (PID $!)"
echo "[provision] Logs: tail -f ${LOG%.log}*.log"
echo "[provision] Done."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment