Last active
March 12, 2026 18:57
-
-
Save matyushkin/bd8404dec55ff80479627d4339f847f5 to your computer and use it in GitHub Desktop.
rna-sigmaborov-colab.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 5, | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.10.0" | |
| }, | |
| "colab": { | |
| "provenance": [] | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "import os, glob\n", | |
| "\n", | |
| "!pip install -q kaggle\n", | |
| "\n", | |
| "os.environ['KAGGLE_USERNAME'] = 'matyushkin'\n", | |
| "os.environ['KAGGLE_KEY'] = '20d8d6d4a3f4f45e4ef2a3d379cd35dd'\n", | |
| "\n", | |
| "os.makedirs('/data', exist_ok=True)\n", | |
| "needed = ['test_sequences.csv', 'train_sequences.csv', 'train_labels.csv',\n", | |
| " 'validation_sequences.csv', 'validation_labels.csv']\n", | |
| "missing = [f for f in needed if not os.path.exists(f'/data/{f}')]\n", | |
| "if missing:\n", | |
| " print('Downloading:', missing)\n", | |
| " for fname in missing:\n", | |
| " os.system(f'kaggle competitions download stanford-rna-3d-folding-2 -f {fname} -p /data/')\n", | |
| " for f in glob.glob('/data/*.zip'):\n", | |
| " os.system(f'unzip -q -o {f} -d /data/ && rm {f}')\n", | |
| "else:\n", | |
| " print('Data already present, skipping download')\n", | |
| "!ls /data/\n" | |
| ], | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "import os, sys\n", | |
| "\n", | |
| "PROTENIX_DIR = '/protenix/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1'\n", | |
| "os.environ['PROTENIX_CODE_DIR'] = PROTENIX_DIR\n", | |
| "os.environ['PROTENIX_ROOT_DIR'] = PROTENIX_DIR\n", | |
| "\n", | |
| "if not os.path.exists(PROTENIX_DIR):\n", | |
| " print('Downloading Protenix...')\n", | |
| " os.system('kaggle datasets download qiweiyin/protenix-v1-adjusted -p /protenix/ --unzip')\n", | |
| " os.system(f'pip install -q biopython==1.86 biotite==1.0.0 rdkit')\n", | |
| " os.system(f'pip install -q -e {PROTENIX_DIR}')\n", | |
| "else:\n", | |
| " print('Protenix already present')\n", | |
| " # Still need to reinstall in new runtime\n", | |
| " os.system(f'pip install -q biopython==1.86 biotite rdkit -e {PROTENIX_DIR}')\n", | |
| "\n", | |
| "print('Done')\n" | |
| ], | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import os\n", | |
| "import sys\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "IS_KAGGLE = False # Colab mode\n", | |
| "LOCAL_N_SAMPLES = 10 # Test on 10 sequences to validate quickly\n", | |
| "\n", | |
| "print(f'Running in LOCAL/COLAB mode — first {LOCAL_N_SAMPLES} targets only')\n" | |
| ], | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "outputs": [], | |
| "execution_count": null | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": "import gc\nimport json\nimport os\nimport time\n\nos.environ[\"LAYERNORM_TYPE\"] = \"torch\"\nos.environ.setdefault(\"RNA_MSA_DEPTH_LIMIT\", \"512\")\n\nimport sys\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom Bio.Align import PairwiseAligner\nfrom tqdm import tqdm", | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "outputs": [], | |
| "execution_count": null | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": "def get_c1_mask(data: dict, atom_array) -> torch.Tensor:\n # 1. Try atom_array attributes first\n if atom_array is not None:\n try:\n if hasattr(atom_array, \"centre_atom_mask\"):\n m = atom_array.centre_atom_mask == 1\n if hasattr(atom_array, \"is_rna\"):\n m = m & atom_array.is_rna\n return torch.from_numpy(m).bool()\n \n if hasattr(atom_array, \"atom_name\"):\n base = atom_array.atom_name == \"C1'\"\n if hasattr(atom_array, \"is_rna\"):\n base = base & atom_array.is_rna\n return torch.from_numpy(base).bool()\n except Exception:\n pass\n\n # 2. Fallback to feature dict\n f = data[\"input_feature_dict\"]\n \n if \"centre_atom_mask\" in f:\n return (f[\"centre_atom_mask\"] == 1).bool()\n if \"center_atom_mask\" in f:\n return (f[\"center_atom_mask\"] == 1).bool()\n \n # Heuristic fallback: check which index gives us roughly N_token atoms\n n_tokens = data.get(\"N_token\", torch.tensor(0)).item()\n mask11 = (f[\"atom_to_tokatom_idx\"] == 11).bool()\n mask12 = (f[\"atom_to_tokatom_idx\"] == 12).bool()\n \n c11 = mask11.sum().item()\n c12 = mask12.sum().item()\n \n # Return the one closer to N_tokens (likely one per residue)\n if abs(c11 - n_tokens) < abs(c12 - n_tokens):\n return mask11\n else:\n return mask12", | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "outputs": [], | |
| "execution_count": null | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": "<a id=\"section1\"></a>\n\n# <div style=\"text-align:center; border-radius:15px; padding:15px; font-size:115%; font-family:'Arial Black', Arial, sans-serif; text-shadow:2px 2px 5px rgba(0,0,0,0.7); background:linear-gradient(90deg, #1e3c72, #c31432); box-shadow:0 2px 5px rgba(0,0,0,0.3); color:white;\"><b>TUNING</b></div>", | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# ─────────────── Paths & Constants ───────────────────────────────────────────\nDATA_BASE = \"/data\"\nDEFAULT_TEST_CSV = f\"{DATA_BASE}/test_sequences.csv\"\nDEFAULT_TRAIN_CSV = f\"{DATA_BASE}/train_sequences.csv\"\nDEFAULT_TRAIN_LBLS = f\"{DATA_BASE}/train_labels.csv\"\nDEFAULT_VAL_CSV = f\"{DATA_BASE}/validation_sequences.csv\"\nDEFAULT_VAL_LBLS = f\"{DATA_BASE}/validation_labels.csv\"\nDEFAULT_OUTPUT = \"/content/submission.csv\"\n\nDEFAULT_CODE_DIR = (\n \"/kaggle/input/datasets/qiweiyin/protenix-v1-adjusted\"\n \"/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1\"\n)\nDEFAULT_ROOT_DIR = DEFAULT_CODE_DIR\n\nMODEL_NAME = \"protenix_base_20250630_v1.0.0\"\nN_SAMPLE = 5\nSEED = 42\nMAX_SEQ_LEN = int(os.environ.get(\"MAX_SEQ_LEN\", \"512\"))\nCHUNK_OVERLAP = int(os.environ.get(\"CHUNK_OVERLAP\", \"128\"))\n\n# TBM quality thresholds — sequences below these get routed to Protenix\nMIN_SIMILARITY = float(os.environ.get(\"MIN_SIMILARITY\", \"0.0\"))\nMIN_PERCENT_IDENTITY = float(os.environ.get(\"MIN_PERCENT_IDENTITY\", \"50.0\"))\n\n# Set False to skip Protenix and use de-novo fallback instead\nUSE_PROTENIX = True\n\n\ndef parse_bool(value: str, default: bool = False) -> str:\n v = str(value).strip().lower()\n if v in {\"1\", \"true\", \"t\", \"yes\", \"y\", \"on\"}:\n return \"true\"\n if v in {\"0\", \"false\", \"f\", \"no\", \"n\", \"off\"}:\n return \"false\"\n return \"true\" if default else \"false\"\n\n\nUSE_MSA = parse_bool(os.environ.get(\"USE_MSA\", \"false\"))\nUSE_TEMPLATE = parse_bool(os.environ.get(\"USE_TEMPLATE\", \"false\"))\nUSE_RNA_MSA = parse_bool(os.environ.get(\"USE_RNA_MSA\", \"true\"))\n\nMODEL_N_SAMPLE = int(os.environ.get(\"MODEL_N_SAMPLE\", str(N_SAMPLE)))\n\n\n# ─────────────── General Utilities ───────────────────────────────────────────\ndef seed_everything(seed: int) -> None:\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n torch.manual_seed(seed)\n torch.cuda.manual_seed(seed)\n torch.cuda.manual_seed_all(seed)\n np.random.seed(seed)\n torch.backends.cudnn.benchmark = False\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.enabled = True\n torch.use_deterministic_algorithms(True)\n\n\ndef resolve_paths():\n test_csv = os.environ.get(\"TEST_CSV\", DEFAULT_TEST_CSV)\n output_csv = os.environ.get(\"SUBMISSION_CSV\", DEFAULT_OUTPUT)\n code_dir = os.environ.get(\"PROTENIX_CODE_DIR\", DEFAULT_CODE_DIR)\n root_dir = os.environ.get(\"PROTENIX_ROOT_DIR\", DEFAULT_ROOT_DIR)\n return test_csv, output_csv, code_dir, root_dir\n\n\ndef ensure_required_files(root_dir: str) -> None:\n for p, name in [\n (Path(root_dir) / \"checkpoint\" / f\"{MODEL_NAME}.pt\", \"checkpoint\"),\n (Path(root_dir) / \"common\" / \"components.cif\", \"CCD file\"),\n (Path(root_dir) / \"common\" / \"components.cif.rdkit_mol.pkl\", \"CCD cache\"),\n ]:\n if not p.exists():\n raise FileNotFoundError(f\"Missing {name}: {p}\")\n\n\n# ─────────────── Protenix Input / Config Helpers ─────────────────────────────\ndef build_input_json(df: pd.DataFrame, json_path: str) -> None:\n data = [\n {\n \"name\": row[\"target_id\"],\n \"covalent_bonds\": [],\n \"sequences\": [{\"rnaSequence\": {\"sequence\": row[\"sequence\"], \"count\": 1}}],\n }\n for _, row in df.iterrows()\n ]\n with open(json_path, \"w\", encoding=\"utf-8\") as f:\n json.dump(data, f)\n\n\ndef build_configs(input_json_path: str, dump_dir: str, model_name: str):\n from configs.configs_base import configs as configs_base\n from configs.configs_data import data_configs\n from configs.configs_inference import inference_configs\n from configs.configs_model_type import model_configs\n from protenix.config.config import parse_configs\n\n base = {**configs_base, **{\"data\": data_configs}, **inference_configs}\n\n def deep_update(t, p):\n for k, v in p.items():\n if isinstance(v, dict) and k in t and isinstance(t[k], dict):\n deep_update(t[k], v)\n else:\n t[k] = v\n\n deep_update(base, model_configs[model_name])\n arg_str = \" \".join([\n f\"--model_name {model_name}\",\n f\"--input_json_path {input_json_path}\",\n f\"--dump_dir {dump_dir}\",\n f\"--use_msa {USE_MSA}\",\n f\"--use_template {USE_TEMPLATE}\",\n f\"--use_rna_msa {USE_RNA_MSA}\",\n f\"--sample_diffusion.N_sample {MODEL_N_SAMPLE}\",\n f\"--seeds {SEED}\",\n ])\n return parse_configs(configs=base, arg_str=arg_str, fill_required_with_null=True)\n\n\ndef get_c1_mask(data: dict, atom_array) -> torch.Tensor:\n # 1. Try atom_array attributes first\n if atom_array is not None:\n try:\n if hasattr(atom_array, \"centre_atom_mask\"):\n m = atom_array.centre_atom_mask == 1\n if hasattr(atom_array, \"is_rna\"):\n m = m & atom_array.is_rna\n return torch.from_numpy(m).bool()\n \n if hasattr(atom_array, \"atom_name\"):\n base = atom_array.atom_name == \"C1'\"\n if hasattr(atom_array, \"is_rna\"):\n base = base & atom_array.is_rna\n return torch.from_numpy(base).bool()\n except Exception:\n pass\n\n # 2. Fallback to feature dict\n f = data[\"input_feature_dict\"]\n \n # CASE A: center_atom_mask exists\n if \"center_atom_mask\" in f:\n return (f[\"center_atom_mask\"] == 1).bool()\n if \"centre_atom_mask\" in f:\n return (f[\"centre_atom_mask\"] == 1).bool()\n \n # CASE B: Use atom_name\n if \"atom_name\" in f:\n # Check against \"C1'\" (byte encoded or string?)\n # For now assume typical behavior is center_atom_mask is present.\n pass\n\n # CASE C: atom_to_tokatom_idx fallback\n # The index for C1' is typically 11 or 12 depending on featurizer.\n # Let's try to match exactly C1' if possible.\n # But usually 'centre_atom_mask' should be there.\n \n # If we fall through, assume standard mask\n return (f[\"atom_to_tokatom_idx\"] == 11).bool()\n\n\ndef get_feature_c1_mask(data: dict) -> torch.Tensor:\n f = data[\"input_feature_dict\"]\n if \"centre_atom_mask\" in f:\n return f[\"centre_atom_mask\"].long() == 1\n return f[\"atom_to_tokatom_idx\"].long() == 12\n\n\ndef coords_to_rows(target_id: str, seq: str, coords: np.ndarray) -> list:\n \"\"\"coords shape: (N_SAMPLE, seq_len, 3)\"\"\"\n rows = []\n for i in range(len(seq)):\n row = {\"ID\": f\"{target_id}_{i + 1}\", \"resname\": seq[i], \"resid\": i + 1}\n for s in range(N_SAMPLE):\n if s < coords.shape[0] and i < coords.shape[1]:\n x, y, z = coords[s, i]\n else:\n x, y, z = 0.0, 0.0, 0.0\n row[f\"x_{s + 1}\"] = float(x)\n row[f\"y_{s + 1}\"] = float(y)\n row[f\"z_{s + 1}\"] = float(z)\n rows.append(row)\n return rows\n\n\ndef pad_samples(coords: np.ndarray, n: int) -> np.ndarray:\n if coords.shape[0] >= n:\n return coords[:n]\n if coords.shape[0] == 0:\n return np.zeros((n, coords.shape[1], 3), dtype=coords.dtype)\n extra = np.repeat(coords[:1], n - coords.shape[0], axis=0)\n return np.concatenate([coords, extra], axis=0)\n\n\ndef split_into_chunks(seq_len: int, max_len: int, overlap: int) -> list:\n \"\"\"Split a sequence into overlapping (start, end) chunks.\"\"\"\n if seq_len <= max_len:\n return [(0, seq_len)]\n chunks = []\n step = max_len - overlap\n pos = 0\n while pos < seq_len:\n end = min(pos + max_len, seq_len)\n chunks.append((pos, end))\n if end == seq_len:\n break\n pos += step\n return chunks\n\n\ndef kabsch_align(P: np.ndarray, Q: np.ndarray):\n \"\"\"Compute optimal rotation R and translation t so that R @ P + t ≈ Q.\"\"\"\n centroid_P = P.mean(axis=0)\n centroid_Q = Q.mean(axis=0)\n Pc = P - centroid_P\n Qc = Q - centroid_Q\n H = Pc.T @ Qc\n U, _, Vt = np.linalg.svd(H)\n d = np.linalg.det(Vt.T @ U.T)\n S = np.eye(3)\n if d < 0:\n S[2, 2] = -1\n R = Vt.T @ S @ U.T\n t = centroid_Q - R @ centroid_P\n return R, t\n\n\ndef stitch_chunk_coords(chunk_coords_list: list,\n chunk_ranges: list,\n seq_len: int) -> np.ndarray:\n \"\"\"\n Merge overlapping chunk coordinates into a full sequence geometry.\n Applies Kabsch alignment on overlapping residues, and smoothly\n blends the coordinates using a linear weight ramp.\n \"\"\"\n if len(chunk_coords_list) == 1:\n coords = chunk_coords_list[0]\n if coords.shape[0] >= seq_len:\n return coords[:seq_len]\n out = np.zeros((seq_len, 3), dtype=coords.dtype)\n out[:coords.shape[0]] = coords\n return out\n\n # Start with the first chunk aligned to itself (identity)\n aligned = [chunk_coords_list[0].copy()]\n\n for i in range(1, len(chunk_coords_list)):\n prev_start, prev_end = chunk_ranges[i - 1]\n cur_start, cur_end = chunk_ranges[i]\n\n ov_start = cur_start\n ov_end = min(prev_end, cur_end)\n ov_len = ov_end - ov_start\n\n if ov_len < 3:\n # Cannot align reliably, just trust the coordinates as-is\n aligned.append(chunk_coords_list[i].copy())\n continue\n\n prev_ov = aligned[i - 1][ov_start - prev_start: ov_end - prev_start]\n cur_ov = chunk_coords_list[i][ov_start - cur_start: ov_end - cur_start]\n\n # Ignore invalid residues (e.g. padding/blank)\n valid = ~(np.isnan(prev_ov).any(axis=1) | np.isnan(cur_ov).any(axis=1))\n if valid.sum() < 3:\n aligned.append(chunk_coords_list[i].copy())\n continue\n\n # Align current chunk to previous chunk using only the overlap region\n R, t = kabsch_align(cur_ov[valid], prev_ov[valid])\n transformed = (chunk_coords_list[i] @ R.T) + t\n aligned.append(transformed)\n\n # Blend them together\n full = np.zeros((seq_len, 3), dtype=np.float64)\n weights = np.zeros(seq_len, dtype=np.float64)\n\n for i, ((s, e), coords) in enumerate(zip(chunk_ranges, aligned)):\n chunk_len = coords.shape[0]\n actual_end = min(s + chunk_len, seq_len)\n used_len = actual_end - s\n\n w = np.ones(used_len, dtype=np.float64)\n\n if i > 0:\n ov_start = s\n ov_end = min(chunk_ranges[i - 1][1], e)\n ramp_len = ov_end - ov_start\n if ramp_len > 0:\n w[:ramp_len] = np.linspace(0.0, 1.0, ramp_len)\n\n if i < len(chunk_ranges) - 1:\n next_s = chunk_ranges[i + 1][0]\n ramp_start = next_s - s\n ramp_len = actual_end - next_s\n if ramp_len > 0 and ramp_start < used_len:\n w[ramp_start:used_len] = np.linspace(1.0, 0.0, ramp_len)\n\n full[s:actual_end] += coords[:used_len] * w[:, None]\n weights[s:actual_end] += w\n\n mask = weights > 0\n full[mask] /= weights[mask, None]\n\n return full\n\n\n# ─────────────── TBM Core Functions ──────────────────────────────────────────\ndef _make_aligner() -> PairwiseAligner:\n al = PairwiseAligner()\n al.mode = \"global\"\n al.match_score = 2\n al.mismatch_score = -1.5\n al.open_gap_score = -8\n al.extend_gap_score = -0.4\n al.query_left_open_gap_score = -8\n al.query_left_extend_gap_score = -0.4\n al.query_right_open_gap_score = -8\n al.query_right_extend_gap_score = -0.4\n al.target_left_open_gap_score = -8\n al.target_left_extend_gap_score = -0.4\n al.target_right_open_gap_score = -8\n al.target_right_extend_gap_score = -0.4\n return al\n\n\n_aligner = _make_aligner()\n\n\ndef parse_stoichiometry(stoich: str) -> list:\n if pd.isna(stoich) or str(stoich).strip() == \"\":\n return []\n return [(ch.strip(), int(cnt)) for part in str(stoich).split(\";\")\n for ch, cnt in [part.split(\":\")]]\n\n\ndef parse_fasta(fasta_content: str) -> dict:\n out, cur, parts = {}, None, []\n for line in str(fasta_content).splitlines():\n line = line.strip()\n if not line:\n continue\n if line.startswith(\">\"):\n if cur is not None:\n out[cur] = \"\".join(parts)\n cur = line[1:].split()[0]\n parts = []\n else:\n parts.append(line.replace(\" \", \"\"))\n if cur is not None:\n out[cur] = \"\".join(parts)\n return out\n\n\ndef get_chain_segments(row) -> list:\n seq = row[\"sequence\"]\n stoich = row.get(\"stoichiometry\", \"\")\n all_sq = row.get(\"all_sequences\", \"\")\n if (pd.isna(stoich) or pd.isna(all_sq)\n or str(stoich).strip() == \"\" or str(all_sq).strip() == \"\"):\n return [(0, len(seq))]\n try:\n chain_dict = parse_fasta(all_sq)\n order = parse_stoichiometry(stoich)\n segs, pos = [], 0\n for ch, cnt in order:\n base = chain_dict.get(ch)\n if base is None:\n return [(0, len(seq))]\n for _ in range(cnt):\n segs.append((pos, pos + len(base)))\n pos += len(base)\n return segs if pos == len(seq) else [(0, len(seq))]\n except Exception:\n return [(0, len(seq))]\n\n\ndef build_segments_map(df: pd.DataFrame) -> tuple:\n seg_map, stoich_map = {}, {}\n for _, r in df.iterrows():\n tid = r[\"target_id\"]\n seg_map[tid] = get_chain_segments(r)\n raw_s = r.get(\"stoichiometry\", \"\")\n stoich_map[tid] = \"\" if pd.isna(raw_s) else str(raw_s)\n return seg_map, stoich_map\n\n\ndef process_labels(labels_df: pd.DataFrame) -> dict:\n coords = {}\n prefixes = labels_df[\"ID\"].str.rsplit(\"_\", n=1).str[0]\n for prefix, grp in labels_df.groupby(prefixes):\n coords[prefix] = grp.sort_values(\"resid\")[[\"x_1\", \"y_1\", \"z_1\"]].values\n return coords\n\n\ndef _build_aligned_strings(query_seq, template_seq, alignment):\n q_segs, t_segs = alignment.aligned\n aq, at, qi, ti = [], [], 0, 0\n for (qs, qe), (ts, te) in zip(q_segs, t_segs):\n while qi < qs: aq.append(query_seq[qi]); at.append(\"-\"); qi += 1\n while ti < ts: aq.append(\"-\"); at.append(template_seq[ti]); ti += 1\n for qp, tp in zip(range(qs, qe), range(ts, te)):\n aq.append(query_seq[qp]); at.append(template_seq[tp])\n qi, ti = qe, te\n while qi < len(query_seq): aq.append(query_seq[qi]); at.append(\"-\"); qi += 1\n while ti < len(template_seq): aq.append(\"-\"); at.append(template_seq[ti]); ti += 1\n return \"\".join(aq), \"\".join(at)\n\n\ndef find_similar_sequences_detailed(query_seq, train_seqs_df, train_coords_dict, top_n=30):\n results = []\n for _, row in train_seqs_df.iterrows():\n tid, tseq = row[\"target_id\"], row[\"sequence\"]\n if tid not in train_coords_dict:\n continue\n if abs(len(tseq) - len(query_seq)) / max(len(tseq), len(query_seq)) > 0.3:\n continue\n aln = next(iter(_aligner.align(query_seq, tseq)))\n norm_s = aln.score / (2 * min(len(query_seq), len(tseq)))\n identical = sum(\n 1 for (qs, qe), (ts, te) in zip(*aln.aligned)\n for qp, tp in zip(range(qs, qe), range(ts, te))\n if query_seq[qp] == tseq[tp]\n )\n pct_id = 100 * identical / len(query_seq)\n aq, at = _build_aligned_strings(query_seq, tseq, aln)\n results.append((tid, tseq, norm_s, train_coords_dict[tid], pct_id, aq, at))\n results.sort(key=lambda x: x[2], reverse=True)\n return results[:top_n]\n\n\ndef adapt_template_to_query(query_seq, template_seq, template_coords) -> np.ndarray:\n aln = next(iter(_aligner.align(query_seq, template_seq)))\n new_coords = np.full((len(query_seq), 3), np.nan)\n for (qs, qe), (ts, te) in zip(*aln.aligned):\n chunk = template_coords[ts:te]\n if len(chunk) == (qe - qs):\n new_coords[qs:qe] = chunk\n for i in range(len(new_coords)):\n if np.isnan(new_coords[i, 0]):\n pv = next((j for j in range(i - 1, -1, -1) if not np.isnan(new_coords[j, 0])), -1)\n nv = next((j for j in range(i + 1, len(new_coords)) if not np.isnan(new_coords[j, 0])), -1)\n if pv >= 0 and nv >= 0:\n w = (i - pv) / (nv - pv)\n new_coords[i] = (1 - w) * new_coords[pv] + w * new_coords[nv]\n elif pv >= 0:\n new_coords[i] = new_coords[pv] + [3, 0, 0]\n elif nv >= 0:\n new_coords[i] = new_coords[nv] + [3, 0, 0]\n else:\n new_coords[i] = [i * 3, 0, 0]\n return np.nan_to_num(new_coords)\n\n\ndef adaptive_rna_constraints(coords, target_id, segments_map, confidence=1.0, passes=2) -> np.ndarray:\n X = coords.copy()\n segments = segments_map.get(target_id, [(0, len(X))])\n strength = max(0.75 * (1.0 - min(confidence, 0.97)), 0.02)\n for _ in range(passes):\n for s, e in segments:\n C = X[s:e]; L = e - s\n if L < 3:\n continue\n # bond i–i+1 ~5.95 Å\n d = C[1:] - C[:-1]; dist = np.linalg.norm(d, axis=1) + 1e-6\n adj = d * ((5.95 - dist) / dist)[:, None] * (0.22 * strength)\n C[:-1] -= adj; C[1:] += adj\n # soft i–i+2 ~10.2 Å\n d2 = C[2:] - C[:-2]; d2n = np.linalg.norm(d2, axis=1) + 1e-6\n adj2 = d2 * ((10.2 - d2n) / d2n)[:, None] * (0.10 * strength)\n C[:-2] -= adj2; C[2:] += adj2\n # Laplacian smoothing\n C[1:-1] += (0.06 * strength) * (0.5 * (C[:-2] + C[2:]) - C[1:-1])\n # self-avoidance\n if L >= 25:\n idx = np.linspace(0, L - 1, min(L, 160)).astype(int) if L > 220 else np.arange(L)\n P = C[idx]; diff = P[:, None, :] - P[None, :, :]\n dm = np.linalg.norm(diff, axis=2) + 1e-6\n sep = np.abs(idx[:, None] - idx[None, :])\n mask = (sep > 2) & (dm < 3.2)\n if np.any(mask):\n vec = (diff * ((3.2 - dm) / dm)[:, :, None] * mask[:, :, None]).sum(axis=1)\n C[idx] += (0.015 * strength) * vec\n X[s:e] = C\n return X\n\n\ndef _rotmat(axis, ang):\n a = np.asarray(axis, float); a /= np.linalg.norm(a) + 1e-12\n x, y, z = a; c, s = np.cos(ang), np.sin(ang); CC = 1 - c\n return np.array([[c+x*x*CC, x*y*CC-z*s, x*z*CC+y*s],\n [y*x*CC+z*s, c+y*y*CC, y*z*CC-x*s],\n [z*x*CC-y*s, z*y*CC+x*s, c+z*z*CC]])\n\n\ndef apply_hinge(coords, seg, rng, deg=22):\n s, e = seg; L = e - s\n if L < 30: return coords\n pivot = s + int(rng.integers(10, L - 10))\n R = _rotmat(rng.normal(size=3), np.deg2rad(float(rng.uniform(-deg, deg))))\n X = coords.copy(); p0 = X[pivot].copy()\n X[pivot+1:e] = (X[pivot+1:e] - p0) @ R.T + p0\n return X\n\n\ndef jitter_chains(coords, segs, rng, deg=12, trans=1.5):\n X = coords.copy(); gc_ = X.mean(0, keepdims=True)\n for s, e in segs:\n R = _rotmat(rng.normal(size=3), np.deg2rad(float(rng.uniform(-deg, deg))))\n shift = rng.normal(size=3); shift = shift / (np.linalg.norm(shift) + 1e-12) * float(rng.uniform(0, trans))\n c = X[s:e].mean(0, keepdims=True)\n X[s:e] = (X[s:e] - c) @ R.T + c + shift\n X -= X.mean(0, keepdims=True) - gc_\n return X\n\n\ndef smooth_wiggle(coords, segs, rng, amp=0.8):\n X = coords.copy()\n for s, e in segs:\n L = e - s\n if L < 20: continue\n ctrl = np.linspace(0, L - 1, 6); disp = rng.normal(0, amp, (6, 3)); t = np.arange(L)\n X[s:e] += np.vstack([np.interp(t, ctrl, disp[:, k]) for k in range(3)]).T\n return X\n\n\ndef generate_rna_structure(sequence: str, seed=None) -> np.ndarray:\n \"\"\"Idealized A-form RNA helix — last-resort de-novo fallback.\"\"\"\n if seed is not None:\n np.random.seed(seed)\n n = len(sequence); coords = np.zeros((n, 3))\n for i in range(n):\n ang = i * 0.6\n coords[i] = [10.0 * np.cos(ang), 10.0 * np.sin(ang), i * 2.5]\n return coords\n\n\n# ─────────────── TBM Phase ───────────────────────────────────────────────────\ndef tbm_phase(test_df, train_seqs_df, train_coords_dict, segments_map):\n \"\"\"\n Phase 1 — Template-Based Modeling.\n\n Returns\n -------\n template_predictions : {target_id: [np.ndarray(seq_len, 3), ...]}\n 0 to N_SAMPLE predictions per target, from real templates.\n protenix_queue : {target_id: (n_needed, full_sequence)}\n Targets that still need more predictions.\n \"\"\"\n print(f\"\\n{'='*60}\")\n print(f\"PHASE 1: Template-Based Modeling\")\n print(f\" MIN_SIMILARITY = {MIN_SIMILARITY} | MIN_PCT_IDENTITY = {MIN_PERCENT_IDENTITY}\")\n print(f\"{'='*60}\")\n t0 = time.time()\n\n template_predictions: dict = {}\n protenix_queue: dict = {}\n\n for _, row in test_df.iterrows():\n tid = row[\"target_id\"]\n seq = row[\"sequence\"]\n segs = segments_map.get(tid, [(0, len(seq))])\n\n similar = find_similar_sequences_detailed(seq, train_seqs_df, train_coords_dict, top_n=30)\n preds = []\n used = set()\n\n for i, (tmpl_id, tmpl_seq, sim, tmpl_coords, pct_id, _, _) in enumerate(similar):\n if len(preds) >= N_SAMPLE:\n break\n if sim < MIN_SIMILARITY or pct_id < MIN_PERCENT_IDENTITY:\n break # list is sorted by sim, so no point continuing\n if tmpl_id in used:\n continue\n\n rng = np.random.default_rng((row.name * 10000000000 + i * 10007) % (2**32))\n adapted = adapt_template_to_query(seq, tmpl_seq, tmpl_coords)\n\n # Diversity transforms (same strategy as the 0-409 TBM notebook)\n slot = len(preds)\n if slot == 0:\n X = adapted\n elif slot == 1:\n X = adapted + rng.normal(0, max(0.01, (0.40 - sim) * 0.06), adapted.shape)\n elif slot == 2:\n longest = max(segs, key=lambda se: se[1] - se[0])\n X = apply_hinge(adapted, longest, rng)\n elif slot == 3:\n X = jitter_chains(adapted, segs, rng)\n else:\n X = smooth_wiggle(adapted, segs, rng)\n\n refined = adaptive_rna_constraints(X, tid, segments_map, confidence=sim)\n preds.append(refined)\n used.add(tmpl_id)\n\n template_predictions[tid] = preds\n n_needed = N_SAMPLE - len(preds)\n if n_needed > 0:\n protenix_queue[tid] = (n_needed, seq)\n print(f\" {tid} ({len(seq)} nt): {len(preds)} TBM → need {n_needed} from Protenix\")\n else:\n print(f\" {tid} ({len(seq)} nt): all {N_SAMPLE} from TBM ✓\")\n\n elapsed = time.time() - t0\n n_full = len(test_df) - len(protenix_queue)\n print(f\"\\nPhase 1 done in {elapsed:.1f}s\")\n print(f\" Fully covered by TBM : {n_full}\")\n print(f\" Need Protenix : {len(protenix_queue)}\")\n return template_predictions, protenix_queue\n\n\n# ─────────────── Main ────────────────────────────────────────────────────────\ndef main() -> None:\n test_csv, output_csv, code_dir, root_dir = resolve_paths()\n\n if not os.path.isdir(code_dir):\n raise FileNotFoundError(\n f\"Missing PROTENIX_CODE_DIR: {code_dir}. \"\n \"Set PROTENIX_CODE_DIR to the repo path.\"\n )\n\n os.environ[\"PROTENIX_ROOT_DIR\"] = root_dir\n sys.path.append(code_dir)\n ensure_required_files(root_dir)\n seed_everything(SEED)\n\n # ── Load test data ──────────────────────────────────────────────────────\n test_df_full = pd.read_csv(test_csv)\n test_df = (test_df_full.head(LOCAL_N_SAMPLES) if not IS_KAGGLE\n else test_df_full).reset_index(drop=True)\n print(f\"Test targets : {len(test_df)}\"\n + (\" (LOCAL MODE)\" if not IS_KAGGLE else \"\"))\n\n seq_by_id = dict(zip(test_df[\"target_id\"], test_df[\"sequence\"]))\n\n # Truncated copy for Protenix (Protenix has token limits)\n test_df_trunc = test_df.copy()\n test_df_trunc[\"sequence\"] = test_df_trunc[\"sequence\"].str[:MAX_SEQ_LEN]\n\n # ── Load training data for TBM ──────────────────────────────────────────\n print(\"\\nLoading training data for TBM …\")\n train_seqs = pd.read_csv(DEFAULT_TRAIN_CSV)\n val_seqs = pd.read_csv(DEFAULT_VAL_CSV)\n train_labels = pd.read_csv(DEFAULT_TRAIN_LBLS, usecols=[\"ID\",\"resid\",\"x_1\",\"y_1\",\"z_1\"], low_memory=False, dtype={\"resid\": \"int32\", \"x_1\": \"float32\", \"y_1\": \"float32\", \"z_1\": \"float32\"})\n val_labels = pd.read_csv(DEFAULT_VAL_LBLS, usecols=[\"ID\",\"resid\",\"x_1\",\"y_1\",\"z_1\"], low_memory=False, dtype={\"resid\": \"int32\", \"x_1\": \"float32\", \"y_1\": \"float32\", \"z_1\": \"float32\"})\n\n combined_seqs = pd.concat([train_seqs, val_seqs], ignore_index=True)\n combined_labels = pd.concat([train_labels, val_labels], ignore_index=True)\n train_coords = process_labels(combined_labels)\n segments_map, _ = build_segments_map(test_df)\n\n print(f\"Template pool: {len(combined_seqs)} sequences, {len(train_coords)} structures\")\n\n # ─── PHASE 1: TBM ──────────────────────────────────────────────────────\n template_preds, protenix_queue = tbm_phase(\n test_df, combined_seqs, train_coords, segments_map\n )\n\n # ─── PHASE 2: Protenix (only for targets that need extra predictions) ──\n protenix_preds: dict = {} # target_id -> np.ndarray (n_needed, seq_len, 3)\n\n if protenix_queue and USE_PROTENIX:\n print(f\"\\n{'='*60}\")\n print(f\"PHASE 2: Protenix for {len(protenix_queue)} targets\")\n print(f\"{'='*60}\")\n\n work_dir = Path(\"/content\")\n work_dir.mkdir(parents=True, exist_ok=True)\n\n # ── 1. Preparation: create tasks for all sequences/chunks ────────────\n tasks = [] # list of dict for input_json\n chunk_info = {} # target_id -> list of {\"name\": chunk_name, \"range\": (s, e)}\n \n for target_id, (n_needed, full_seq) in protenix_queue.items():\n seq_len = len(full_seq)\n if seq_len <= MAX_SEQ_LEN:\n tasks.append({\"target_id\": target_id, \"sequence\": full_seq})\n chunk_info[target_id] = [{\"name\": target_id, \"range\": (0, seq_len)}]\n print(f\" {target_id} ({seq_len} nt): single pass queued\")\n else:\n chunks = split_into_chunks(seq_len, MAX_SEQ_LEN, CHUNK_OVERLAP)\n print(f\" {target_id} ({seq_len} nt): {len(chunks)} chunks queued \"\n f\"{[(s, e) for s, e in chunks]}\")\n \n chunk_info[target_id] = []\n for ci, (cs, ce) in enumerate(chunks):\n chunk_name = f\"{target_id}_chunk{ci}\"\n sub_seq = full_seq[cs:ce]\n tasks.append({\"target_id\": chunk_name, \"sequence\": sub_seq})\n chunk_info[target_id].append({\"name\": chunk_name, \"range\": (cs, ce)})\n\n # Build combined input JSON\n tasks_df = pd.DataFrame(tasks)\n input_json_path = str(work_dir / \"protenix_queue_input.json\")\n build_input_json(tasks_df, input_json_path)\n\n from protenix.data.inference.infer_dataloader import InferenceDataset\n from runner.inference import (InferenceRunner,\n update_gpu_compatible_configs,\n update_inference_configs)\n\n # Initialize model exactly ONCE\n configs = build_configs(input_json_path, str(work_dir / \"outputs\"), MODEL_NAME)\n configs = update_gpu_compatible_configs(configs)\n runner = InferenceRunner(configs)\n dataset = InferenceDataset(configs)\n\n # ── 2. Inference: process dataset and collect predictions ────────────\n raw_predictions = {} # sample_name -> coords (np.ndarray or None)\n\n def _extract_c1_coords(prediction, feat, chunk_seq_len, raw_coords):\n if \"centre_atom_mask\" in feat:\n mask = (feat[\"centre_atom_mask\"] == 1).to(raw_coords.device)\n elif \"atom_to_tokatom_idx\" in feat:\n m11 = (feat[\"atom_to_tokatom_idx\"] == 11).to(raw_coords.device)\n m12 = (feat[\"atom_to_tokatom_idx\"] == 12).to(raw_coords.device)\n c11, c12 = m11.sum(), m12.sum()\n mask = m11 if abs(c11 - chunk_seq_len) < abs(c12 - chunk_seq_len) else m12\n else:\n mask = torch.zeros(raw_coords.shape[1], dtype=torch.bool, device=raw_coords.device)\n \n coords = raw_coords[:, mask, :].detach().cpu().numpy()\n \n # Collapse check\n if coords.shape[1] > 1:\n diffs = np.linalg.norm(coords[0, 1:] - coords[0, :-1], axis=-1)\n if np.all(diffs < 1e-4):\n print(f\" WARNING: Collapsed coordinates detected\")\n return None\n \n if coords.shape[1] != chunk_seq_len:\n if coords.shape[1] == 1 and chunk_seq_len > 1:\n return None\n padded = np.zeros((coords.shape[0], chunk_seq_len, 3), dtype=np.float32)\n ml = min(coords.shape[1], chunk_seq_len)\n padded[:, :ml, :] = coords[:, :ml, :]\n coords = padded\n return coords\n\n for i in tqdm(range(len(dataset)), desc=\"Protenix Inference\"):\n data, atom_array, err = dataset[i]\n sample_name = data.get(\"sample_name\", f\"sample_{i}\")\n \n if err:\n print(f\" {sample_name} data error: {err}\")\n raw_predictions[sample_name] = None\n del data, atom_array, err\n gc.collect(); torch.cuda.empty_cache(); gc.collect()\n continue\n \n # Find how many samples are needed for this specific query\n target_id = sample_name.split(\"_chunk\")[0] if \"_chunk\" in sample_name else sample_name\n n_needed = protenix_queue.get(target_id, (N_SAMPLE, \"\"))[0]\n sub_seq_len = data[\"N_token\"].item() # roughly correct\n \n try:\n new_cfg = update_inference_configs(configs, sub_seq_len)\n new_cfg.sample_diffusion.N_sample = n_needed\n runner.update_model_configs(new_cfg)\n \n pred = runner.predict(data)\n raw_coords = pred[\"coordinate\"]\n \n coords = _extract_c1_coords(pred, data[\"input_feature_dict\"], \n sub_seq_len, raw_coords)\n raw_predictions[sample_name] = coords\n except Exception as exc:\n print(f\" {sample_name} inference failed: {exc}\")\n import traceback; traceback.print_exc()\n raw_predictions[sample_name] = None\n finally:\n try: del pred, data, atom_array, raw_coords\n except: pass\n gc.collect(); torch.cuda.empty_cache(); gc.collect()\n\n # ── 3. Post-processing: Stitching and final formatting ───────────────\n for target_id, (n_needed, full_seq) in protenix_queue.items():\n seq_len = len(full_seq)\n chunks = chunk_info.get(target_id, [])\n \n if not chunks:\n continue\n\n if len(chunks) == 1:\n # Single pass\n coords = raw_predictions.get(target_id)\n protenix_preds[target_id] = coords\n if coords is not None:\n print(f\" {target_id}: {coords.shape[0]} predictions generated\")\n else:\n print(f\" {target_id}: FAILED\")\n else:\n # Stitch chunks together\n chunk_results_per_sample = {s: [] for s in range(n_needed)}\n all_ok = True\n \n for ci, cinfo in enumerate(chunks):\n cname = cinfo[\"name\"]\n crange = cinfo[\"range\"]\n ccoords = raw_predictions.get(cname)\n \n if ccoords is None:\n all_ok = False\n break\n \n for s_idx in range(n_needed):\n if s_idx < ccoords.shape[0]:\n chunk_results_per_sample[s_idx].append((ccoords[s_idx], crange))\n else:\n chunk_results_per_sample[s_idx].append((ccoords[-1], crange))\n \n if not all_ok:\n print(f\" {target_id}: chunked inference incomplete, using fallback\")\n protenix_preds[target_id] = None\n continue\n \n stitched_samples = []\n for s_idx in range(n_needed):\n items = chunk_results_per_sample[s_idx]\n coords_list = [c for c, _ in items]\n ranges_list = [r for _, r in items]\n full_coords = stitch_chunk_coords(coords_list, ranges_list, seq_len)\n stitched_samples.append(full_coords)\n \n result = np.stack(stitched_samples, axis=0)\n protenix_preds[target_id] = result\n print(f\" {target_id}: {result.shape[0]} stitched predictions generated\")\n# ...existing code...\n\n elif protenix_queue and not USE_PROTENIX:\n print(f\"\\nPHASE 2 skipped (USE_PROTENIX=False). \"\n f\"De-novo fallback will cover {len(protenix_queue)} targets.\")\n\n # ─── PHASE 3: Combine everything ───────────────────────────────────────\n print(f\"\\n{'='*60}\")\n print(\"PHASE 3: Combine TBM + Protenix + de-novo fallback\")\n print(f\"{'='*60}\")\n\n all_rows = []\n\n for _, row in test_df.iterrows():\n tid = row[\"target_id\"]\n seq = row[\"sequence\"]\n\n combined: list = list(template_preds.get(tid, [])) # TBM predictions\n\n # Append Protenix predictions to fill remaining slots\n ptx = protenix_preds.get(tid)\n if ptx is not None and ptx.ndim == 3:\n for j in range(ptx.shape[0]):\n if len(combined) >= N_SAMPLE:\n break\n combined.append(ptx[j]) # (seq_len, 3)\n\n # De-novo fallback for any still-empty slots\n n_denovo = 0\n while len(combined) < N_SAMPLE:\n seed_val = row.name * 1000000 + len(combined) * 1000\n dn = generate_rna_structure(seq, seed=seed_val)\n combined.append(adaptive_rna_constraints(dn, tid, segments_map, confidence=0.2))\n n_denovo += 1\n\n if n_denovo:\n print(f\" {tid}: {n_denovo} slot(s) filled with de-novo fallback\")\n\n # Stack to (N_SAMPLE, seq_len, 3) and write rows\n stacked = np.stack(combined[:N_SAMPLE], axis=0)\n all_rows.extend(coords_to_rows(tid, seq, stacked))\n\n # ── Save ───────────────────────────────────────────────────────────────\n sub = pd.DataFrame(all_rows)\n cols = [\"ID\", \"resname\", \"resid\"] + [\n f\"{c}_{i}\" for i in range(1, N_SAMPLE + 1) for c in [\"x\", \"y\", \"z\"]\n ]\n coord_cols = [c for c in cols if c.startswith((\"x_\", \"y_\", \"z_\"))]\n sub[coord_cols] = sub[coord_cols].clip(-999.999, 9999.999)\n sub[cols].to_csv(output_csv, index=False)\n\n print(f\"\\n✓ Saved submission to {output_csv} ({len(sub):,} rows)\")" | |
| ], | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "outputs": [], | |
| "execution_count": null | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": "<a id=\"section2\"></a>\n\n# <div style=\"text-align:center; border-radius:15px; padding:15px; font-size:115%; font-family:'Arial Black', Arial, sans-serif; text-shadow:2px 2px 5px rgba(0,0,0,0.7); background:linear-gradient(90deg, #1e3c72, #c31432); box-shadow:0 2px 5px rgba(0,0,0,0.3); color:white;\"><b>MAIN</b></div>", | |
| "metadata": {} | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": "if __name__ == \"__main__\":\n main()", | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "outputs": [], | |
| "execution_count": null | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment