Skip to content

Instantly share code, notes, and snippets.

@matyushkin
Last active March 12, 2026 18:57
Show Gist options
  • Select an option

  • Save matyushkin/bd8404dec55ff80479627d4339f847f5 to your computer and use it in GitHub Desktop.

Select an option

Save matyushkin/bd8404dec55ff80479627d4339f847f5 to your computer and use it in GitHub Desktop.
rna-sigmaborov-colab.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
},
"colab": {
"provenance": []
}
},
"cells": [
{
"cell_type": "code",
"metadata": {},
"source": [
"import os, glob\n",
"\n",
"!pip install -q kaggle\n",
"\n",
"os.environ['KAGGLE_USERNAME'] = 'matyushkin'\n",
"os.environ['KAGGLE_KEY'] = '20d8d6d4a3f4f45e4ef2a3d379cd35dd'\n",
"\n",
"os.makedirs('/data', exist_ok=True)\n",
"needed = ['test_sequences.csv', 'train_sequences.csv', 'train_labels.csv',\n",
" 'validation_sequences.csv', 'validation_labels.csv']\n",
"missing = [f for f in needed if not os.path.exists(f'/data/{f}')]\n",
"if missing:\n",
" print('Downloading:', missing)\n",
" for fname in missing:\n",
" os.system(f'kaggle competitions download stanford-rna-3d-folding-2 -f {fname} -p /data/')\n",
" for f in glob.glob('/data/*.zip'):\n",
" os.system(f'unzip -q -o {f} -d /data/ && rm {f}')\n",
"else:\n",
" print('Data already present, skipping download')\n",
"!ls /data/\n"
],
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import os, sys\n",
"\n",
"PROTENIX_DIR = '/protenix/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1'\n",
"os.environ['PROTENIX_CODE_DIR'] = PROTENIX_DIR\n",
"os.environ['PROTENIX_ROOT_DIR'] = PROTENIX_DIR\n",
"\n",
"if not os.path.exists(PROTENIX_DIR):\n",
" print('Downloading Protenix...')\n",
" os.system('kaggle datasets download qiweiyin/protenix-v1-adjusted -p /protenix/ --unzip')\n",
" os.system(f'pip install -q biopython==1.86 biotite==1.0.0 rdkit')\n",
" os.system(f'pip install -q -e {PROTENIX_DIR}')\n",
"else:\n",
" print('Protenix already present')\n",
" # Still need to reinstall in new runtime\n",
" os.system(f'pip install -q biopython==1.86 biotite rdkit -e {PROTENIX_DIR}')\n",
"\n",
"print('Done')\n"
],
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"import sys\n",
"import pandas as pd\n",
"\n",
"IS_KAGGLE = False # Colab mode\n",
"LOCAL_N_SAMPLES = 10 # Test on 10 sequences to validate quickly\n",
"\n",
"print(f'Running in LOCAL/COLAB mode — first {LOCAL_N_SAMPLES} targets only')\n"
],
"metadata": {
"trusted": true
},
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"source": "import gc\nimport json\nimport os\nimport time\n\nos.environ[\"LAYERNORM_TYPE\"] = \"torch\"\nos.environ.setdefault(\"RNA_MSA_DEPTH_LIMIT\", \"512\")\n\nimport sys\nfrom pathlib import Path\n\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom Bio.Align import PairwiseAligner\nfrom tqdm import tqdm",
"metadata": {
"trusted": true
},
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"source": "def get_c1_mask(data: dict, atom_array) -> torch.Tensor:\n # 1. Try atom_array attributes first\n if atom_array is not None:\n try:\n if hasattr(atom_array, \"centre_atom_mask\"):\n m = atom_array.centre_atom_mask == 1\n if hasattr(atom_array, \"is_rna\"):\n m = m & atom_array.is_rna\n return torch.from_numpy(m).bool()\n \n if hasattr(atom_array, \"atom_name\"):\n base = atom_array.atom_name == \"C1'\"\n if hasattr(atom_array, \"is_rna\"):\n base = base & atom_array.is_rna\n return torch.from_numpy(base).bool()\n except Exception:\n pass\n\n # 2. Fallback to feature dict\n f = data[\"input_feature_dict\"]\n \n if \"centre_atom_mask\" in f:\n return (f[\"centre_atom_mask\"] == 1).bool()\n if \"center_atom_mask\" in f:\n return (f[\"center_atom_mask\"] == 1).bool()\n \n # Heuristic fallback: check which index gives us roughly N_token atoms\n n_tokens = data.get(\"N_token\", torch.tensor(0)).item()\n mask11 = (f[\"atom_to_tokatom_idx\"] == 11).bool()\n mask12 = (f[\"atom_to_tokatom_idx\"] == 12).bool()\n \n c11 = mask11.sum().item()\n c12 = mask12.sum().item()\n \n # Return the one closer to N_tokens (likely one per residue)\n if abs(c11 - n_tokens) < abs(c12 - n_tokens):\n return mask11\n else:\n return mask12",
"metadata": {
"trusted": true
},
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"source": "<a id=\"section1\"></a>\n\n# <div style=\"text-align:center; border-radius:15px; padding:15px; font-size:115%; font-family:'Arial Black', Arial, sans-serif; text-shadow:2px 2px 5px rgba(0,0,0,0.7); background:linear-gradient(90deg, #1e3c72, #c31432); box-shadow:0 2px 5px rgba(0,0,0,0.3); color:white;\"><b>TUNING</b></div>",
"metadata": {}
},
{
"cell_type": "code",
"source": [
"# ─────────────── Paths & Constants ───────────────────────────────────────────\nDATA_BASE = \"/data\"\nDEFAULT_TEST_CSV = f\"{DATA_BASE}/test_sequences.csv\"\nDEFAULT_TRAIN_CSV = f\"{DATA_BASE}/train_sequences.csv\"\nDEFAULT_TRAIN_LBLS = f\"{DATA_BASE}/train_labels.csv\"\nDEFAULT_VAL_CSV = f\"{DATA_BASE}/validation_sequences.csv\"\nDEFAULT_VAL_LBLS = f\"{DATA_BASE}/validation_labels.csv\"\nDEFAULT_OUTPUT = \"/content/submission.csv\"\n\nDEFAULT_CODE_DIR = (\n \"/kaggle/input/datasets/qiweiyin/protenix-v1-adjusted\"\n \"/Protenix-v1-adjust-v2/Protenix-v1-adjust-v2/Protenix-v1\"\n)\nDEFAULT_ROOT_DIR = DEFAULT_CODE_DIR\n\nMODEL_NAME = \"protenix_base_20250630_v1.0.0\"\nN_SAMPLE = 5\nSEED = 42\nMAX_SEQ_LEN = int(os.environ.get(\"MAX_SEQ_LEN\", \"512\"))\nCHUNK_OVERLAP = int(os.environ.get(\"CHUNK_OVERLAP\", \"128\"))\n\n# TBM quality thresholds — sequences below these get routed to Protenix\nMIN_SIMILARITY = float(os.environ.get(\"MIN_SIMILARITY\", \"0.0\"))\nMIN_PERCENT_IDENTITY = float(os.environ.get(\"MIN_PERCENT_IDENTITY\", \"50.0\"))\n\n# Set False to skip Protenix and use de-novo fallback instead\nUSE_PROTENIX = True\n\n\ndef parse_bool(value: str, default: bool = False) -> str:\n v = str(value).strip().lower()\n if v in {\"1\", \"true\", \"t\", \"yes\", \"y\", \"on\"}:\n return \"true\"\n if v in {\"0\", \"false\", \"f\", \"no\", \"n\", \"off\"}:\n return \"false\"\n return \"true\" if default else \"false\"\n\n\nUSE_MSA = parse_bool(os.environ.get(\"USE_MSA\", \"false\"))\nUSE_TEMPLATE = parse_bool(os.environ.get(\"USE_TEMPLATE\", \"false\"))\nUSE_RNA_MSA = parse_bool(os.environ.get(\"USE_RNA_MSA\", \"true\"))\n\nMODEL_N_SAMPLE = int(os.environ.get(\"MODEL_N_SAMPLE\", str(N_SAMPLE)))\n\n\n# ─────────────── General Utilities ───────────────────────────────────────────\ndef seed_everything(seed: int) -> None:\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n torch.manual_seed(seed)\n torch.cuda.manual_seed(seed)\n torch.cuda.manual_seed_all(seed)\n np.random.seed(seed)\n torch.backends.cudnn.benchmark = False\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.enabled = True\n torch.use_deterministic_algorithms(True)\n\n\ndef resolve_paths():\n test_csv = os.environ.get(\"TEST_CSV\", DEFAULT_TEST_CSV)\n output_csv = os.environ.get(\"SUBMISSION_CSV\", DEFAULT_OUTPUT)\n code_dir = os.environ.get(\"PROTENIX_CODE_DIR\", DEFAULT_CODE_DIR)\n root_dir = os.environ.get(\"PROTENIX_ROOT_DIR\", DEFAULT_ROOT_DIR)\n return test_csv, output_csv, code_dir, root_dir\n\n\ndef ensure_required_files(root_dir: str) -> None:\n for p, name in [\n (Path(root_dir) / \"checkpoint\" / f\"{MODEL_NAME}.pt\", \"checkpoint\"),\n (Path(root_dir) / \"common\" / \"components.cif\", \"CCD file\"),\n (Path(root_dir) / \"common\" / \"components.cif.rdkit_mol.pkl\", \"CCD cache\"),\n ]:\n if not p.exists():\n raise FileNotFoundError(f\"Missing {name}: {p}\")\n\n\n# ─────────────── Protenix Input / Config Helpers ─────────────────────────────\ndef build_input_json(df: pd.DataFrame, json_path: str) -> None:\n data = [\n {\n \"name\": row[\"target_id\"],\n \"covalent_bonds\": [],\n \"sequences\": [{\"rnaSequence\": {\"sequence\": row[\"sequence\"], \"count\": 1}}],\n }\n for _, row in df.iterrows()\n ]\n with open(json_path, \"w\", encoding=\"utf-8\") as f:\n json.dump(data, f)\n\n\ndef build_configs(input_json_path: str, dump_dir: str, model_name: str):\n from configs.configs_base import configs as configs_base\n from configs.configs_data import data_configs\n from configs.configs_inference import inference_configs\n from configs.configs_model_type import model_configs\n from protenix.config.config import parse_configs\n\n base = {**configs_base, **{\"data\": data_configs}, **inference_configs}\n\n def deep_update(t, p):\n for k, v in p.items():\n if isinstance(v, dict) and k in t and isinstance(t[k], dict):\n deep_update(t[k], v)\n else:\n t[k] = v\n\n deep_update(base, model_configs[model_name])\n arg_str = \" \".join([\n f\"--model_name {model_name}\",\n f\"--input_json_path {input_json_path}\",\n f\"--dump_dir {dump_dir}\",\n f\"--use_msa {USE_MSA}\",\n f\"--use_template {USE_TEMPLATE}\",\n f\"--use_rna_msa {USE_RNA_MSA}\",\n f\"--sample_diffusion.N_sample {MODEL_N_SAMPLE}\",\n f\"--seeds {SEED}\",\n ])\n return parse_configs(configs=base, arg_str=arg_str, fill_required_with_null=True)\n\n\ndef get_c1_mask(data: dict, atom_array) -> torch.Tensor:\n # 1. Try atom_array attributes first\n if atom_array is not None:\n try:\n if hasattr(atom_array, \"centre_atom_mask\"):\n m = atom_array.centre_atom_mask == 1\n if hasattr(atom_array, \"is_rna\"):\n m = m & atom_array.is_rna\n return torch.from_numpy(m).bool()\n \n if hasattr(atom_array, \"atom_name\"):\n base = atom_array.atom_name == \"C1'\"\n if hasattr(atom_array, \"is_rna\"):\n base = base & atom_array.is_rna\n return torch.from_numpy(base).bool()\n except Exception:\n pass\n\n # 2. Fallback to feature dict\n f = data[\"input_feature_dict\"]\n \n # CASE A: center_atom_mask exists\n if \"center_atom_mask\" in f:\n return (f[\"center_atom_mask\"] == 1).bool()\n if \"centre_atom_mask\" in f:\n return (f[\"centre_atom_mask\"] == 1).bool()\n \n # CASE B: Use atom_name\n if \"atom_name\" in f:\n # Check against \"C1'\" (byte encoded or string?)\n # For now assume typical behavior is center_atom_mask is present.\n pass\n\n # CASE C: atom_to_tokatom_idx fallback\n # The index for C1' is typically 11 or 12 depending on featurizer.\n # Let's try to match exactly C1' if possible.\n # But usually 'centre_atom_mask' should be there.\n \n # If we fall through, assume standard mask\n return (f[\"atom_to_tokatom_idx\"] == 11).bool()\n\n\ndef get_feature_c1_mask(data: dict) -> torch.Tensor:\n f = data[\"input_feature_dict\"]\n if \"centre_atom_mask\" in f:\n return f[\"centre_atom_mask\"].long() == 1\n return f[\"atom_to_tokatom_idx\"].long() == 12\n\n\ndef coords_to_rows(target_id: str, seq: str, coords: np.ndarray) -> list:\n \"\"\"coords shape: (N_SAMPLE, seq_len, 3)\"\"\"\n rows = []\n for i in range(len(seq)):\n row = {\"ID\": f\"{target_id}_{i + 1}\", \"resname\": seq[i], \"resid\": i + 1}\n for s in range(N_SAMPLE):\n if s < coords.shape[0] and i < coords.shape[1]:\n x, y, z = coords[s, i]\n else:\n x, y, z = 0.0, 0.0, 0.0\n row[f\"x_{s + 1}\"] = float(x)\n row[f\"y_{s + 1}\"] = float(y)\n row[f\"z_{s + 1}\"] = float(z)\n rows.append(row)\n return rows\n\n\ndef pad_samples(coords: np.ndarray, n: int) -> np.ndarray:\n if coords.shape[0] >= n:\n return coords[:n]\n if coords.shape[0] == 0:\n return np.zeros((n, coords.shape[1], 3), dtype=coords.dtype)\n extra = np.repeat(coords[:1], n - coords.shape[0], axis=0)\n return np.concatenate([coords, extra], axis=0)\n\n\ndef split_into_chunks(seq_len: int, max_len: int, overlap: int) -> list:\n \"\"\"Split a sequence into overlapping (start, end) chunks.\"\"\"\n if seq_len <= max_len:\n return [(0, seq_len)]\n chunks = []\n step = max_len - overlap\n pos = 0\n while pos < seq_len:\n end = min(pos + max_len, seq_len)\n chunks.append((pos, end))\n if end == seq_len:\n break\n pos += step\n return chunks\n\n\ndef kabsch_align(P: np.ndarray, Q: np.ndarray):\n \"\"\"Compute optimal rotation R and translation t so that R @ P + t ≈ Q.\"\"\"\n centroid_P = P.mean(axis=0)\n centroid_Q = Q.mean(axis=0)\n Pc = P - centroid_P\n Qc = Q - centroid_Q\n H = Pc.T @ Qc\n U, _, Vt = np.linalg.svd(H)\n d = np.linalg.det(Vt.T @ U.T)\n S = np.eye(3)\n if d < 0:\n S[2, 2] = -1\n R = Vt.T @ S @ U.T\n t = centroid_Q - R @ centroid_P\n return R, t\n\n\ndef stitch_chunk_coords(chunk_coords_list: list,\n chunk_ranges: list,\n seq_len: int) -> np.ndarray:\n \"\"\"\n Merge overlapping chunk coordinates into a full sequence geometry.\n Applies Kabsch alignment on overlapping residues, and smoothly\n blends the coordinates using a linear weight ramp.\n \"\"\"\n if len(chunk_coords_list) == 1:\n coords = chunk_coords_list[0]\n if coords.shape[0] >= seq_len:\n return coords[:seq_len]\n out = np.zeros((seq_len, 3), dtype=coords.dtype)\n out[:coords.shape[0]] = coords\n return out\n\n # Start with the first chunk aligned to itself (identity)\n aligned = [chunk_coords_list[0].copy()]\n\n for i in range(1, len(chunk_coords_list)):\n prev_start, prev_end = chunk_ranges[i - 1]\n cur_start, cur_end = chunk_ranges[i]\n\n ov_start = cur_start\n ov_end = min(prev_end, cur_end)\n ov_len = ov_end - ov_start\n\n if ov_len < 3:\n # Cannot align reliably, just trust the coordinates as-is\n aligned.append(chunk_coords_list[i].copy())\n continue\n\n prev_ov = aligned[i - 1][ov_start - prev_start: ov_end - prev_start]\n cur_ov = chunk_coords_list[i][ov_start - cur_start: ov_end - cur_start]\n\n # Ignore invalid residues (e.g. padding/blank)\n valid = ~(np.isnan(prev_ov).any(axis=1) | np.isnan(cur_ov).any(axis=1))\n if valid.sum() < 3:\n aligned.append(chunk_coords_list[i].copy())\n continue\n\n # Align current chunk to previous chunk using only the overlap region\n R, t = kabsch_align(cur_ov[valid], prev_ov[valid])\n transformed = (chunk_coords_list[i] @ R.T) + t\n aligned.append(transformed)\n\n # Blend them together\n full = np.zeros((seq_len, 3), dtype=np.float64)\n weights = np.zeros(seq_len, dtype=np.float64)\n\n for i, ((s, e), coords) in enumerate(zip(chunk_ranges, aligned)):\n chunk_len = coords.shape[0]\n actual_end = min(s + chunk_len, seq_len)\n used_len = actual_end - s\n\n w = np.ones(used_len, dtype=np.float64)\n\n if i > 0:\n ov_start = s\n ov_end = min(chunk_ranges[i - 1][1], e)\n ramp_len = ov_end - ov_start\n if ramp_len > 0:\n w[:ramp_len] = np.linspace(0.0, 1.0, ramp_len)\n\n if i < len(chunk_ranges) - 1:\n next_s = chunk_ranges[i + 1][0]\n ramp_start = next_s - s\n ramp_len = actual_end - next_s\n if ramp_len > 0 and ramp_start < used_len:\n w[ramp_start:used_len] = np.linspace(1.0, 0.0, ramp_len)\n\n full[s:actual_end] += coords[:used_len] * w[:, None]\n weights[s:actual_end] += w\n\n mask = weights > 0\n full[mask] /= weights[mask, None]\n\n return full\n\n\n# ─────────────── TBM Core Functions ──────────────────────────────────────────\ndef _make_aligner() -> PairwiseAligner:\n al = PairwiseAligner()\n al.mode = \"global\"\n al.match_score = 2\n al.mismatch_score = -1.5\n al.open_gap_score = -8\n al.extend_gap_score = -0.4\n al.query_left_open_gap_score = -8\n al.query_left_extend_gap_score = -0.4\n al.query_right_open_gap_score = -8\n al.query_right_extend_gap_score = -0.4\n al.target_left_open_gap_score = -8\n al.target_left_extend_gap_score = -0.4\n al.target_right_open_gap_score = -8\n al.target_right_extend_gap_score = -0.4\n return al\n\n\n_aligner = _make_aligner()\n\n\ndef parse_stoichiometry(stoich: str) -> list:\n if pd.isna(stoich) or str(stoich).strip() == \"\":\n return []\n return [(ch.strip(), int(cnt)) for part in str(stoich).split(\";\")\n for ch, cnt in [part.split(\":\")]]\n\n\ndef parse_fasta(fasta_content: str) -> dict:\n out, cur, parts = {}, None, []\n for line in str(fasta_content).splitlines():\n line = line.strip()\n if not line:\n continue\n if line.startswith(\">\"):\n if cur is not None:\n out[cur] = \"\".join(parts)\n cur = line[1:].split()[0]\n parts = []\n else:\n parts.append(line.replace(\" \", \"\"))\n if cur is not None:\n out[cur] = \"\".join(parts)\n return out\n\n\ndef get_chain_segments(row) -> list:\n seq = row[\"sequence\"]\n stoich = row.get(\"stoichiometry\", \"\")\n all_sq = row.get(\"all_sequences\", \"\")\n if (pd.isna(stoich) or pd.isna(all_sq)\n or str(stoich).strip() == \"\" or str(all_sq).strip() == \"\"):\n return [(0, len(seq))]\n try:\n chain_dict = parse_fasta(all_sq)\n order = parse_stoichiometry(stoich)\n segs, pos = [], 0\n for ch, cnt in order:\n base = chain_dict.get(ch)\n if base is None:\n return [(0, len(seq))]\n for _ in range(cnt):\n segs.append((pos, pos + len(base)))\n pos += len(base)\n return segs if pos == len(seq) else [(0, len(seq))]\n except Exception:\n return [(0, len(seq))]\n\n\ndef build_segments_map(df: pd.DataFrame) -> tuple:\n seg_map, stoich_map = {}, {}\n for _, r in df.iterrows():\n tid = r[\"target_id\"]\n seg_map[tid] = get_chain_segments(r)\n raw_s = r.get(\"stoichiometry\", \"\")\n stoich_map[tid] = \"\" if pd.isna(raw_s) else str(raw_s)\n return seg_map, stoich_map\n\n\ndef process_labels(labels_df: pd.DataFrame) -> dict:\n coords = {}\n prefixes = labels_df[\"ID\"].str.rsplit(\"_\", n=1).str[0]\n for prefix, grp in labels_df.groupby(prefixes):\n coords[prefix] = grp.sort_values(\"resid\")[[\"x_1\", \"y_1\", \"z_1\"]].values\n return coords\n\n\ndef _build_aligned_strings(query_seq, template_seq, alignment):\n q_segs, t_segs = alignment.aligned\n aq, at, qi, ti = [], [], 0, 0\n for (qs, qe), (ts, te) in zip(q_segs, t_segs):\n while qi < qs: aq.append(query_seq[qi]); at.append(\"-\"); qi += 1\n while ti < ts: aq.append(\"-\"); at.append(template_seq[ti]); ti += 1\n for qp, tp in zip(range(qs, qe), range(ts, te)):\n aq.append(query_seq[qp]); at.append(template_seq[tp])\n qi, ti = qe, te\n while qi < len(query_seq): aq.append(query_seq[qi]); at.append(\"-\"); qi += 1\n while ti < len(template_seq): aq.append(\"-\"); at.append(template_seq[ti]); ti += 1\n return \"\".join(aq), \"\".join(at)\n\n\ndef find_similar_sequences_detailed(query_seq, train_seqs_df, train_coords_dict, top_n=30):\n results = []\n for _, row in train_seqs_df.iterrows():\n tid, tseq = row[\"target_id\"], row[\"sequence\"]\n if tid not in train_coords_dict:\n continue\n if abs(len(tseq) - len(query_seq)) / max(len(tseq), len(query_seq)) > 0.3:\n continue\n aln = next(iter(_aligner.align(query_seq, tseq)))\n norm_s = aln.score / (2 * min(len(query_seq), len(tseq)))\n identical = sum(\n 1 for (qs, qe), (ts, te) in zip(*aln.aligned)\n for qp, tp in zip(range(qs, qe), range(ts, te))\n if query_seq[qp] == tseq[tp]\n )\n pct_id = 100 * identical / len(query_seq)\n aq, at = _build_aligned_strings(query_seq, tseq, aln)\n results.append((tid, tseq, norm_s, train_coords_dict[tid], pct_id, aq, at))\n results.sort(key=lambda x: x[2], reverse=True)\n return results[:top_n]\n\n\ndef adapt_template_to_query(query_seq, template_seq, template_coords) -> np.ndarray:\n aln = next(iter(_aligner.align(query_seq, template_seq)))\n new_coords = np.full((len(query_seq), 3), np.nan)\n for (qs, qe), (ts, te) in zip(*aln.aligned):\n chunk = template_coords[ts:te]\n if len(chunk) == (qe - qs):\n new_coords[qs:qe] = chunk\n for i in range(len(new_coords)):\n if np.isnan(new_coords[i, 0]):\n pv = next((j for j in range(i - 1, -1, -1) if not np.isnan(new_coords[j, 0])), -1)\n nv = next((j for j in range(i + 1, len(new_coords)) if not np.isnan(new_coords[j, 0])), -1)\n if pv >= 0 and nv >= 0:\n w = (i - pv) / (nv - pv)\n new_coords[i] = (1 - w) * new_coords[pv] + w * new_coords[nv]\n elif pv >= 0:\n new_coords[i] = new_coords[pv] + [3, 0, 0]\n elif nv >= 0:\n new_coords[i] = new_coords[nv] + [3, 0, 0]\n else:\n new_coords[i] = [i * 3, 0, 0]\n return np.nan_to_num(new_coords)\n\n\ndef adaptive_rna_constraints(coords, target_id, segments_map, confidence=1.0, passes=2) -> np.ndarray:\n X = coords.copy()\n segments = segments_map.get(target_id, [(0, len(X))])\n strength = max(0.75 * (1.0 - min(confidence, 0.97)), 0.02)\n for _ in range(passes):\n for s, e in segments:\n C = X[s:e]; L = e - s\n if L < 3:\n continue\n # bond i–i+1 ~5.95 Å\n d = C[1:] - C[:-1]; dist = np.linalg.norm(d, axis=1) + 1e-6\n adj = d * ((5.95 - dist) / dist)[:, None] * (0.22 * strength)\n C[:-1] -= adj; C[1:] += adj\n # soft i–i+2 ~10.2 Å\n d2 = C[2:] - C[:-2]; d2n = np.linalg.norm(d2, axis=1) + 1e-6\n adj2 = d2 * ((10.2 - d2n) / d2n)[:, None] * (0.10 * strength)\n C[:-2] -= adj2; C[2:] += adj2\n # Laplacian smoothing\n C[1:-1] += (0.06 * strength) * (0.5 * (C[:-2] + C[2:]) - C[1:-1])\n # self-avoidance\n if L >= 25:\n idx = np.linspace(0, L - 1, min(L, 160)).astype(int) if L > 220 else np.arange(L)\n P = C[idx]; diff = P[:, None, :] - P[None, :, :]\n dm = np.linalg.norm(diff, axis=2) + 1e-6\n sep = np.abs(idx[:, None] - idx[None, :])\n mask = (sep > 2) & (dm < 3.2)\n if np.any(mask):\n vec = (diff * ((3.2 - dm) / dm)[:, :, None] * mask[:, :, None]).sum(axis=1)\n C[idx] += (0.015 * strength) * vec\n X[s:e] = C\n return X\n\n\ndef _rotmat(axis, ang):\n a = np.asarray(axis, float); a /= np.linalg.norm(a) + 1e-12\n x, y, z = a; c, s = np.cos(ang), np.sin(ang); CC = 1 - c\n return np.array([[c+x*x*CC, x*y*CC-z*s, x*z*CC+y*s],\n [y*x*CC+z*s, c+y*y*CC, y*z*CC-x*s],\n [z*x*CC-y*s, z*y*CC+x*s, c+z*z*CC]])\n\n\ndef apply_hinge(coords, seg, rng, deg=22):\n s, e = seg; L = e - s\n if L < 30: return coords\n pivot = s + int(rng.integers(10, L - 10))\n R = _rotmat(rng.normal(size=3), np.deg2rad(float(rng.uniform(-deg, deg))))\n X = coords.copy(); p0 = X[pivot].copy()\n X[pivot+1:e] = (X[pivot+1:e] - p0) @ R.T + p0\n return X\n\n\ndef jitter_chains(coords, segs, rng, deg=12, trans=1.5):\n X = coords.copy(); gc_ = X.mean(0, keepdims=True)\n for s, e in segs:\n R = _rotmat(rng.normal(size=3), np.deg2rad(float(rng.uniform(-deg, deg))))\n shift = rng.normal(size=3); shift = shift / (np.linalg.norm(shift) + 1e-12) * float(rng.uniform(0, trans))\n c = X[s:e].mean(0, keepdims=True)\n X[s:e] = (X[s:e] - c) @ R.T + c + shift\n X -= X.mean(0, keepdims=True) - gc_\n return X\n\n\ndef smooth_wiggle(coords, segs, rng, amp=0.8):\n X = coords.copy()\n for s, e in segs:\n L = e - s\n if L < 20: continue\n ctrl = np.linspace(0, L - 1, 6); disp = rng.normal(0, amp, (6, 3)); t = np.arange(L)\n X[s:e] += np.vstack([np.interp(t, ctrl, disp[:, k]) for k in range(3)]).T\n return X\n\n\ndef generate_rna_structure(sequence: str, seed=None) -> np.ndarray:\n \"\"\"Idealized A-form RNA helix — last-resort de-novo fallback.\"\"\"\n if seed is not None:\n np.random.seed(seed)\n n = len(sequence); coords = np.zeros((n, 3))\n for i in range(n):\n ang = i * 0.6\n coords[i] = [10.0 * np.cos(ang), 10.0 * np.sin(ang), i * 2.5]\n return coords\n\n\n# ─────────────── TBM Phase ───────────────────────────────────────────────────\ndef tbm_phase(test_df, train_seqs_df, train_coords_dict, segments_map):\n \"\"\"\n Phase 1 — Template-Based Modeling.\n\n Returns\n -------\n template_predictions : {target_id: [np.ndarray(seq_len, 3), ...]}\n 0 to N_SAMPLE predictions per target, from real templates.\n protenix_queue : {target_id: (n_needed, full_sequence)}\n Targets that still need more predictions.\n \"\"\"\n print(f\"\\n{'='*60}\")\n print(f\"PHASE 1: Template-Based Modeling\")\n print(f\" MIN_SIMILARITY = {MIN_SIMILARITY} | MIN_PCT_IDENTITY = {MIN_PERCENT_IDENTITY}\")\n print(f\"{'='*60}\")\n t0 = time.time()\n\n template_predictions: dict = {}\n protenix_queue: dict = {}\n\n for _, row in test_df.iterrows():\n tid = row[\"target_id\"]\n seq = row[\"sequence\"]\n segs = segments_map.get(tid, [(0, len(seq))])\n\n similar = find_similar_sequences_detailed(seq, train_seqs_df, train_coords_dict, top_n=30)\n preds = []\n used = set()\n\n for i, (tmpl_id, tmpl_seq, sim, tmpl_coords, pct_id, _, _) in enumerate(similar):\n if len(preds) >= N_SAMPLE:\n break\n if sim < MIN_SIMILARITY or pct_id < MIN_PERCENT_IDENTITY:\n break # list is sorted by sim, so no point continuing\n if tmpl_id in used:\n continue\n\n rng = np.random.default_rng((row.name * 10000000000 + i * 10007) % (2**32))\n adapted = adapt_template_to_query(seq, tmpl_seq, tmpl_coords)\n\n # Diversity transforms (same strategy as the 0-409 TBM notebook)\n slot = len(preds)\n if slot == 0:\n X = adapted\n elif slot == 1:\n X = adapted + rng.normal(0, max(0.01, (0.40 - sim) * 0.06), adapted.shape)\n elif slot == 2:\n longest = max(segs, key=lambda se: se[1] - se[0])\n X = apply_hinge(adapted, longest, rng)\n elif slot == 3:\n X = jitter_chains(adapted, segs, rng)\n else:\n X = smooth_wiggle(adapted, segs, rng)\n\n refined = adaptive_rna_constraints(X, tid, segments_map, confidence=sim)\n preds.append(refined)\n used.add(tmpl_id)\n\n template_predictions[tid] = preds\n n_needed = N_SAMPLE - len(preds)\n if n_needed > 0:\n protenix_queue[tid] = (n_needed, seq)\n print(f\" {tid} ({len(seq)} nt): {len(preds)} TBM → need {n_needed} from Protenix\")\n else:\n print(f\" {tid} ({len(seq)} nt): all {N_SAMPLE} from TBM ✓\")\n\n elapsed = time.time() - t0\n n_full = len(test_df) - len(protenix_queue)\n print(f\"\\nPhase 1 done in {elapsed:.1f}s\")\n print(f\" Fully covered by TBM : {n_full}\")\n print(f\" Need Protenix : {len(protenix_queue)}\")\n return template_predictions, protenix_queue\n\n\n# ─────────────── Main ────────────────────────────────────────────────────────\ndef main() -> None:\n test_csv, output_csv, code_dir, root_dir = resolve_paths()\n\n if not os.path.isdir(code_dir):\n raise FileNotFoundError(\n f\"Missing PROTENIX_CODE_DIR: {code_dir}. \"\n \"Set PROTENIX_CODE_DIR to the repo path.\"\n )\n\n os.environ[\"PROTENIX_ROOT_DIR\"] = root_dir\n sys.path.append(code_dir)\n ensure_required_files(root_dir)\n seed_everything(SEED)\n\n # ── Load test data ──────────────────────────────────────────────────────\n test_df_full = pd.read_csv(test_csv)\n test_df = (test_df_full.head(LOCAL_N_SAMPLES) if not IS_KAGGLE\n else test_df_full).reset_index(drop=True)\n print(f\"Test targets : {len(test_df)}\"\n + (\" (LOCAL MODE)\" if not IS_KAGGLE else \"\"))\n\n seq_by_id = dict(zip(test_df[\"target_id\"], test_df[\"sequence\"]))\n\n # Truncated copy for Protenix (Protenix has token limits)\n test_df_trunc = test_df.copy()\n test_df_trunc[\"sequence\"] = test_df_trunc[\"sequence\"].str[:MAX_SEQ_LEN]\n\n # ── Load training data for TBM ──────────────────────────────────────────\n print(\"\\nLoading training data for TBM …\")\n train_seqs = pd.read_csv(DEFAULT_TRAIN_CSV)\n val_seqs = pd.read_csv(DEFAULT_VAL_CSV)\n train_labels = pd.read_csv(DEFAULT_TRAIN_LBLS, usecols=[\"ID\",\"resid\",\"x_1\",\"y_1\",\"z_1\"], low_memory=False, dtype={\"resid\": \"int32\", \"x_1\": \"float32\", \"y_1\": \"float32\", \"z_1\": \"float32\"})\n val_labels = pd.read_csv(DEFAULT_VAL_LBLS, usecols=[\"ID\",\"resid\",\"x_1\",\"y_1\",\"z_1\"], low_memory=False, dtype={\"resid\": \"int32\", \"x_1\": \"float32\", \"y_1\": \"float32\", \"z_1\": \"float32\"})\n\n combined_seqs = pd.concat([train_seqs, val_seqs], ignore_index=True)\n combined_labels = pd.concat([train_labels, val_labels], ignore_index=True)\n train_coords = process_labels(combined_labels)\n segments_map, _ = build_segments_map(test_df)\n\n print(f\"Template pool: {len(combined_seqs)} sequences, {len(train_coords)} structures\")\n\n # ─── PHASE 1: TBM ──────────────────────────────────────────────────────\n template_preds, protenix_queue = tbm_phase(\n test_df, combined_seqs, train_coords, segments_map\n )\n\n # ─── PHASE 2: Protenix (only for targets that need extra predictions) ──\n protenix_preds: dict = {} # target_id -> np.ndarray (n_needed, seq_len, 3)\n\n if protenix_queue and USE_PROTENIX:\n print(f\"\\n{'='*60}\")\n print(f\"PHASE 2: Protenix for {len(protenix_queue)} targets\")\n print(f\"{'='*60}\")\n\n work_dir = Path(\"/content\")\n work_dir.mkdir(parents=True, exist_ok=True)\n\n # ── 1. Preparation: create tasks for all sequences/chunks ────────────\n tasks = [] # list of dict for input_json\n chunk_info = {} # target_id -> list of {\"name\": chunk_name, \"range\": (s, e)}\n \n for target_id, (n_needed, full_seq) in protenix_queue.items():\n seq_len = len(full_seq)\n if seq_len <= MAX_SEQ_LEN:\n tasks.append({\"target_id\": target_id, \"sequence\": full_seq})\n chunk_info[target_id] = [{\"name\": target_id, \"range\": (0, seq_len)}]\n print(f\" {target_id} ({seq_len} nt): single pass queued\")\n else:\n chunks = split_into_chunks(seq_len, MAX_SEQ_LEN, CHUNK_OVERLAP)\n print(f\" {target_id} ({seq_len} nt): {len(chunks)} chunks queued \"\n f\"{[(s, e) for s, e in chunks]}\")\n \n chunk_info[target_id] = []\n for ci, (cs, ce) in enumerate(chunks):\n chunk_name = f\"{target_id}_chunk{ci}\"\n sub_seq = full_seq[cs:ce]\n tasks.append({\"target_id\": chunk_name, \"sequence\": sub_seq})\n chunk_info[target_id].append({\"name\": chunk_name, \"range\": (cs, ce)})\n\n # Build combined input JSON\n tasks_df = pd.DataFrame(tasks)\n input_json_path = str(work_dir / \"protenix_queue_input.json\")\n build_input_json(tasks_df, input_json_path)\n\n from protenix.data.inference.infer_dataloader import InferenceDataset\n from runner.inference import (InferenceRunner,\n update_gpu_compatible_configs,\n update_inference_configs)\n\n # Initialize model exactly ONCE\n configs = build_configs(input_json_path, str(work_dir / \"outputs\"), MODEL_NAME)\n configs = update_gpu_compatible_configs(configs)\n runner = InferenceRunner(configs)\n dataset = InferenceDataset(configs)\n\n # ── 2. Inference: process dataset and collect predictions ────────────\n raw_predictions = {} # sample_name -> coords (np.ndarray or None)\n\n def _extract_c1_coords(prediction, feat, chunk_seq_len, raw_coords):\n if \"centre_atom_mask\" in feat:\n mask = (feat[\"centre_atom_mask\"] == 1).to(raw_coords.device)\n elif \"atom_to_tokatom_idx\" in feat:\n m11 = (feat[\"atom_to_tokatom_idx\"] == 11).to(raw_coords.device)\n m12 = (feat[\"atom_to_tokatom_idx\"] == 12).to(raw_coords.device)\n c11, c12 = m11.sum(), m12.sum()\n mask = m11 if abs(c11 - chunk_seq_len) < abs(c12 - chunk_seq_len) else m12\n else:\n mask = torch.zeros(raw_coords.shape[1], dtype=torch.bool, device=raw_coords.device)\n \n coords = raw_coords[:, mask, :].detach().cpu().numpy()\n \n # Collapse check\n if coords.shape[1] > 1:\n diffs = np.linalg.norm(coords[0, 1:] - coords[0, :-1], axis=-1)\n if np.all(diffs < 1e-4):\n print(f\" WARNING: Collapsed coordinates detected\")\n return None\n \n if coords.shape[1] != chunk_seq_len:\n if coords.shape[1] == 1 and chunk_seq_len > 1:\n return None\n padded = np.zeros((coords.shape[0], chunk_seq_len, 3), dtype=np.float32)\n ml = min(coords.shape[1], chunk_seq_len)\n padded[:, :ml, :] = coords[:, :ml, :]\n coords = padded\n return coords\n\n for i in tqdm(range(len(dataset)), desc=\"Protenix Inference\"):\n data, atom_array, err = dataset[i]\n sample_name = data.get(\"sample_name\", f\"sample_{i}\")\n \n if err:\n print(f\" {sample_name} data error: {err}\")\n raw_predictions[sample_name] = None\n del data, atom_array, err\n gc.collect(); torch.cuda.empty_cache(); gc.collect()\n continue\n \n # Find how many samples are needed for this specific query\n target_id = sample_name.split(\"_chunk\")[0] if \"_chunk\" in sample_name else sample_name\n n_needed = protenix_queue.get(target_id, (N_SAMPLE, \"\"))[0]\n sub_seq_len = data[\"N_token\"].item() # roughly correct\n \n try:\n new_cfg = update_inference_configs(configs, sub_seq_len)\n new_cfg.sample_diffusion.N_sample = n_needed\n runner.update_model_configs(new_cfg)\n \n pred = runner.predict(data)\n raw_coords = pred[\"coordinate\"]\n \n coords = _extract_c1_coords(pred, data[\"input_feature_dict\"], \n sub_seq_len, raw_coords)\n raw_predictions[sample_name] = coords\n except Exception as exc:\n print(f\" {sample_name} inference failed: {exc}\")\n import traceback; traceback.print_exc()\n raw_predictions[sample_name] = None\n finally:\n try: del pred, data, atom_array, raw_coords\n except: pass\n gc.collect(); torch.cuda.empty_cache(); gc.collect()\n\n # ── 3. Post-processing: Stitching and final formatting ───────────────\n for target_id, (n_needed, full_seq) in protenix_queue.items():\n seq_len = len(full_seq)\n chunks = chunk_info.get(target_id, [])\n \n if not chunks:\n continue\n\n if len(chunks) == 1:\n # Single pass\n coords = raw_predictions.get(target_id)\n protenix_preds[target_id] = coords\n if coords is not None:\n print(f\" {target_id}: {coords.shape[0]} predictions generated\")\n else:\n print(f\" {target_id}: FAILED\")\n else:\n # Stitch chunks together\n chunk_results_per_sample = {s: [] for s in range(n_needed)}\n all_ok = True\n \n for ci, cinfo in enumerate(chunks):\n cname = cinfo[\"name\"]\n crange = cinfo[\"range\"]\n ccoords = raw_predictions.get(cname)\n \n if ccoords is None:\n all_ok = False\n break\n \n for s_idx in range(n_needed):\n if s_idx < ccoords.shape[0]:\n chunk_results_per_sample[s_idx].append((ccoords[s_idx], crange))\n else:\n chunk_results_per_sample[s_idx].append((ccoords[-1], crange))\n \n if not all_ok:\n print(f\" {target_id}: chunked inference incomplete, using fallback\")\n protenix_preds[target_id] = None\n continue\n \n stitched_samples = []\n for s_idx in range(n_needed):\n items = chunk_results_per_sample[s_idx]\n coords_list = [c for c, _ in items]\n ranges_list = [r for _, r in items]\n full_coords = stitch_chunk_coords(coords_list, ranges_list, seq_len)\n stitched_samples.append(full_coords)\n \n result = np.stack(stitched_samples, axis=0)\n protenix_preds[target_id] = result\n print(f\" {target_id}: {result.shape[0]} stitched predictions generated\")\n# ...existing code...\n\n elif protenix_queue and not USE_PROTENIX:\n print(f\"\\nPHASE 2 skipped (USE_PROTENIX=False). \"\n f\"De-novo fallback will cover {len(protenix_queue)} targets.\")\n\n # ─── PHASE 3: Combine everything ───────────────────────────────────────\n print(f\"\\n{'='*60}\")\n print(\"PHASE 3: Combine TBM + Protenix + de-novo fallback\")\n print(f\"{'='*60}\")\n\n all_rows = []\n\n for _, row in test_df.iterrows():\n tid = row[\"target_id\"]\n seq = row[\"sequence\"]\n\n combined: list = list(template_preds.get(tid, [])) # TBM predictions\n\n # Append Protenix predictions to fill remaining slots\n ptx = protenix_preds.get(tid)\n if ptx is not None and ptx.ndim == 3:\n for j in range(ptx.shape[0]):\n if len(combined) >= N_SAMPLE:\n break\n combined.append(ptx[j]) # (seq_len, 3)\n\n # De-novo fallback for any still-empty slots\n n_denovo = 0\n while len(combined) < N_SAMPLE:\n seed_val = row.name * 1000000 + len(combined) * 1000\n dn = generate_rna_structure(seq, seed=seed_val)\n combined.append(adaptive_rna_constraints(dn, tid, segments_map, confidence=0.2))\n n_denovo += 1\n\n if n_denovo:\n print(f\" {tid}: {n_denovo} slot(s) filled with de-novo fallback\")\n\n # Stack to (N_SAMPLE, seq_len, 3) and write rows\n stacked = np.stack(combined[:N_SAMPLE], axis=0)\n all_rows.extend(coords_to_rows(tid, seq, stacked))\n\n # ── Save ───────────────────────────────────────────────────────────────\n sub = pd.DataFrame(all_rows)\n cols = [\"ID\", \"resname\", \"resid\"] + [\n f\"{c}_{i}\" for i in range(1, N_SAMPLE + 1) for c in [\"x\", \"y\", \"z\"]\n ]\n coord_cols = [c for c in cols if c.startswith((\"x_\", \"y_\", \"z_\"))]\n sub[coord_cols] = sub[coord_cols].clip(-999.999, 9999.999)\n sub[cols].to_csv(output_csv, index=False)\n\n print(f\"\\n✓ Saved submission to {output_csv} ({len(sub):,} rows)\")"
],
"metadata": {
"trusted": true
},
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"source": "<a id=\"section2\"></a>\n\n# <div style=\"text-align:center; border-radius:15px; padding:15px; font-size:115%; font-family:'Arial Black', Arial, sans-serif; text-shadow:2px 2px 5px rgba(0,0,0,0.7); background:linear-gradient(90deg, #1e3c72, #c31432); box-shadow:0 2px 5px rgba(0,0,0,0.3); color:white;\"><b>MAIN</b></div>",
"metadata": {}
},
{
"cell_type": "code",
"source": "if __name__ == \"__main__\":\n main()",
"metadata": {
"trusted": true
},
"outputs": [],
"execution_count": null
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment