Last active
March 13, 2026 11:14
-
-
Save matyushkin/a88013d5cb0e07a20d8a930e1e78a843 to your computer and use it in GitHub Desktop.
Protenix validation on Colab for Stanford RNA 3D Folding 2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": "# Protenix Validation — Lightning AI / Colab\n\nRun Protenix inference on validation targets that lack good TBM templates.\nSaves per-target predictions as `.npy` files for local merge with TBM results.\n\n**Platform**: Lightning AI Studio (T4 free, 22h/mo) or Colab (T4/A100).\n**Runtime**: ~30-60 min for 28 targets on T4." | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 1. Setup — install dependencies\n", | |
| "!pip install -q biopython biotite\n", | |
| "!pip install -q rdkit-pypi\n", | |
| "\n", | |
| "# Check GPU\n", | |
| "import torch\n", | |
| "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}\")\n", | |
| "print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\" if torch.cuda.is_available() else \"\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": "# 2. Output directory\nimport os\n\n# Detect platform\nIS_COLAB = 'COLAB_GPU' in os.environ or os.path.exists('/content')\nIS_LIGHTNING = 'LIGHTNING_' in ''.join(os.environ.keys()) or os.path.exists('/teamspace')\n\nif IS_COLAB:\n from google.colab import drive\n drive.mount('/content/drive')\n OUTPUT_DIR = '/content/drive/MyDrive/rna_folding_2/val_protenix'\n WORK_BASE = '/content'\nelif IS_LIGHTNING:\n # Lightning AI: persistent storage at /teamspace/studios/this_studio/\n OUTPUT_DIR = '/teamspace/studios/this_studio/val_protenix'\n WORK_BASE = '/teamspace/studios/this_studio'\nelse:\n OUTPUT_DIR = './val_protenix'\n WORK_BASE = '.'\n\nos.makedirs(OUTPUT_DIR, exist_ok=True)\nprint(f\"Platform: {'Colab' if IS_COLAB else 'Lightning' if IS_LIGHTNING else 'Other'}\")\nprint(f\"Results → {OUTPUT_DIR}\")" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": "# 3. Download competition data via Kaggle API\n!pip install -q kaggle\n\nimport os\nkaggle_json = os.path.expanduser('~/.kaggle/kaggle.json')\nif not os.path.exists(kaggle_json):\n if IS_COLAB:\n from google.colab import files\n print(\"Upload your kaggle.json:\")\n uploaded = files.upload()\n !mkdir -p ~/.kaggle && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json\n else:\n # Lightning / other: place kaggle.json manually\n print(f\"Put kaggle.json at {kaggle_json}\")\n print(\"Get it from: https://www.kaggle.com/settings → API → Create New Token\")\n\nDATA_DIR_PATH = f'{WORK_BASE}/data'\n!mkdir -p {DATA_DIR_PATH}\n!kaggle competitions download stanford-rna-3d-folding-2 -f train_sequences.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f train_labels.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f validation_sequences.csv -p {DATA_DIR_PATH}/\n!kaggle competitions download stanford-rna-3d-folding-2 -f validation_labels.csv -p {DATA_DIR_PATH}/\n!ls -lh {DATA_DIR_PATH}/" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": "# 4. Download Protenix model from Kaggle datasets\n!kaggle datasets download qiweiyin/protenix-v1-adjusted -p {WORK_BASE}/protenix_model/ --unzip\n!ls {WORK_BASE}/protenix_model/" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": "import os\nimport sys\nimport json\nimport gc\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom Bio.Align import PairwiseAligner\nfrom tqdm import tqdm\n\n# Paths — use WORK_BASE detected in cell 2\nDATA_DIR = Path(f'{WORK_BASE}/data')\nPROTENIX_DIR = Path(f'{WORK_BASE}/protenix_model/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1')\n\nos.environ['LAYERNORM_TYPE'] = 'torch'\nos.environ['RNA_MSA_DEPTH_LIMIT'] = '512'\nos.environ['PROTENIX_ROOT_DIR'] = str(PROTENIX_DIR)\nsys.path.append(str(PROTENIX_DIR))\n\nMODEL_NAME = 'protenix_base_20250630_v1.0.0'\nN_SAMPLE = 5\nSEED = 42\nMAX_SEQ_LEN = 512\nCHUNK_OVERLAP = 256" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 5. TBM functions (same as local script)\n", | |
| "\n", | |
| "def make_aligner():\n", | |
| " al = PairwiseAligner()\n", | |
| " al.mode = 'global'\n", | |
| " al.match_score = 2\n", | |
| " al.mismatch_score = -1.5\n", | |
| " al.open_gap_score = -8\n", | |
| " al.extend_gap_score = -0.4\n", | |
| " al.query_left_open_gap_score = -8\n", | |
| " al.query_left_extend_gap_score = -0.4\n", | |
| " al.query_right_open_gap_score = -8\n", | |
| " al.query_right_extend_gap_score = -0.4\n", | |
| " al.target_left_open_gap_score = -8\n", | |
| " al.target_left_extend_gap_score = -0.4\n", | |
| " al.target_right_open_gap_score = -8\n", | |
| " al.target_right_extend_gap_score = -0.4\n", | |
| " return al\n", | |
| "\n", | |
| "_aligner = make_aligner()\n", | |
| "\n", | |
| "def process_labels(labels_df):\n", | |
| " coords = {}\n", | |
| " prefixes = labels_df['ID'].str.rsplit('_', n=1).str[0]\n", | |
| " for prefix, grp in labels_df.groupby(prefixes):\n", | |
| " coords[prefix] = grp.sort_values('resid')[['x_1', 'y_1', 'z_1']].values\n", | |
| " return coords\n", | |
| "\n", | |
| "def find_similar_sequences(query_seq, train_seqs_df, train_coords_dict, top_n=30):\n", | |
| " results = []\n", | |
| " for _, row in train_seqs_df.iterrows():\n", | |
| " tid, tseq = row['target_id'], row['sequence']\n", | |
| " if tid not in train_coords_dict:\n", | |
| " continue\n", | |
| " if abs(len(tseq) - len(query_seq)) / max(len(tseq), len(query_seq)) > 0.3:\n", | |
| " continue\n", | |
| " aln = next(iter(_aligner.align(query_seq, tseq)))\n", | |
| " norm_s = aln.score / (2 * min(len(query_seq), len(tseq)))\n", | |
| " identical = sum(\n", | |
| " 1 for (qs, qe), (ts, te) in zip(*aln.aligned)\n", | |
| " for qp, tp in zip(range(qs, qe), range(ts, te))\n", | |
| " if query_seq[qp] == tseq[tp]\n", | |
| " )\n", | |
| " pct_id = 100 * identical / len(query_seq)\n", | |
| " results.append((tid, tseq, norm_s, train_coords_dict[tid], pct_id))\n", | |
| " results.sort(key=lambda x: x[2], reverse=True)\n", | |
| " return results[:top_n]\n", | |
| "\n", | |
| "print('TBM functions loaded')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 6. Identify which targets need Protenix\n", | |
| "# (those without good TBM templates)\n", | |
| "\n", | |
| "MIN_SIMILARITY = 0.10\n", | |
| "MIN_PERCENT_IDENTITY = 55.0\n", | |
| "\n", | |
| "train_seqs = pd.read_csv(DATA_DIR / 'train_sequences.csv')\n", | |
| "train_labels = pd.read_csv(DATA_DIR / 'train_labels.csv', low_memory=False)\n", | |
| "val_seqs = pd.read_csv(DATA_DIR / 'validation_sequences.csv')\n", | |
| "\n", | |
| "train_coords = process_labels(train_labels)\n", | |
| "print(f'Train: {len(train_seqs)} sequences, {len(train_coords)} with coords')\n", | |
| "print(f'Val: {len(val_seqs)} targets')\n", | |
| "\n", | |
| "# Check which targets lack TBM coverage\n", | |
| "protenix_targets = []\n", | |
| "tbm_targets = []\n", | |
| "\n", | |
| "for _, row in tqdm(val_seqs.iterrows(), total=len(val_seqs), desc='Scanning'):\n", | |
| " tid = row['target_id']\n", | |
| " seq = row['sequence']\n", | |
| " similar = find_similar_sequences(seq, train_seqs, train_coords, top_n=10)\n", | |
| " \n", | |
| " n_good = sum(1 for _, _, sim, _, pct in similar \n", | |
| " if sim >= MIN_SIMILARITY and pct >= MIN_PERCENT_IDENTITY)\n", | |
| " \n", | |
| " n_needed = N_SAMPLE - min(n_good, N_SAMPLE)\n", | |
| " if n_needed > 0:\n", | |
| " protenix_targets.append((tid, seq, n_needed))\n", | |
| " print(f' {tid} ({len(seq)} nt): {n_good} TBM → need {n_needed} from Protenix')\n", | |
| " else:\n", | |
| " tbm_targets.append(tid)\n", | |
| " print(f' {tid} ({len(seq)} nt): fully covered by TBM')\n", | |
| "\n", | |
| "print(f'\\nTargets needing Protenix: {len(protenix_targets)}')\n", | |
| "print(f'Targets fully TBM: {len(tbm_targets)}')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 7. Protenix inference\n", | |
| "\n", | |
| "def build_input_json(records, json_path):\n", | |
| " data = [{\n", | |
| " 'name': r['target_id'],\n", | |
| " 'covalent_bonds': [],\n", | |
| " 'sequences': [{'rnaSequence': {'sequence': r['sequence'], 'count': 1}}],\n", | |
| " } for r in records]\n", | |
| " with open(json_path, 'w') as f:\n", | |
| " json.dump(data, f)\n", | |
| "\n", | |
| "def split_into_chunks(seq_len, max_len, overlap):\n", | |
| " if seq_len <= max_len:\n", | |
| " return [(0, seq_len)]\n", | |
| " chunks = []\n", | |
| " start = 0\n", | |
| " while start < seq_len:\n", | |
| " end = min(start + max_len, seq_len)\n", | |
| " chunks.append((start, end))\n", | |
| " if end == seq_len:\n", | |
| " break\n", | |
| " start = end - overlap\n", | |
| " return chunks\n", | |
| "\n", | |
| "def build_configs(input_json_path, dump_dir, model_name):\n", | |
| " from configs.configs_base import configs as configs_base\n", | |
| " from configs.configs_data import data_configs\n", | |
| " from configs.configs_inference import inference_configs\n", | |
| " from configs.configs_model_type import model_configs\n", | |
| " from protenix.config.config import parse_configs\n", | |
| "\n", | |
| " base = {**configs_base, **{'data': data_configs}, **inference_configs}\n", | |
| "\n", | |
| " def deep_update(t, p):\n", | |
| " for k, v in p.items():\n", | |
| " if isinstance(v, dict) and k in t and isinstance(t[k], dict):\n", | |
| " deep_update(t[k], v)\n", | |
| " else:\n", | |
| " t[k] = v\n", | |
| "\n", | |
| " deep_update(base, model_configs[model_name])\n", | |
| " arg_str = ' '.join([\n", | |
| " f'--model_name {model_name}',\n", | |
| " f'--input_json_path {input_json_path}',\n", | |
| " f'--dump_dir {dump_dir}',\n", | |
| " '--use_msa false',\n", | |
| " '--use_template false',\n", | |
| " '--use_rna_msa true',\n", | |
| " f'--sample_diffusion.N_sample {N_SAMPLE}',\n", | |
| " f'--seeds {SEED}',\n", | |
| " ])\n", | |
| " return parse_configs(configs=base, arg_str=arg_str, fill_required_with_null=True)\n", | |
| "\n", | |
| "print('Protenix helpers loaded')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": "# 8. Run Protenix on targets that need it\n\nfrom protenix.data.inference.infer_dataloader import InferenceDataset\nfrom runner.inference import InferenceRunner, update_gpu_compatible_configs, update_inference_configs\n\nwork_dir = Path(f'{WORK_BASE}/working')\nwork_dir.mkdir(exist_ok=True)\n\n# Build tasks (with chunking for long sequences)\ntasks = []\nchunk_info = {}\n\nfor tid, seq, n_needed in protenix_targets:\n seq_len = len(seq)\n if seq_len <= MAX_SEQ_LEN:\n tasks.append({'target_id': tid, 'sequence': seq})\n chunk_info[tid] = [{'name': tid, 'range': (0, seq_len)}]\n else:\n chunks = split_into_chunks(seq_len, MAX_SEQ_LEN, CHUNK_OVERLAP)\n chunk_info[tid] = []\n for ci, (cs, ce) in enumerate(chunks):\n chunk_name = f'{tid}_chunk{ci}'\n tasks.append({'target_id': chunk_name, 'sequence': seq[cs:ce]})\n chunk_info[tid].append({'name': chunk_name, 'range': (cs, ce)})\n\nprint(f'Total Protenix tasks: {len(tasks)} (from {len(protenix_targets)} targets)')\n\n# Write input JSON\ninput_json = str(work_dir / 'protenix_input.json')\nbuild_input_json(tasks, input_json)\n\n# Init model\nconfigs = build_configs(input_json, str(work_dir / 'outputs'), MODEL_NAME)\nconfigs = update_gpu_compatible_configs(configs)\nrunner = InferenceRunner(configs)\ndataset = InferenceDataset(configs)\n\nprint(f'Model loaded. Running inference on {len(dataset)} samples...')" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 9. Inference loop\n", | |
| "\n", | |
| "def get_c1_mask(data, atom_array):\n", | |
| " if atom_array is not None:\n", | |
| " try:\n", | |
| " if hasattr(atom_array, 'centre_atom_mask'):\n", | |
| " m = atom_array.centre_atom_mask == 1\n", | |
| " if hasattr(atom_array, 'is_rna'):\n", | |
| " m = m & atom_array.is_rna\n", | |
| " return torch.from_numpy(m).bool()\n", | |
| " if hasattr(atom_array, 'atom_name'):\n", | |
| " base = atom_array.atom_name == \"C1'\"\n", | |
| " if hasattr(atom_array, 'is_rna'):\n", | |
| " base = base & atom_array.is_rna\n", | |
| " return torch.from_numpy(base).bool()\n", | |
| " except Exception:\n", | |
| " pass\n", | |
| " f = data['input_feature_dict']\n", | |
| " if 'centre_atom_mask' in f:\n", | |
| " return (f['centre_atom_mask'] == 1).bool()\n", | |
| " if 'center_atom_mask' in f:\n", | |
| " return (f['center_atom_mask'] == 1).bool()\n", | |
| " n_tokens = data.get('N_token', torch.tensor(0)).item()\n", | |
| " m11 = (f['atom_to_tokatom_idx'] == 11).bool()\n", | |
| " m12 = (f['atom_to_tokatom_idx'] == 12).bool()\n", | |
| " return m11 if abs(m11.sum().item() - n_tokens) < abs(m12.sum().item() - n_tokens) else m12\n", | |
| "\n", | |
| "raw_predictions = {} # task_name -> (N_SAMPLE, seq_len, 3)\n", | |
| "\n", | |
| "for idx in tqdm(range(len(dataset)), desc='Protenix inference'):\n", | |
| " data = dataset[idx]\n", | |
| " sample_name = data.get('sample_name', f'sample_{idx}')\n", | |
| " \n", | |
| " try:\n", | |
| " prediction = runner.predict(data)\n", | |
| " feat = data['input_feature_dict']\n", | |
| " \n", | |
| " # Extract all-atom coords\n", | |
| " if hasattr(prediction, 'atom_coordinate'):\n", | |
| " raw_coords = prediction.atom_coordinate\n", | |
| " elif isinstance(prediction, dict) and 'atom_coordinate' in prediction:\n", | |
| " raw_coords = prediction['atom_coordinate']\n", | |
| " else:\n", | |
| " raw_coords = prediction\n", | |
| " \n", | |
| " if isinstance(raw_coords, torch.Tensor):\n", | |
| " raw_coords = raw_coords.cpu()\n", | |
| " \n", | |
| " mask = get_c1_mask(data, getattr(data, 'atom_array', None))\n", | |
| " \n", | |
| " # raw_coords shape: (N_SAMPLE, N_atoms, 3)\n", | |
| " if raw_coords.dim() == 3:\n", | |
| " c1_coords = raw_coords[:, mask, :].numpy()\n", | |
| " else:\n", | |
| " c1_coords = raw_coords[mask, :].unsqueeze(0).numpy()\n", | |
| " \n", | |
| " raw_predictions[sample_name] = c1_coords\n", | |
| " print(f' {sample_name}: {c1_coords.shape}')\n", | |
| " \n", | |
| " except Exception as e:\n", | |
| " print(f' {sample_name}: FAILED - {e}')\n", | |
| " \n", | |
| " gc.collect()\n", | |
| " torch.cuda.empty_cache()\n", | |
| "\n", | |
| "print(f'\\nCompleted: {len(raw_predictions)}/{len(dataset)} tasks')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# 10. Save predictions to Drive\n", | |
| "\n", | |
| "for name, coords in raw_predictions.items():\n", | |
| " np.save(f'{OUTPUT_DIR}/{name}.npy', coords)\n", | |
| " print(f'Saved {name}.npy: shape {coords.shape}')\n", | |
| "\n", | |
| "# Also save chunk_info for local stitching\n", | |
| "with open(f'{OUTPUT_DIR}/chunk_info.json', 'w') as f:\n", | |
| " json.dump(chunk_info, f)\n", | |
| "\n", | |
| "print(f'\\nAll saved to {OUTPUT_DIR}')\n", | |
| "print('Download these files and run the local merge script.')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "accelerator": "GPU", | |
| "colab": { | |
| "gpuType": "T4", | |
| "provenance": [] | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.11.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment