ESMFold2-Fast / esmfold2_processor.py
lhallee's picture
Upload folder using huggingface_hub
fb8a87c verified
import random
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
from .esmfold2_conformers import load_ccd
from .esmfold2_output import build_molecular_complex_from_features
from .esmfold2_prepare_input import ChainInfo, prepare_esmfold2_input
from .esmfold2_types import (
MSA,
Modification,
ProteinInput,
StructurePredictionInput,
)
from .esmfold2_molecular_complex import MolecularComplexResult
@contextmanager
def _seed_context(seed: int | None):
if seed is None:
yield
return
py_state = random.getstate()
np_state = np.random.get_state()
torch_state = torch.random.get_rng_state()
cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
try:
yield
finally:
random.setstate(py_state)
np.random.set_state(np_state)
torch.random.set_rng_state(torch_state)
if cuda_state is not None:
torch.cuda.set_rng_state_all(cuda_state)
def clean_esmfold2_input(input: StructurePredictionInput) -> StructurePredictionInput:
"""Group identical protein sequences into the same ProteinInput with multiple ids.
Example: Passing a tetramer like [ProteinInput(id=["0"], seq="AAA|AAA|BBB|BBB")]
gets converted into [ProteinInput(id=["0_0", "0_1"], seq="AAA"),
ProteinInput(id=["0_2", "0_3"], seq="BBB")]
Preserves the original order of unique sequences. Also converts "|" chainbreak
tokens to ":" in the sequence.
"""
cleaned_sequences: list = []
chain_to_ids: dict[str, list[str]] = {}
chain_to_modifications: dict[str, list] = {}
chain_to_msa: dict[str, MSA | None] = {}
for item in input.sequences:
if isinstance(item, ProteinInput):
sequence = ":".join(item.sequence.split("|"))
if ":" not in sequence:
cleaned_sequences.append(item)
continue
if ":" in sequence and input.covalent_bonds is not None:
raise ValueError(
"Covalent bonds are not supported when using chainbreaks. "
"Chains must be separated into multiple ProteinInput objects."
)
base_id = item.id[0] if isinstance(item.id, list) else item.id
chain_to_ids = {}
chain_to_modifications = {}
chain_to_msa = {}
chains = sequence.split(":")
chain_start_positions = []
pos = 0
for chain in chains:
chain_start_positions.append(pos)
pos += len(chain) + 1
if item.modifications is not None:
for chain_idx, chain in enumerate(chains):
chain_start = chain_start_positions[chain_idx]
chain_end = chain_start + len(chain)
chain_modifications = []
for mod in item.modifications:
if chain_start <= mod.position < chain_end:
adjusted_mod = Modification(
position=mod.position - chain_start, ccd=mod.ccd
)
chain_modifications.append(adjusted_mod)
if chain not in chain_to_modifications:
chain_to_modifications[chain] = chain_modifications
else:
chain_to_modifications[chain].extend(chain_modifications)
if item.msa is not None:
for chain_idx, chain in enumerate(chains):
if chain not in chain_to_msa:
chain_start = chain_start_positions[chain_idx]
chain_end = chain_start + len(chain)
chain_msa = item.msa.select_positions( # type: ignore
np.arange(chain_start, chain_end)
)
chain_to_msa[chain] = chain_msa
for i, chain in enumerate(chains):
chain_id = base_id + "_" + str(i)
if chain in chain_to_ids:
chain_to_ids[chain].append(chain_id)
else:
chain_to_ids[chain] = [chain_id]
cleaned_sequences.append((item, chain))
else:
cleaned_sequences.append(item)
for i in range(len(cleaned_sequences)):
if isinstance(cleaned_sequences[i], tuple):
item, chain = cleaned_sequences[i]
chain_ids = chain_to_ids[chain]
chain_modifications = (
chain_to_modifications.get(chain) if item.modifications else None
)
chain_msa = chain_to_msa.get(chain) if item.msa else None
cleaned_sequences[i] = ProteinInput(
id=chain_ids,
sequence=chain,
msa=chain_msa,
modifications=chain_modifications,
)
return StructurePredictionInput(
sequences=cleaned_sequences,
distogram_conditioning=input.distogram_conditioning,
covalent_bonds=input.covalent_bonds,
)
class ESMFold2InputBuilder:
def __init__(self, ccd_cache: Path | None = None):
load_ccd(ccd_cache)
def prepare_input(
self,
input: StructurePredictionInput,
seed: int | None = None,
device: torch.device | str | None = None,
) -> tuple[dict, list[ChainInfo]]:
"""Prepare raw input for the folding model.
Converts user-provided StructurePredictionInput into batched tensors
ready for model inference.
Parameters
----------
input : StructurePredictionInput
Input specification (sequences, structures, constraints, etc.).
seed : int, optional
Random seed for reproducibility.
device : torch.device or str, optional
Target device for the returned tensors. Defaults to CPU; pass
``model.device`` to skip a separate ``.to(...)`` step. ``fold()``
forwards ``model.device`` automatically.
Returns
-------
tuple[dict, list[ChainInfo]]
Batched input tensors and chain metadata for output processing.
"""
structure_prediction_input = clean_esmfold2_input(input)
with _seed_context(seed) if seed is not None else nullcontext():
features, chain_infos = prepare_esmfold2_input(
structure_prediction_input, seed=seed
)
features = {
k: (v[None].to(device) if device is not None else v[None])
if isinstance(v, torch.Tensor)
else v
for k, v in features.items()
}
return features, chain_infos
def __call__(
self,
input: StructurePredictionInput,
seed: int | None = None,
device: torch.device | str | None = None,
) -> tuple[dict, list[ChainInfo]]:
return self.prepare_input(input, seed=seed, device=device)
def decode(
self,
output: dict[str, torch.Tensor],
features: dict[str, torch.Tensor],
chain_infos: list[ChainInfo],
*,
num_diffusion_samples: int = 1,
complex_id: str = "pred",
) -> MolecularComplexResult | list[MolecularComplexResult]:
"""Convert raw model outputs into one MolecularComplexResult per sample.
Parameters
----------
output : dict[str, Tensor]
Output dict returned by ESMFold2Model.forward.
features : dict[str, Tensor]
Feature dict from :meth:`prepare_input` (batched, on the model device).
chain_infos : list[ChainInfo]
Chain metadata returned alongside `features`.
num_diffusion_samples : int
Number of diffusion samples present in the output (Bm = B * num_diffusion_samples).
complex_id : str
Identifier assigned to each MolecularComplex.
Returns
-------
MolecularComplexResult or list[MolecularComplexResult]
A single result when num_diffusion_samples == 1, otherwise a list of length Bm.
"""
atom_mask = features["atom_attention_mask"][0]
ref_element = features["ref_element"][0]
ref_atom_name_chars = features["ref_atom_name_chars"][0]
sample_coords = output["sample_atom_coords"]
plddts = output["plddt"]
Bm = sample_coords.shape[0]
ptm_t = output.get("ptm")
iptm_t = output.get("iptm")
pae_t = output.get("pae")
distogram_t = output.get("distogram_logits")
pair_chains_t = output.get("pair_chains_iptm")
residue_index_t = output.get("residue_index")
entity_id_t = output.get("entity_id")
results: list[MolecularComplexResult] = []
for i in range(Bm):
mc = build_molecular_complex_from_features(
coords=sample_coords[i],
plddt=plddts[i],
atom_mask=atom_mask,
ref_element=ref_element,
ref_atom_name_chars=ref_atom_name_chars,
chain_infos=chain_infos,
complex_id=complex_id,
)
results.append(
MolecularComplexResult(
complex=mc,
plddt=plddts[i].detach().cpu(),
ptm=float(ptm_t[i].item()) if ptm_t is not None else None,
iptm=float(iptm_t[i].item()) if iptm_t is not None else None,
pae=pae_t[i].detach().cpu() if pae_t is not None else None,
distogram=(
distogram_t[0].detach().cpu()
if distogram_t is not None
else None
),
pair_chains_iptm=(
pair_chains_t[i].detach().cpu()
if pair_chains_t is not None
else None
),
residue_index=(
residue_index_t[0].detach().cpu()
if residue_index_t is not None
else None
),
entity_id=(
entity_id_t[0].detach().cpu()
if entity_id_t is not None
else None
),
)
)
if num_diffusion_samples == 1 and len(results) == 1:
return results[0]
return results
def fold(
self,
model: Any,
input: StructurePredictionInput,
*,
num_loops: int = 3,
num_sampling_steps: int = 200,
num_diffusion_samples: int = 1,
seed: int | None = None,
noise_scale: float | None = None,
step_scale: float | None = None,
max_inference_sigma: int | None = None,
early_exit: bool = False,
complex_id: str = "pred",
) -> MolecularComplexResult | list[MolecularComplexResult]:
"""Fold a structure end-to-end: encode → model → decode.
Parameters
----------
model : ESMFold2Model
The folding model. Must already be on the target device and in eval mode.
input : StructurePredictionInput
User-facing input specification.
num_loops, num_sampling_steps, num_diffusion_samples : int
Inference knobs forwarded to the model.
seed : int, optional
Seeds both input prep (SMILES conformer generation) and diffusion sampling.
noise_scale, step_scale, max_inference_sigma, early_exit
Optional sampler overrides forwarded to the model when not None.
complex_id : str
Identifier assigned to the predicted MolecularComplex(es).
Returns
-------
MolecularComplexResult or list[MolecularComplexResult]
A single result when num_diffusion_samples == 1, otherwise a list.
"""
features, chain_infos = self.prepare_input(
input, seed=seed, device=model.device
)
sampler_kwargs: dict[str, Any] = {}
if noise_scale is not None:
sampler_kwargs["noise_scale"] = noise_scale
if step_scale is not None:
sampler_kwargs["step_scale"] = step_scale
if max_inference_sigma is not None:
sampler_kwargs["max_inference_sigma"] = max_inference_sigma
with torch.no_grad():
with _seed_context(seed) if seed is not None else nullcontext():
output = model(
**features,
num_loops=num_loops,
num_sampling_steps=num_sampling_steps,
num_diffusion_samples=num_diffusion_samples,
early_exit=early_exit,
**sampler_kwargs,
)
return self.decode(
output,
features,
chain_infos,
num_diffusion_samples=num_diffusion_samples,
complex_id=complex_id,
)
__all__ = ["ESMFold2InputBuilder", "clean_esmfold2_input"]