| | |
| | """ |
| | Decode all 80 generated sequences and test them with HMD-AMP. |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | import pandas as pd |
| | from Bio import SeqIO |
| | from Bio.SeqRecord import SeqRecord |
| | from Bio.Seq import Seq |
| | import os |
| | from datetime import datetime |
| | from tqdm import tqdm |
| | import sys |
| |
|
| | |
| | from final_sequence_decoder import EmbeddingToSequenceConverter |
| |
|
| | |
| | sys.path.append('/home/edwardsun/flow/HMD-AMP') |
| | from sklearn.utils import shuffle |
| | import esm |
| | from deepforest import CascadeForestClassifier |
| | from src.utils import * |
| |
|
| | def load_generated_embeddings(): |
| | """Load all generated embeddings from today.""" |
| | base_path = '/data2/edwardsun/generated_samples' |
| | today = '20250829' |
| | |
| | files = [ |
| | f'generated_amps_best_model_no_cfg_{today}.pt', |
| | f'generated_amps_best_model_weak_cfg_{today}.pt', |
| | f'generated_amps_best_model_strong_cfg_{today}.pt', |
| | f'generated_amps_best_model_very_strong_cfg_{today}.pt' |
| | ] |
| | |
| | all_embeddings = [] |
| | all_labels = [] |
| | |
| | for file in files: |
| | file_path = os.path.join(base_path, file) |
| | if os.path.exists(file_path): |
| | print(f"Loading {file}...") |
| | embeddings = torch.load(file_path, map_location='cpu') |
| | |
| | |
| | if 'no_cfg' in file: |
| | cfg_type = 'no_cfg' |
| | elif 'weak_cfg' in file: |
| | cfg_type = 'weak_cfg' |
| | elif 'strong_cfg' in file and 'very' not in file: |
| | cfg_type = 'strong_cfg' |
| | elif 'very_strong_cfg' in file: |
| | cfg_type = 'very_strong_cfg' |
| | |
| | |
| | for i in range(embeddings.shape[0]): |
| | all_embeddings.append(embeddings[i]) |
| | all_labels.append(f"{cfg_type}_{i+1}") |
| | |
| | print(f"β Loaded {len(all_embeddings)} embeddings total") |
| | return all_embeddings, all_labels |
| |
|
| | def decode_embeddings_to_sequences(embeddings, labels): |
| | """Decode embeddings to sequences.""" |
| | print("Initializing sequence decoder...") |
| | decoder = EmbeddingToSequenceConverter(device='cuda') |
| | |
| | sequences = [] |
| | sequence_ids = [] |
| | |
| | print("Decoding embeddings to sequences...") |
| | for i, (embedding, label) in enumerate(tqdm(zip(embeddings, labels), total=len(embeddings))): |
| | |
| | sequence = decoder.embedding_to_sequence( |
| | embedding, |
| | method='diverse', |
| | temperature=0.8 |
| | ) |
| | sequences.append(sequence) |
| | sequence_ids.append(f"generated_seq_{i+1}_{label}") |
| | |
| | return sequences, sequence_ids |
| |
|
| | def save_sequences_as_fasta(sequences, sequence_ids, filename): |
| | """Save sequences as FASTA file.""" |
| | records = [] |
| | for seq_id, seq in zip(sequence_ids, sequences): |
| | record = SeqRecord(Seq(seq), id=seq_id, description="") |
| | records.append(record) |
| | |
| | SeqIO.write(records, filename, "fasta") |
| | print(f"β Saved {len(sequences)} sequences to {filename}") |
| |
|
| | def test_with_hmd_amp(sequences, sequence_ids): |
| | """Test sequences with HMD-AMP classifier.""" |
| | print("\n𧬠Testing sequences with HMD-AMP classifier...") |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | ftmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/ft_parts.pth' |
| | clsmodel_save_path = '/home/edwardsun/flow/HMD-AMP/AMP/clsmodel' |
| | |
| | |
| | temp_fasta = 'temp_sequences.fasta' |
| | save_sequences_as_fasta(sequences, sequence_ids, temp_fasta) |
| | |
| | try: |
| | |
| | seq_embeddings, _, seq_ids = amp_feature_extraction(ftmodel_save_path, device, temp_fasta) |
| | |
| | |
| | cls_model = CascadeForestClassifier() |
| | cls_model.load(clsmodel_save_path) |
| | |
| | |
| | binary_pred = cls_model.predict(seq_embeddings) |
| | |
| | print(f"π HMD-AMP Results:") |
| | print(f"Total sequences: {len(sequences)}") |
| | print(f"Predicted AMPs: {np.sum(binary_pred)} ({np.sum(binary_pred)/len(sequences)*100:.1f}%)") |
| | print(f"Predicted non-AMPs: {len(sequences) - np.sum(binary_pred)} ({(len(sequences) - np.sum(binary_pred))/len(sequences)*100:.1f}%)") |
| | |
| | |
| | results_df = pd.DataFrame({ |
| | 'ID': sequence_ids, |
| | 'Sequence': sequences, |
| | 'AMP_Prediction': binary_pred, |
| | 'CFG_Type': [seq_id.split('_')[-2] for seq_id in sequence_ids] |
| | }) |
| | |
| | |
| | cfg_analysis = results_df.groupby('CFG_Type')['AMP_Prediction'].agg(['count', 'sum', 'mean']).round(3) |
| | cfg_analysis.columns = ['Total', 'Predicted_AMPs', 'AMP_Rate'] |
| | |
| | print(f"\nπ Results by CFG Configuration:") |
| | print(cfg_analysis) |
| | |
| | |
| | amp_results = results_df[results_df['AMP_Prediction'] == 1] |
| | if len(amp_results) > 0: |
| | print(f"\nπ Sequences predicted as AMPs ({len(amp_results)}):") |
| | for idx, row in amp_results.iterrows(): |
| | seq = row['Sequence'] |
| | cationic = seq.count('K') + seq.count('R') |
| | net_charge = seq.count('K') + seq.count('R') + seq.count('H') - seq.count('D') - seq.count('E') |
| | print(f" {row['ID']}: {seq}") |
| | print(f" Length: {len(seq)}, Cationic (K+R): {cationic}, Net charge: {net_charge:+d}") |
| | else: |
| | print(f"\nβ No sequences predicted as AMPs") |
| | |
| | |
| | results_df.to_csv('hmd_amp_detailed_results.csv', index=False) |
| | cfg_analysis.to_csv('hmd_amp_cfg_analysis.csv') |
| | |
| | print(f"\nπΎ Results saved:") |
| | print(f" - hmd_amp_detailed_results.csv (detailed per-sequence results)") |
| | print(f" - hmd_amp_cfg_analysis.csv (summary by CFG type)") |
| | |
| | return results_df, cfg_analysis |
| | |
| | finally: |
| | |
| | if os.path.exists(temp_fasta): |
| | os.remove(temp_fasta) |
| |
|
| | def main(): |
| | print("π Starting sequence decoding and HMD-AMP testing...") |
| | |
| | |
| | embeddings, labels = load_generated_embeddings() |
| | |
| | |
| | sequences, sequence_ids = decode_embeddings_to_sequences(embeddings, labels) |
| | |
| | |
| | fasta_filename = f'generated_sequences_{datetime.now().strftime("%Y%m%d_%H%M%S")}.fasta' |
| | save_sequences_as_fasta(sequences, sequence_ids, fasta_filename) |
| | |
| | |
| | results_df, cfg_analysis = test_with_hmd_amp(sequences, sequence_ids) |
| | |
| | print(f"\nβ
Complete! Generated and tested {len(sequences)} sequences") |
| | print(f"π Sequences saved as: {fasta_filename}") |
| | |
| | |
| | total_amps = results_df['AMP_Prediction'].sum() |
| | print(f"\nπ FINAL SUMMARY:") |
| | print(f"Generated sequences: {len(sequences)}") |
| | print(f"HMD-AMP predicted AMPs: {total_amps}/{len(sequences)} ({total_amps/len(sequences)*100:.1f}%)") |
| | |
| | if total_amps > 0: |
| | print(f"β¨ Success! Your flow model generated {total_amps} sequences that HMD-AMP classifies as AMPs!") |
| | else: |
| | print(f"π No sequences classified as AMPs - this may indicate the need for stronger AMP conditioning.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|