Skip to content

Instantly share code, notes, and snippets.

@matyushkin
Last active March 13, 2026 11:14
Show Gist options
  • Select an option

  • Save matyushkin/a88013d5cb0e07a20d8a930e1e78a843 to your computer and use it in GitHub Desktop.

Select an option

Save matyushkin/a88013d5cb0e07a20d8a930e1e78a843 to your computer and use it in GitHub Desktop.
Protenix validation on Colab for Stanford RNA 3D Folding 2
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# Protenix Validation — Lightning AI / Colab\n\nRun Protenix inference on validation targets that lack good TBM templates.\nSaves per-target predictions as `.npy` files for local merge with TBM results.\n\n**Platform**: Lightning AI Studio (T4 free, 22h/mo) or Colab (T4/A100).\n**Runtime**: ~30-60 min for 28 targets on T4."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 1. Setup — install dependencies\n",
"!pip install -q biopython biotite\n",
"!pip install -q rdkit-pypi\n",
"\n",
"# Check GPU\n",
"import torch\n",
"print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}\")\n",
"print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 2. Output directory\nimport os\n\n# Detect platform\nIS_COLAB = 'COLAB_GPU' in os.environ or os.path.exists('/content')\nIS_LIGHTNING = 'LIGHTNING_' in ''.join(os.environ.keys()) or os.path.exists('/teamspace')\n\nif IS_COLAB:\n from google.colab import drive\n drive.mount('/content/drive')\n OUTPUT_DIR = '/content/drive/MyDrive/rna_folding_2/val_protenix'\n WORK_BASE = '/content'\nelif IS_LIGHTNING:\n # Lightning AI: persistent storage at /teamspace/studios/this_studio/\n OUTPUT_DIR = '/teamspace/studios/this_studio/val_protenix'\n WORK_BASE = '/teamspace/studios/this_studio'\nelse:\n OUTPUT_DIR = './val_protenix'\n WORK_BASE = '.'\n\nos.makedirs(OUTPUT_DIR, exist_ok=True)\nprint(f\"Platform: {'Colab' if IS_COLAB else 'Lightning' if IS_LIGHTNING else 'Other'}\")\nprint(f\"Results → {OUTPUT_DIR}\")"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 3. Download competition data via Kaggle API\n!pip install -q kaggle\n\nimport os\nkaggle_json = os.path.expanduser('~/.kaggle/kaggle.json')\nif not os.path.exists(kaggle_json):\n if IS_COLAB:\n from google.colab import files\n print(\"Upload your kaggle.json:\")\n uploaded = files.upload()\n !mkdir -p ~/.kaggle && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json\n else:\n # Lightning / other: place kaggle.json manually\n print(f\"Put kaggle.json at {kaggle_json}\")\n print(\"Get it from: https://www.kaggle.com/settings → API → Create New Token\")\n\nDATA_DIR_PATH = f'{WORK_BASE}/data'\n!mkdir -p {DATA_DIR_PATH}\n!kaggle competitions download stanford-rna-3d-folding-2 -f train_sequences.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f train_labels.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f validation_sequences.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f validation_labels.csv -p {DATA_DIR_PATH}/\n!ls -lh {DATA_DIR_PATH}/"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 4. Download Protenix model from Kaggle datasets\n!kaggle datasets download qiweiyin/protenix-v1-adjusted -p {WORK_BASE}/protenix_model/ --unzip\n!ls {WORK_BASE}/protenix_model/"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import os\nimport sys\nimport json\nimport gc\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom Bio.Align import PairwiseAligner\nfrom tqdm import tqdm\n\n# Paths — use WORK_BASE detected in cell 2\nDATA_DIR = Path(f'{WORK_BASE}/data')\nPROTENIX_DIR = Path(f'{WORK_BASE}/protenix_model/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1')\n\nos.environ['LAYERNORM_TYPE'] = 'torch'\nos.environ['RNA_MSA_DEPTH_LIMIT'] = '512'\nos.environ['PROTENIX_ROOT_DIR'] = str(PROTENIX_DIR)\nsys.path.append(str(PROTENIX_DIR))\n\nMODEL_NAME = 'protenix_base_20250630_v1.0.0'\nN_SAMPLE = 5\nSEED = 42\nMAX_SEQ_LEN = 512\nCHUNK_OVERLAP = 256"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 5. TBM functions (same as local script)\n",
"\n",
"def make_aligner():\n",
" al = PairwiseAligner()\n",
" al.mode = 'global'\n",
" al.match_score = 2\n",
" al.mismatch_score = -1.5\n",
" al.open_gap_score = -8\n",
" al.extend_gap_score = -0.4\n",
" al.query_left_open_gap_score = -8\n",
" al.query_left_extend_gap_score = -0.4\n",
" al.query_right_open_gap_score = -8\n",
" al.query_right_extend_gap_score = -0.4\n",
" al.target_left_open_gap_score = -8\n",
" al.target_left_extend_gap_score = -0.4\n",
" al.target_right_open_gap_score = -8\n",
" al.target_right_extend_gap_score = -0.4\n",
" return al\n",
"\n",
"_aligner = make_aligner()\n",
"\n",
"def process_labels(labels_df):\n",
" coords = {}\n",
" prefixes = labels_df['ID'].str.rsplit('_', n=1).str[0]\n",
" for prefix, grp in labels_df.groupby(prefixes):\n",
" coords[prefix] = grp.sort_values('resid')[['x_1', 'y_1', 'z_1']].values\n",
" return coords\n",
"\n",
"def find_similar_sequences(query_seq, train_seqs_df, train_coords_dict, top_n=30):\n",
" results = []\n",
" for _, row in train_seqs_df.iterrows():\n",
" tid, tseq = row['target_id'], row['sequence']\n",
" if tid not in train_coords_dict:\n",
" continue\n",
" if abs(len(tseq) - len(query_seq)) / max(len(tseq), len(query_seq)) > 0.3:\n",
" continue\n",
" aln = next(iter(_aligner.align(query_seq, tseq)))\n",
" norm_s = aln.score / (2 * min(len(query_seq), len(tseq)))\n",
" identical = sum(\n",
" 1 for (qs, qe), (ts, te) in zip(*aln.aligned)\n",
" for qp, tp in zip(range(qs, qe), range(ts, te))\n",
" if query_seq[qp] == tseq[tp]\n",
" )\n",
" pct_id = 100 * identical / len(query_seq)\n",
" results.append((tid, tseq, norm_s, train_coords_dict[tid], pct_id))\n",
" results.sort(key=lambda x: x[2], reverse=True)\n",
" return results[:top_n]\n",
"\n",
"print('TBM functions loaded')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 6. Identify which targets need Protenix\n",
"# (those without good TBM templates)\n",
"\n",
"MIN_SIMILARITY = 0.10\n",
"MIN_PERCENT_IDENTITY = 55.0\n",
"\n",
"train_seqs = pd.read_csv(DATA_DIR / 'train_sequences.csv')\n",
"train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv', low_memory=False)\n",
"val_seqs = pd.read_csv(DATA_DIR / 'validation_sequences.csv')\n",
"\n",
"train_coords = process_labels(train_labels)\n",
"print(f'Train: {len(train_seqs)} sequences, {len(train_coords)} with coords')\n",
"print(f'Val: {len(val_seqs)} targets')\n",
"\n",
"# Check which targets lack TBM coverage\n",
"protenix_targets = []\n",
"tbm_targets = []\n",
"\n",
"for _, row in tqdm(val_seqs.iterrows(), total=len(val_seqs), desc='Scanning'):\n",
" tid = row['target_id']\n",
" seq = row['sequence']\n",
" similar = find_similar_sequences(seq, train_seqs, train_coords, top_n=10)\n",
" \n",
" n_good = sum(1 for _, _, sim, _, pct in similar \n",
" if sim >= MIN_SIMILARITY and pct >= MIN_PERCENT_IDENTITY)\n",
" \n",
" n_needed = N_SAMPLE - min(n_good, N_SAMPLE)\n",
" if n_needed > 0:\n",
" protenix_targets.append((tid, seq, n_needed))\n",
" print(f' {tid} ({len(seq)} nt): {n_good} TBM → need {n_needed} from Protenix')\n",
" else:\n",
" tbm_targets.append(tid)\n",
" print(f' {tid} ({len(seq)} nt): fully covered by TBM')\n",
"\n",
"print(f'\\nTargets needing Protenix: {len(protenix_targets)}')\n",
"print(f'Targets fully TBM: {len(tbm_targets)}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 7. Protenix inference\n",
"\n",
"def build_input_json(records, json_path):\n",
" data = [{\n",
" 'name': r['target_id'],\n",
" 'covalent_bonds': [],\n",
" 'sequences': [{'rnaSequence': {'sequence': r['sequence'], 'count': 1}}],\n",
" } for r in records]\n",
" with open(json_path, 'w') as f:\n",
" json.dump(data, f)\n",
"\n",
"def split_into_chunks(seq_len, max_len, overlap):\n",
" if seq_len <= max_len:\n",
" return [(0, seq_len)]\n",
" chunks = []\n",
" start = 0\n",
" while start < seq_len:\n",
" end = min(start + max_len, seq_len)\n",
" chunks.append((start, end))\n",
" if end == seq_len:\n",
" break\n",
" start = end - overlap\n",
" return chunks\n",
"\n",
"def build_configs(input_json_path, dump_dir, model_name):\n",
" from configs.configs_base import configs as configs_base\n",
" from configs.configs_data import data_configs\n",
" from configs.configs_inference import inference_configs\n",
" from configs.configs_model_type import model_configs\n",
" from protenix.config.config import parse_configs\n",
"\n",
" base = {**configs_base, **{'data': data_configs}, **inference_configs}\n",
"\n",
" def deep_update(t, p):\n",
" for k, v in p.items():\n",
" if isinstance(v, dict) and k in t and isinstance(t[k], dict):\n",
" deep_update(t[k], v)\n",
" else:\n",
" t[k] = v\n",
"\n",
" deep_update(base, model_configs[model_name])\n",
" arg_str = ' '.join([\n",
" f'--model_name {model_name}',\n",
" f'--input_json_path {input_json_path}',\n",
" f'--dump_dir {dump_dir}',\n",
" '--use_msa false',\n",
" '--use_template false',\n",
" '--use_rna_msa true',\n",
" f'--sample_diffusion.N_sample {N_SAMPLE}',\n",
" f'--seeds {SEED}',\n",
" ])\n",
" return parse_configs(configs=base, arg_str=arg_str, fill_required_with_null=True)\n",
"\n",
"print('Protenix helpers loaded')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# 8. Run Protenix on targets that need it\n\nfrom protenix.data.inference.infer_dataloader import InferenceDataset\nfrom runner.inference import InferenceRunner, update_gpu_compatible_configs, update_inference_configs\n\nwork_dir = Path(f'{WORK_BASE}/working')\nwork_dir.mkdir(exist_ok=True)\n\n# Build tasks (with chunking for long sequences)\ntasks = []\nchunk_info = {}\n\nfor tid, seq, n_needed in protenix_targets:\n seq_len = len(seq)\n if seq_len <= MAX_SEQ_LEN:\n tasks.append({'target_id': tid, 'sequence': seq})\n chunk_info[tid] = [{'name': tid, 'range': (0, seq_len)}]\n else:\n chunks = split_into_chunks(seq_len, MAX_SEQ_LEN, CHUNK_OVERLAP)\n chunk_info[tid] = []\n for ci, (cs, ce) in enumerate(chunks):\n chunk_name = f'{tid}_chunk{ci}'\n tasks.append({'target_id': chunk_name, 'sequence': seq[cs:ce]})\n chunk_info[tid].append({'name': chunk_name, 'range': (cs, ce)})\n\nprint(f'Total Protenix tasks: {len(tasks)} (from {len(protenix_targets)} targets)')\n\n# Write input JSON\ninput_json = str(work_dir / 'protenix_input.json')\nbuild_input_json(tasks, input_json)\n\n# Init model\nconfigs = build_configs(input_json, str(work_dir / 'outputs'), MODEL_NAME)\nconfigs = update_gpu_compatible_configs(configs)\nrunner = InferenceRunner(configs)\ndataset = InferenceDataset(configs)\n\nprint(f'Model loaded. Running inference on {len(dataset)} samples...')"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 9. Inference loop\n",
"\n",
"def get_c1_mask(data, atom_array):\n",
" if atom_array is not None:\n",
" try:\n",
" if hasattr(atom_array, 'centre_atom_mask'):\n",
" m = atom_array.centre_atom_mask == 1\n",
" if hasattr(atom_array, 'is_rna'):\n",
" m = m & atom_array.is_rna\n",
" return torch.from_numpy(m).bool()\n",
" if hasattr(atom_array, 'atom_name'):\n",
" base = atom_array.atom_name == \"C1'\"\n",
" if hasattr(atom_array, 'is_rna'):\n",
" base = base & atom_array.is_rna\n",
" return torch.from_numpy(base).bool()\n",
" except Exception:\n",
" pass\n",
" f = data['input_feature_dict']\n",
" if 'centre_atom_mask' in f:\n",
" return (f['centre_atom_mask'] == 1).bool()\n",
" if 'center_atom_mask' in f:\n",
" return (f['center_atom_mask'] == 1).bool()\n",
" n_tokens = data.get('N_token', torch.tensor(0)).item()\n",
" m11 = (f['atom_to_tokatom_idx'] == 11).bool()\n",
" m12 = (f['atom_to_tokatom_idx'] == 12).bool()\n",
" return m11 if abs(m11.sum().item() - n_tokens) < abs(m12.sum().item() - n_tokens) else m12\n",
"\n",
"raw_predictions = {} # task_name -> (N_SAMPLE, seq_len, 3)\n",
"\n",
"for idx in tqdm(range(len(dataset)), desc='Protenix inference'):\n",
" data = dataset[idx]\n",
" sample_name = data.get('sample_name', f'sample_{idx}')\n",
" \n",
" try:\n",
" prediction = runner.predict(data)\n",
" feat = data['input_feature_dict']\n",
" \n",
" # Extract all-atom coords\n",
" if hasattr(prediction, 'atom_coordinate'):\n",
" raw_coords = prediction.atom_coordinate\n",
" elif isinstance(prediction, dict) and 'atom_coordinate' in prediction:\n",
" raw_coords = prediction['atom_coordinate']\n",
" else:\n",
" raw_coords = prediction\n",
" \n",
" if isinstance(raw_coords, torch.Tensor):\n",
" raw_coords = raw_coords.cpu()\n",
" \n",
" mask = get_c1_mask(data, getattr(data, 'atom_array', None))\n",
" \n",
" # raw_coords shape: (N_SAMPLE, N_atoms, 3)\n",
" if raw_coords.dim() == 3:\n",
" c1_coords = raw_coords[:, mask, :].numpy()\n",
" else:\n",
" c1_coords = raw_coords[mask, :].unsqueeze(0).numpy()\n",
" \n",
" raw_predictions[sample_name] = c1_coords\n",
" print(f' {sample_name}: {c1_coords.shape}')\n",
" \n",
" except Exception as e:\n",
" print(f' {sample_name}: FAILED - {e}')\n",
" \n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
"\n",
"print(f'\\nCompleted: {len(raw_predictions)}/{len(dataset)} tasks')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 10. Save predictions to Drive\n",
"\n",
"for name, coords in raw_predictions.items():\n",
" np.save(f'{OUTPUT_DIR}/{name}.npy', coords)\n",
" print(f'Saved {name}.npy: shape {coords.shape}')\n",
"\n",
"# Also save chunk_info for local stitching\n",
"with open(f'{OUTPUT_DIR}/chunk_info.json', 'w') as f:\n",
" json.dump(chunk_info, f)\n",
"\n",
"print(f'\\nAll saved to {OUTPUT_DIR}')\n",
"print('Download these files and run the local merge script.')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment