| import os |
| import math |
| from pathlib import Path |
| import sys |
| from contextlib import contextmanager |
| import numpy as np |
| import pandas as pd |
| import torch |
| from tqdm import tqdm |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence |
| from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM |
| from lightning.pytorch import seed_everything |
| seed_everything(1986) |
|
|
| CSV_PATH = Path("./Classifier_Weight/training_data_cleaned/binding_affinity/c-binding_with_openfold_scores.csv") |
|
|
| OUT_ROOT = Path( |
| "./Classifier_Weight/training_data_cleaned/binding_affinity" |
| ) |
|
|
| |
| WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D" |
| WT_MAX_LEN = 1022 |
| WT_BATCH = 32 |
|
|
| |
| SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all" |
| TOKENIZER_VOCAB = "./Classifier_Weight/tokenizer/new_vocab.txt" |
| TOKENIZER_SPLITS = "./Classifier_Weight/tokenizer/new_splits.txt" |
| SMI_MAX_LEN = 768 |
| SMI_BATCH = 128 |
|
|
| |
| TRAIN_FRAC = 0.80 |
| RANDOM_SEED = 1986 |
| AFFINITY_Q_BINS = 30 |
|
|
| COL_SEQ1 = "seq1" |
| COL_SEQ2 = "seq2" |
| COL_AFF = "affinity" |
| COL_F2S = "Fasta2SMILES" |
| COL_REACT = "REACT_SMILES" |
| COL_WT_IPTM = "wt_iptm_score" |
| COL_SMI_IPTM = "smiles_iptm_score" |
|
|
| |
| DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| QUIET = True |
| USE_TQDM = False |
| LOG_FILE = None |
|
|
| def log(msg: str): |
| if LOG_FILE is not None: |
| Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True) |
| with open(LOG_FILE, "a") as f: |
| f.write(msg.rstrip() + "\n") |
| if not QUIET: |
| print(msg) |
|
|
| def pbar(it, **kwargs): |
| return tqdm(it, **kwargs) if USE_TQDM else it |
|
|
| @contextmanager |
| def section(title: str): |
| log(f"\n=== {title} ===") |
| yield |
| log(f"=== done: {title} ===") |
|
|
|
|
| |
| |
| |
| def has_uaa(seq: str) -> bool: |
| return "X" in str(seq).upper() |
|
|
| def affinity_to_class(a: float) -> str: |
| |
| if a >= 9.0: |
| return "High" |
| elif a >= 7.0: |
| return "Moderate" |
| else: |
| return "Low" |
|
|
| def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame: |
| df = df.copy() |
|
|
| df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
| df = df.dropna(subset=[COL_AFF]).reset_index(drop=True) |
|
|
| df["affinity_class"] = df[COL_AFF].apply(affinity_to_class) |
|
|
| try: |
| df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop") |
| strat_col = "aff_bin" |
| except Exception: |
| df["aff_bin"] = df["affinity_class"] |
| strat_col = "aff_bin" |
|
|
| rng = np.random.RandomState(RANDOM_SEED) |
|
|
| df["split"] = None |
| for _, g in df.groupby(strat_col, observed=True): |
| idx = g.index.to_numpy() |
| rng.shuffle(idx) |
| n_train = int(math.floor(len(idx) * TRAIN_FRAC)) |
| df.loc[idx[:n_train], "split"] = "train" |
| df.loc[idx[n_train:], "split"] = "val" |
|
|
| df["split"] = df["split"].fillna("train") |
| return df |
|
|
| def _summ(x): |
| x = np.asarray(x, dtype=float) |
| x = x[~np.isnan(x)] |
| if len(x) == 0: |
| return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
| return { |
| "n": int(len(x)), |
| "mean": float(np.mean(x)), |
| "std": float(np.std(x)), |
| "p50": float(np.quantile(x, 0.50)), |
| "p95": float(np.quantile(x, 0.95)), |
| } |
|
|
| def _len_stats(seqs): |
| lens = np.asarray([len(str(s)) for s in seqs], dtype=float) |
| if len(lens) == 0: |
| return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan} |
| return { |
| "n": int(len(lens)), |
| "mean": float(lens.mean()), |
| "std": float(lens.std()), |
| "p50": float(np.quantile(lens, 0.50)), |
| "p95": float(np.quantile(lens, 0.95)), |
| } |
|
|
| def verify_split_before_embedding( |
| df2: pd.DataFrame, |
| affinity_col: str, |
| split_col: str, |
| seq_col: str, |
| iptm_col: str, |
| aff_class_col: str = "affinity_class", |
| aff_bins: int = 30, |
| save_report_prefix: str | None = None, |
| verbose: bool = False, |
| ): |
| df2 = df2.copy() |
| df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce") |
| df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce") |
|
|
| assert split_col in df2.columns, f"Missing split col: {split_col}" |
| assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}" |
| assert df2[affinity_col].notna().any(), "No valid affinity values after coercion." |
|
|
| try: |
| df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop") |
| except Exception: |
| df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str) |
|
|
| tr = df2[df2[split_col] == "train"].reset_index(drop=True) |
| va = df2[df2[split_col] == "val"].reset_index(drop=True) |
|
|
| tr_aff = _summ(tr[affinity_col].to_numpy()) |
| va_aff = _summ(va[affinity_col].to_numpy()) |
| tr_len = _len_stats(tr[seq_col].tolist()) |
| va_len = _len_stats(va[seq_col].tolist()) |
|
|
| |
| bin_ct = ( |
| df2.groupby([split_col, "_aff_bin_dbg"]) |
| .size() |
| .groupby(level=0) |
| .apply(lambda s: s / s.sum()) |
| ) |
| tr_bins = bin_ct.loc["train"] |
| va_bins = bin_ct.loc["val"] |
| all_bins = tr_bins.index.union(va_bins.index) |
| tr_bins = tr_bins.reindex(all_bins, fill_value=0.0) |
| va_bins = va_bins.reindex(all_bins, fill_value=0.0) |
| max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values))) |
|
|
| msg = ( |
| f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | " |
| f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | " |
| f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | " |
| f"max_bin_diff={max_bin_diff:.4f}" |
| ) |
| log(msg) |
|
|
| if verbose and (not QUIET): |
| class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
| class_prop = class_ct.div(class_ct.sum(axis=1), axis=0) |
| print("\n[verbose] affinity_class counts:\n", class_ct) |
| print("\n[verbose] affinity_class proportions:\n", class_prop.round(4)) |
|
|
| if save_report_prefix is not None: |
| out = Path(save_report_prefix) |
| out.parent.mkdir(parents=True, exist_ok=True) |
|
|
| stats_df = pd.DataFrame([ |
| {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}}, |
| {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}}, |
| ]) |
| class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0) |
| class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index() |
|
|
| stats_df.to_csv(out.with_suffix(".stats.csv"), index=False) |
| class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022): |
| embs = [] |
| for i in pbar(range(0, len(seqs), batch_size)): |
| batch = seqs[i:i + batch_size] |
| inputs = tokenizer( |
| batch, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| out = model(**inputs) |
| h = out.last_hidden_state |
|
|
| attn = inputs["attention_mask"].unsqueeze(-1) |
| summed = (h * attn).sum(dim=1) |
| denom = attn.sum(dim=1).clamp(min=1e-9) |
| pooled = (summed / denom).detach().cpu().numpy() |
| embs.append(pooled) |
|
|
| return np.vstack(embs) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022): |
| tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
| tok = {k: v.to(DEVICE) for k, v in tok.items()} |
| out = model(**tok) |
| h = out.last_hidden_state[0] |
| attn = tok["attention_mask"][0].bool() |
| ids = tok["input_ids"][0] |
|
|
| keep = attn.clone() |
| if cls_id is not None: |
| keep &= (ids != cls_id) |
| if eos_id is not None: |
| keep &= (ids != eos_id) |
|
|
| return h[keep].detach().cpu().to(torch.float16).numpy() |
|
|
| def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model): |
| """ |
| Expects df_split to have: |
| - target_sequence (seq1) |
| - sequence (binder seq2; WT binder) |
| - label, affinity_class, COL_AFF, COL_WT_IPTM |
| Saves a dataset where each row contains BOTH: |
| - target_embedding (Lt,H), target_attention_mask, target_length |
| - binder_embedding (Lb,H), binder_attention_mask, binder_length |
| """ |
| cls_id = tokenizer.cls_token_id |
| eos_id = tokenizer.eos_token_id |
| H = model.config.hidden_size |
|
|
| features = Features({ |
| "target_sequence": Value("string"), |
| "sequence": Value("string"), |
| "label": Value("float32"), |
| "affinity": Value("float32"), |
| "affinity_class": Value("string"), |
|
|
| "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
| "target_attention_mask": HFSequence(Value("int8")), |
| "target_length": Value("int64"), |
|
|
| "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)), |
| "binder_attention_mask": HFSequence(Value("int8")), |
| "binder_length": Value("int64"), |
|
|
| COL_WT_IPTM: Value("float32"), |
| COL_AFF: Value("float32"), |
| }) |
|
|
| def gen_rows(df: pd.DataFrame): |
| for r in pbar(df.itertuples(index=False), total=len(df)): |
| tgt = str(getattr(r, "target_sequence")).strip() |
| bnd = str(getattr(r, "sequence")).strip() |
|
|
| y = float(getattr(r, "label")) |
| aff = float(getattr(r, COL_AFF)) |
| acls = str(getattr(r, "affinity_class")) |
|
|
| iptm = getattr(r, COL_WT_IPTM) |
| iptm = float(iptm) if pd.notna(iptm) else np.nan |
|
|
| |
| t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
| b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) |
|
|
| t_list = t_emb.tolist() |
| b_list = b_emb.tolist() |
| Lt = len(t_list) |
| Lb = len(b_list) |
|
|
| yield { |
| "target_sequence": tgt, |
| "sequence": bnd, |
| "label": np.float32(y), |
| "affinity": np.float32(aff), |
| "affinity_class": acls, |
|
|
| "target_embedding": t_list, |
| "target_attention_mask": [1] * Lt, |
| "target_length": int(Lt), |
|
|
| "binder_embedding": b_list, |
| "binder_attention_mask": [1] * Lb, |
| "binder_length": int(Lb), |
|
|
| COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
| COL_AFF: np.float32(aff), |
| } |
|
|
| out_dir.mkdir(parents=True, exist_ok=True) |
| ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
| ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
| return ds |
|
|
| def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled, |
| smi_tok, smi_roformer): |
| """ |
| df_split must have: |
| - target_sequence (seq1) |
| - sequence (binder smiles string) |
| - label, affinity_class, COL_AFF, COL_SMI_IPTM |
| Saves rows with: |
| target_embedding (Lt,Ht) from ESM |
| binder_embedding (Lb,Hb) from PeptideCLM |
| """ |
| cls_id = wt_tokenizer.cls_token_id |
| eos_id = wt_tokenizer.eos_token_id |
| Ht = wt_model_unpooled.config.hidden_size |
|
|
| Hb = getattr(smi_roformer.config, "hidden_size", None) |
| if Hb is None: |
| Hb = getattr(smi_roformer.config, "dim", None) |
| if Hb is None: |
| raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.") |
|
|
| features = Features({ |
| "target_sequence": Value("string"), |
| "sequence": Value("string"), |
| "label": Value("float32"), |
| "affinity": Value("float32"), |
| "affinity_class": Value("string"), |
|
|
| "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)), |
| "target_attention_mask": HFSequence(Value("int8")), |
| "target_length": Value("int64"), |
|
|
| "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)), |
| "binder_attention_mask": HFSequence(Value("int8")), |
| "binder_length": Value("int64"), |
|
|
| COL_SMI_IPTM: Value("float32"), |
| COL_AFF: Value("float32"), |
| }) |
|
|
| def gen_rows(df: pd.DataFrame): |
| for r in pbar(df.itertuples(index=False), total=len(df)): |
| tgt = str(getattr(r, "target_sequence")).strip() |
| bnd = str(getattr(r, "sequence")).strip() |
|
|
| y = float(getattr(r, "label")) |
| aff = float(getattr(r, COL_AFF)) |
| acls = str(getattr(r, "affinity_class")) |
|
|
| iptm = getattr(r, COL_SMI_IPTM) |
| iptm = float(iptm) if pd.notna(iptm) else np.nan |
|
|
| |
| t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN) |
| t_list = t_emb.tolist() |
| Lt = len(t_list) |
|
|
| |
| _, tok_list, mask_list, lengths = smiles_embed_batch_return_both( |
| [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN |
| ) |
| b_emb = tok_list[0] |
| b_list = b_emb.tolist() |
| Lb = int(lengths[0]) |
| b_mask = mask_list[0].astype(np.int8).tolist() |
|
|
| yield { |
| "target_sequence": tgt, |
| "sequence": bnd, |
| "label": np.float32(y), |
| "affinity": np.float32(aff), |
| "affinity_class": acls, |
|
|
| "target_embedding": t_list, |
| "target_attention_mask": [1] * Lt, |
| "target_length": int(Lt), |
|
|
| "binder_embedding": b_list, |
| "binder_attention_mask": [int(x) for x in b_mask], |
| "binder_length": int(Lb), |
|
|
| COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan), |
| COL_AFF: np.float32(aff), |
| } |
|
|
| out_dir.mkdir(parents=True, exist_ok=True) |
| ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features) |
| ds.save_to_disk(str(out_dir), max_shard_size="1GB") |
| return ds |
|
|
|
|
| |
| |
| |
| def get_special_ids(tokenizer_obj): |
| cand = [ |
| getattr(tokenizer_obj, "pad_token_id", None), |
| getattr(tokenizer_obj, "cls_token_id", None), |
| getattr(tokenizer_obj, "sep_token_id", None), |
| getattr(tokenizer_obj, "bos_token_id", None), |
| getattr(tokenizer_obj, "eos_token_id", None), |
| getattr(tokenizer_obj, "mask_token_id", None), |
| ] |
| return sorted({x for x in cand if x is not None}) |
|
|
| @torch.no_grad() |
| def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length): |
| tok = tokenizer_obj( |
| batch_sequences, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ) |
| input_ids = tok["input_ids"].to(DEVICE) |
| attention_mask = tok["attention_mask"].to(DEVICE) |
|
|
| outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask) |
| last_hidden = outputs.last_hidden_state |
|
|
| special_ids = get_special_ids(tokenizer_obj) |
| valid = attention_mask.bool() |
| if len(special_ids) > 0: |
| sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long) |
| if hasattr(torch, "isin"): |
| valid = valid & (~torch.isin(input_ids, sid)) |
| else: |
| m = torch.zeros_like(valid) |
| for s in special_ids: |
| m |= (input_ids == s) |
| valid = valid & (~m) |
|
|
| valid_f = valid.unsqueeze(-1).float() |
| summed = torch.sum(last_hidden * valid_f, dim=1) |
| denom = torch.clamp(valid_f.sum(dim=1), min=1e-9) |
| pooled = (summed / denom).detach().cpu().numpy() |
|
|
| token_emb_list, mask_list, lengths = [], [], [] |
| for b in range(last_hidden.shape[0]): |
| emb = last_hidden[b, valid[b]] |
| token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy()) |
| li = emb.shape[0] |
| lengths.append(int(li)) |
| mask_list.append(np.ones((li,), dtype=np.int8)) |
|
|
| return pooled, token_emb_list, mask_list, lengths |
|
|
| def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length): |
| pooled_all = [] |
| token_emb_all = [] |
| mask_all = [] |
| lengths_all = [] |
|
|
| for i in pbar(range(0, len(seqs), batch_size)): |
| batch = seqs[i:i + batch_size] |
| pooled, tok_list, m_list, lens = smiles_embed_batch_return_both( |
| batch, tokenizer_obj, model_roformer, max_length |
| ) |
| pooled_all.append(pooled) |
| token_emb_all.extend(tok_list) |
| mask_all.extend(m_list) |
| lengths_all.extend(lens) |
|
|
| return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all |
|
|
| def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame): |
| wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
| wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
|
|
| |
| tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist() |
| tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist() |
|
|
| wt_train_tgt_emb = wt_pooled_embeddings( |
| tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
| ) |
| wt_val_tgt_emb = wt_pooled_embeddings( |
| tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN |
| ) |
|
|
| |
| train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)} |
| val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)} |
| return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map |
| |
| |
| |
| def main(): |
| log(f"[INFO] DEVICE: {DEVICE}") |
| OUT_ROOT.mkdir(parents=True, exist_ok=True) |
|
|
| with section("load csv + dedup"): |
| df = pd.read_csv(CSV_PATH) |
| for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]: |
| if c in df.columns: |
| df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x) |
| |
| |
| DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT] |
| df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True) |
| |
| print("Rows after dedup on", DEDUP_COLS, ":", len(df)) |
|
|
| need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM] |
| missing = [c for c in need if c not in df.columns] |
| if missing: |
| raise ValueError(f"Missing required columns: {missing}") |
|
|
| |
| df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce") |
|
|
| |
| with section("prepare wt/smiles subsets"): |
| |
| df_wt = df.copy() |
| df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip() |
| df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")] |
| df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True) |
|
|
| |
| df_smi = df.copy() |
| df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True) |
| df_smi = df_smi[ |
| pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna() |
| ].reset_index(drop=True) |
|
|
| is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False) |
| df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S]) |
| df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip() |
| df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")] |
| df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True) |
|
|
| log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)") |
|
|
| |
| with section("split wt and smiles separately"): |
| df_wt2 = make_distribution_matched_split(df_wt) |
| df_smi2 = make_distribution_matched_split(df_smi) |
|
|
| |
| wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv" |
| smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv" |
| df_wt2.to_csv(wt_split_csv, index=False) |
| df_smi2.to_csv(smi_split_csv, index=False) |
| log(f"Saved WT split meta: {wt_split_csv}") |
| log(f"Saved SMILES split meta: {smi_split_csv}") |
|
|
| verify_split_before_embedding( |
| df2=df_wt2, |
| affinity_col=COL_AFF, |
| split_col="split", |
| seq_col="wt_sequence", |
| iptm_col=COL_WT_IPTM, |
| aff_class_col="affinity_class", |
| aff_bins=AFFINITY_Q_BINS, |
| save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"), |
| verbose=False, |
| ) |
| verify_split_before_embedding( |
| df2=df_smi2, |
| affinity_col=COL_AFF, |
| split_col="split", |
| seq_col="smiles_sequence", |
| iptm_col=COL_SMI_IPTM, |
| aff_class_col="affinity_class", |
| aff_bins=AFFINITY_Q_BINS, |
| save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"), |
| verbose=False, |
| ) |
|
|
| |
| def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame: |
| out = df_in.copy() |
| out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() |
| out["sequence"] = out[binder_seq_col].astype(str).str.strip() |
| out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce") |
| out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce") |
| out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True) |
| return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]] |
|
|
| wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM) |
| smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM) |
|
|
| |
| |
| |
| wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True) |
| wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True) |
| smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True) |
| smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True) |
| |
| |
| |
| |
| |
| with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"): |
| wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME) |
| wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval() |
| |
| |
| wt_train_tgt_emb = wt_pooled_embeddings( |
| wt_train["target_sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| wt_val_tgt_emb = wt_pooled_embeddings( |
| wt_val["target_sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| |
| smi_train_tgt_emb = wt_pooled_embeddings( |
| smi_train["target_sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| smi_val_tgt_emb = wt_pooled_embeddings( |
| smi_val["target_sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| |
| |
| |
| |
| with section("WT pooled binder embeddings + save"): |
| wt_train_emb = wt_pooled_embeddings( |
| wt_train["sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| wt_val_emb = wt_pooled_embeddings( |
| wt_val["sequence"].astype(str).str.strip().tolist(), |
| wt_tok, wt_esm, |
| batch_size=WT_BATCH, |
| max_length=WT_MAX_LEN, |
| ).astype(np.float32) |
| |
| wt_train_ds = Dataset.from_dict({ |
| "target_sequence": wt_train["target_sequence"].tolist(), |
| "sequence": wt_train["sequence"].tolist(), |
| "label": wt_train["label"].astype(float).tolist(), |
| "target_embedding": wt_train_tgt_emb, |
| "embedding": wt_train_emb, |
| COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(), |
| COL_AFF: wt_train[COL_AFF].astype(float).tolist(), |
| "affinity_class": wt_train["affinity_class"].tolist(), |
| }) |
| |
| wt_val_ds = Dataset.from_dict({ |
| "target_sequence": wt_val["target_sequence"].tolist(), |
| "sequence": wt_val["sequence"].tolist(), |
| "label": wt_val["label"].astype(float).tolist(), |
| "target_embedding": wt_val_tgt_emb, |
| "embedding": wt_val_emb, |
| COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(), |
| COL_AFF: wt_val[COL_AFF].astype(float).tolist(), |
| "affinity_class": wt_val["affinity_class"].tolist(), |
| }) |
| |
| wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds}) |
| wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled" |
| wt_pooled_dd.save_to_disk(str(wt_pooled_out)) |
| log(f"Saved WT pooled -> {wt_pooled_out}") |
| |
| |
| |
| |
| |
| with section("SMILES pooled binder embeddings + save"): |
| smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS) |
| smi_roformer = ( |
| AutoModelForMaskedLM |
| .from_pretrained(SMI_MODEL_NAME) |
| .roformer |
| .to(DEVICE) |
| .eval() |
| ) |
| |
| smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
| smi_train["sequence"].astype(str).str.strip().tolist(), |
| smi_tok, smi_roformer, |
| batch_size=SMI_BATCH, |
| max_length=SMI_MAX_LEN, |
| ) |
| |
| smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both( |
| smi_val["sequence"].astype(str).str.strip().tolist(), |
| smi_tok, smi_roformer, |
| batch_size=SMI_BATCH, |
| max_length=SMI_MAX_LEN, |
| ) |
| |
| smi_train_ds = Dataset.from_dict({ |
| "target_sequence": smi_train["target_sequence"].tolist(), |
| "sequence": smi_train["sequence"].tolist(), |
| "label": smi_train["label"].astype(float).tolist(), |
| "target_embedding": smi_train_tgt_emb, |
| "embedding": smi_train_pooled.astype(np.float32), |
| COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(), |
| COL_AFF: smi_train[COL_AFF].astype(float).tolist(), |
| "affinity_class": smi_train["affinity_class"].tolist(), |
| }) |
| |
| smi_val_ds = Dataset.from_dict({ |
| "target_sequence": smi_val["target_sequence"].tolist(), |
| "sequence": smi_val["sequence"].tolist(), |
| "label": smi_val["label"].astype(float).tolist(), |
| "target_embedding": smi_val_tgt_emb, |
| "embedding": smi_val_pooled.astype(np.float32), |
| COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(), |
| COL_AFF: smi_val[COL_AFF].astype(float).tolist(), |
| "affinity_class": smi_val["affinity_class"].tolist(), |
| }) |
| |
| smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds}) |
| smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled" |
| smi_pooled_dd.save_to_disk(str(smi_pooled_out)) |
| log(f"Saved SMILES pooled -> {smi_pooled_out}") |
|
|
|
|
| |
| |
| |
| with section("WT unpooled paired embeddings + save"): |
| wt_tok_unpooled = wt_tok |
| wt_esm_unpooled = wt_esm |
|
|
| wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled" |
| wt_unpooled_dd = DatasetDict({ |
| "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train", |
| wt_tok_unpooled, wt_esm_unpooled), |
| "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val", |
| wt_tok_unpooled, wt_esm_unpooled), |
| }) |
| wt_unpooled_dd.save_to_disk(str(wt_unpooled_out)) |
| log(f"Saved WT unpooled -> {wt_unpooled_out}") |
|
|
|
|
| |
| |
| |
| with section("SMILES unpooled paired embeddings + save"): |
| smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled" |
| smi_unpooled_dd = DatasetDict({ |
| "train": build_smiles_unpooled_paired_dataset( |
| smi_train, smi_unpooled_out / "train", |
| wt_tok, wt_esm, |
| smi_tok, smi_roformer |
| ), |
| "val": build_smiles_unpooled_paired_dataset( |
| smi_val, smi_unpooled_out / "val", |
| wt_tok, wt_esm, |
| smi_tok, smi_roformer |
| ), |
| }) |
| smi_unpooled_dd.save_to_disk(str(smi_unpooled_out)) |
| log(f"Saved SMILES unpooled -> {smi_unpooled_out}") |
|
|
| log(f"\n[DONE] All datasets saved under: {OUT_ROOT}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|