| |
| import matplotlib.pyplot as plt |
| import matplotlib as mpl |
| import numpy as np |
| import os |
| import pandas as pd |
| from lightning.pytorch import seed_everything |
| import torch |
| from tqdm import tqdm |
| from datasets import Dataset, DatasetDict, Features, Value, Sequence |
| from transformers import AutoModelForMaskedLM |
| import sys |
| from transformers import AutoTokenizer, EsmModel |
| from datasets import Dataset, DatasetDict |
| import tqdm |
|
|
| seed_everything(1986) |
| |
| |
| |
| m1 = [ |
| '[PAD]','A','R','N','D','C','Q','E','G','H', |
| 'I','L','K','M','F','P','S','T','W','Y','V' |
| ] |
| m2 = dict(zip( |
| ['[PAD]','[UNK]','[CLS]','[SEP]','[MASK]','L', |
| 'A','G','V','E','S','I','K','R','D','T','P','N', |
| 'Q','F','Y','M','H','C','W','X','U','B','Z','O'], |
| range(30) |
| )) |
| |
| reverse_m2 = {v: k for k, v in m2.items()} |
| sequences = [] |
| labels = [] |
|
|
| |
| print("Processing positive sequences...") |
| with np.load('nf-positive.npz') as pos: |
| pos_data = pos['arr_0'] |
| for seq in pos_data: |
| sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
| sequences.append(sequence) |
| labels.append(1) |
|
|
| |
| print("Processing negative sequences...") |
| with np.load('nf-negative.npz') as neg: |
| neg_data = neg['arr_0'] |
| for seq in neg_data: |
| sequence = ''.join(reverse_m2[token] for token in seq if token != 0) |
| sequences.append(sequence) |
| labels.append(0) |
| |
| |
| ids = [f"seq_{i:06d}" for i in range(len(sequences))] |
| df = pd.DataFrame({ |
| "id": ids, |
| "sequence": sequences, |
| "label": labels, |
| }) |
| print("Before dedup:", len(df)) |
|
|
| df = ( |
| df |
| .drop_duplicates(subset=["sequence"]) |
| .reset_index(drop=True) |
| ) |
|
|
| print("After dedup:", len(df)) |
| |
| df.to_csv("nf_all.csv", index=False) |
| print("Saved nf_all.csv") |
|
|
| |
| with open("nf_all.fasta", "w") as f: |
| for seq_id, seq in zip(df["id"], df["sequence"]): |
| f.write(f">{seq_id}\n{seq}\n") |
|
|
| print("Saved nf_all.fasta") |
|
|
| |
| |
| |
|
|
| |
| """ |
| mkdir -p mmseqs_tmp |
| |
| mmseqs createdb nf_all.fasta nfDB |
| |
| mmseqs cluster nfDB nfDB_clu mmseqs_tmp \ |
| --min-seq-id 0.3 -c 0.8 --cov-mode 0 |
| |
| mmseqs createtsv nfDB nfDB nfDB_clu clusters-nf.tsv |
| """ |
| |
|
|
|
|
| |
| |
| |
|
|
| train_fraction = 0.8 |
| csv_path = "nf_all.csv" |
| clusters_tsv = "clusters-nf.tsv" |
| rng = np.random.default_rng() |
|
|
| df = pd.read_csv(csv_path) |
|
|
| |
| id_to_index = {sid: i for i, sid in enumerate(df["id"])} |
|
|
| |
| cluster_map = {} |
| with open(clusters_tsv) as f: |
| for line in f: |
| if not line.strip(): |
| continue |
| rep_id, member_id = line.strip().split('\t') |
| cluster_map[member_id] = rep_id |
|
|
| |
| for sid in df["id"]: |
| if sid not in cluster_map: |
| cluster_map[sid] = sid |
|
|
| |
| cluster_to_indices = {} |
| for sid, cid in cluster_map.items(): |
| idx = id_to_index[sid] |
| cluster_to_indices.setdefault(cid, []).append(idx) |
|
|
| |
| cluster_ids = list(cluster_to_indices.keys()) |
| rng.shuffle(cluster_ids) |
|
|
| |
| total_n = len(df) |
| train_target = int(train_fraction * total_n) |
|
|
| train_indices = [] |
| val_indices = [] |
| current_train = 0 |
|
|
| for cid in cluster_ids: |
| indices = cluster_to_indices[cid] |
| if current_train + len(indices) <= train_target: |
| train_indices.extend(indices) |
| current_train += len(indices) |
| else: |
| val_indices.extend(indices) |
|
|
| |
| split = np.full(total_n, "val", dtype=object) |
| split[train_indices] = "train" |
|
|
| |
| df_with_split = df.copy() |
| df_with_split["split"] = split |
| df_with_split.to_csv("nf_meta_with_split.csv", index=False) |
|
|
| |
| df_train = df_with_split[df_with_split["split"] == "train"].reset_index(drop=True) |
| df_val = df_with_split[df_with_split["split"] == "val"].reset_index(drop=True) |
|
|
| df_train.to_csv("nf_train.csv", index=False) |
| df_val.to_csv("nf_val.csv", index=False) |
|
|
| |
| print("Split counts:") |
| print(df_with_split["split"].value_counts()) |
| print() |
| print(f"Train size: {len(df_train)}") |
| print(f"Val size: {len(df_val)}") |
| print("Wrote:") |
| print(" - sol_meta_with_split.csv") |
| print(" - sol_train.csv") |
| print(" - sol_val.csv") |
|
|
|
|
| device = torch.device("cuda:0") |
| print(f"Using device: {device}") |
|
|
| meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
| save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings" |
|
|
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| model = model.to(device) |
| model.eval() |
|
|
|
|
| def compute_embeddings(sequences, batch_size=32): |
| """Return numpy array of shape (N, hidden_dim).""" |
| embeddings = [] |
| for i in tqdm.trange(0, len(sequences), batch_size): |
| batch_sequences = sequences[i:i + batch_size] |
|
|
| inputs = tokenizer( |
| batch_sequences, |
| padding=True, |
| max_length=1022, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| last_hidden_states = outputs.last_hidden_state |
|
|
| attention_mask = inputs["attention_mask"].unsqueeze(-1) |
| masked_hidden_states = last_hidden_states * attention_mask |
| sum_hidden_states = masked_hidden_states.sum(dim=1) |
| seq_lengths = attention_mask.sum(dim=1) |
| batch_embeddings = sum_hidden_states / seq_lengths |
|
|
| embeddings.append(batch_embeddings.cpu()) |
|
|
| return torch.cat(embeddings, dim=0).numpy() |
|
|
|
|
| def create_and_save_datasets(): |
| |
| meta = pd.read_csv(meta_path) |
| sequences = meta["sequence"].tolist() |
| labels = meta["label"].tolist() |
| splits = meta["split"].tolist() |
|
|
| print(f"Total sequences: {len(sequences)}") |
| print("Split counts:", pd.Series(splits).value_counts().to_dict()) |
|
|
| print("Computing ESM embeddings...") |
| embeddings = compute_embeddings(sequences) |
|
|
| full_ds = Dataset.from_dict({ |
| "sequence": sequences, |
| "embedding": embeddings, |
| "label": labels, |
| "split": splits, |
| }) |
|
|
| |
| train_ds = full_ds.filter(lambda x: x["split"] == "train") |
| val_ds = full_ds.filter(lambda x: x["split"] == "val") |
|
|
| train_ds = train_ds.remove_columns("split") |
| val_ds = val_ds.remove_columns("split") |
|
|
| ds_dict = DatasetDict({ |
| "train": train_ds, |
| "val": val_ds, |
| }) |
|
|
| ds_dict.save_to_disk(save_path) |
| print(f"Saved DatasetDict with train/val to: {save_path}") |
| print("Train size:", len(ds_dict["train"])) |
| print("Val size:", len(ds_dict["val"])) |
|
|
| return ds_dict |
|
|
|
|
| ds = create_and_save_datasets() |
|
|
| ex = ds["train"][0] |
| print("\nExample from train:") |
| print("Sequence:", ex["sequence"]) |
| print("Embedding shape:", np.array(ex["embedding"]).shape) |
| print("Label:", ex["label"]) |
|
|
| torch.cuda.empty_cache() |
|
|
| meta_path = "./Classifier_Weight/training_data_cleaned/nf/nf_meta_with_split.csv" |
| save_path = "./Classifier_Weight/training_data_cleaned/nf/nf_wt_with_embeddings_unpooled" |
|
|
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D", add_pooling_layer=False).to(device).eval() |
|
|
| cls_id = tokenizer.cls_token_id |
| eos_id = tokenizer.eos_token_id |
|
|
| @torch.no_grad() |
| def embed_one(seq, max_length=1022): |
| inputs = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| out = model(**inputs) |
| h = out.last_hidden_state[0] |
| attn = inputs["attention_mask"][0].bool() |
| ids = inputs["input_ids"][0] |
|
|
| keep = attn.clone() |
| if cls_id is not None: |
| keep &= (ids != cls_id) |
| if eos_id is not None: |
| keep &= (ids != eos_id) |
|
|
| hb = h[keep].detach().cpu().to(torch.float16).numpy() |
| return hb |
|
|
| H = 1280 |
| features = Features({ |
| "sequence": Value("string"), |
| "label": Value("int64"), |
| "embedding": Sequence(Sequence(Value("float16"), length=H)), |
| "attention_mask": Sequence(Value("int8")), |
| "length": Value("int64"), |
| }) |
|
|
| def make_generator(df): |
| for seq, lab in tqdm.tqdm(zip(df["sequence"].tolist(), df["label"].astype(int).tolist()), total=len(df)): |
| emb = embed_one(seq) |
| emb_list = emb.tolist() |
| li = len(emb_list) |
| yield { |
| "sequence": seq, |
| "label": int(lab), |
| "embedding": emb_list, |
| "attention_mask": [1] * li, |
| "length": li, |
| } |
|
|
| def build_and_save_split(df, out_dir): |
| ds = Dataset.from_generator(make_generator, gen_kwargs={"df": df}, features=features) |
| |
| ds.save_to_disk(out_dir, max_shard_size="1GB") |
| return ds |
|
|
| meta = pd.read_csv(meta_path) |
| train_df = meta[meta["split"] == "train"].reset_index(drop=True) |
| val_df = meta[meta["split"] == "val"].reset_index(drop=True) |
|
|
| train_dir = os.path.join(save_path, "train") |
| val_dir = os.path.join(save_path, "val") |
| os.makedirs(save_path, exist_ok=True) |
|
|
| train_ds = build_and_save_split(train_df, train_dir) |
| val_ds = build_and_save_split(val_df, val_dir) |
|
|
| ds_dict = DatasetDict({"train": train_ds, "val": val_ds}) |
| ds_dict.save_to_disk(save_path) |
| print(ds_dict) |
|
|