Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import random | |
| from contextlib import contextmanager, nullcontext | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| from .esmfold2_conformers import load_ccd | |
| from .esmfold2_output import build_molecular_complex_from_features | |
| from .esmfold2_prepare_input import ChainInfo, prepare_esmfold2_input | |
| from .esmfold2_types import ( | |
| MSA, | |
| Modification, | |
| ProteinInput, | |
| StructurePredictionInput, | |
| ) | |
| from .esmfold2_molecular_complex import MolecularComplexResult | |
| def _seed_context(seed: int | None): | |
| if seed is None: | |
| yield | |
| return | |
| py_state = random.getstate() | |
| np_state = np.random.get_state() | |
| torch_state = torch.random.get_rng_state() | |
| cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| try: | |
| yield | |
| finally: | |
| random.setstate(py_state) | |
| np.random.set_state(np_state) | |
| torch.random.set_rng_state(torch_state) | |
| if cuda_state is not None: | |
| torch.cuda.set_rng_state_all(cuda_state) | |
| def clean_esmfold2_input(input: StructurePredictionInput) -> StructurePredictionInput: | |
| """Group identical protein sequences into the same ProteinInput with multiple ids. | |
| Example: Passing a tetramer like [ProteinInput(id=["0"], seq="AAA|AAA|BBB|BBB")] | |
| gets converted into [ProteinInput(id=["0_0", "0_1"], seq="AAA"), | |
| ProteinInput(id=["0_2", "0_3"], seq="BBB")] | |
| Preserves the original order of unique sequences. Also converts "|" chainbreak | |
| tokens to ":" in the sequence. | |
| """ | |
| cleaned_sequences: list = [] | |
| chain_to_ids: dict[str, list[str]] = {} | |
| chain_to_modifications: dict[str, list] = {} | |
| chain_to_msa: dict[str, MSA | None] = {} | |
| for item in input.sequences: | |
| if isinstance(item, ProteinInput): | |
| sequence = ":".join(item.sequence.split("|")) | |
| if ":" not in sequence: | |
| cleaned_sequences.append(item) | |
| continue | |
| if ":" in sequence and input.covalent_bonds is not None: | |
| raise ValueError( | |
| "Covalent bonds are not supported when using chainbreaks. " | |
| "Chains must be separated into multiple ProteinInput objects." | |
| ) | |
| base_id = item.id[0] if isinstance(item.id, list) else item.id | |
| chain_to_ids = {} | |
| chain_to_modifications = {} | |
| chain_to_msa = {} | |
| chains = sequence.split(":") | |
| chain_start_positions = [] | |
| pos = 0 | |
| for chain in chains: | |
| chain_start_positions.append(pos) | |
| pos += len(chain) + 1 | |
| if item.modifications is not None: | |
| for chain_idx, chain in enumerate(chains): | |
| chain_start = chain_start_positions[chain_idx] | |
| chain_end = chain_start + len(chain) | |
| chain_modifications = [] | |
| for mod in item.modifications: | |
| if chain_start <= mod.position < chain_end: | |
| adjusted_mod = Modification( | |
| position=mod.position - chain_start, ccd=mod.ccd | |
| ) | |
| chain_modifications.append(adjusted_mod) | |
| if chain not in chain_to_modifications: | |
| chain_to_modifications[chain] = chain_modifications | |
| else: | |
| chain_to_modifications[chain].extend(chain_modifications) | |
| if item.msa is not None: | |
| for chain_idx, chain in enumerate(chains): | |
| if chain not in chain_to_msa: | |
| chain_start = chain_start_positions[chain_idx] | |
| chain_end = chain_start + len(chain) | |
| chain_msa = item.msa.select_positions( # type: ignore | |
| np.arange(chain_start, chain_end) | |
| ) | |
| chain_to_msa[chain] = chain_msa | |
| for i, chain in enumerate(chains): | |
| chain_id = base_id + "_" + str(i) | |
| if chain in chain_to_ids: | |
| chain_to_ids[chain].append(chain_id) | |
| else: | |
| chain_to_ids[chain] = [chain_id] | |
| cleaned_sequences.append((item, chain)) | |
| else: | |
| cleaned_sequences.append(item) | |
| for i in range(len(cleaned_sequences)): | |
| if isinstance(cleaned_sequences[i], tuple): | |
| item, chain = cleaned_sequences[i] | |
| chain_ids = chain_to_ids[chain] | |
| chain_modifications = ( | |
| chain_to_modifications.get(chain) if item.modifications else None | |
| ) | |
| chain_msa = chain_to_msa.get(chain) if item.msa else None | |
| cleaned_sequences[i] = ProteinInput( | |
| id=chain_ids, | |
| sequence=chain, | |
| msa=chain_msa, | |
| modifications=chain_modifications, | |
| ) | |
| return StructurePredictionInput( | |
| sequences=cleaned_sequences, | |
| distogram_conditioning=input.distogram_conditioning, | |
| covalent_bonds=input.covalent_bonds, | |
| ) | |
| class ESMFold2InputBuilder: | |
| def __init__(self, ccd_cache: Path | None = None): | |
| load_ccd(ccd_cache) | |
| def prepare_input( | |
| self, | |
| input: StructurePredictionInput, | |
| seed: int | None = None, | |
| device: torch.device | str | None = None, | |
| ) -> tuple[dict, list[ChainInfo]]: | |
| """Prepare raw input for the folding model. | |
| Converts user-provided StructurePredictionInput into batched tensors | |
| ready for model inference. | |
| Parameters | |
| ---------- | |
| input : StructurePredictionInput | |
| Input specification (sequences, structures, constraints, etc.). | |
| seed : int, optional | |
| Random seed for reproducibility. | |
| device : torch.device or str, optional | |
| Target device for the returned tensors. Defaults to CPU; pass | |
| ``model.device`` to skip a separate ``.to(...)`` step. ``fold()`` | |
| forwards ``model.device`` automatically. | |
| Returns | |
| ------- | |
| tuple[dict, list[ChainInfo]] | |
| Batched input tensors and chain metadata for output processing. | |
| """ | |
| structure_prediction_input = clean_esmfold2_input(input) | |
| with _seed_context(seed) if seed is not None else nullcontext(): | |
| features, chain_infos = prepare_esmfold2_input( | |
| structure_prediction_input, seed=seed | |
| ) | |
| features = { | |
| k: (v[None].to(device) if device is not None else v[None]) | |
| if isinstance(v, torch.Tensor) | |
| else v | |
| for k, v in features.items() | |
| } | |
| return features, chain_infos | |
| def __call__( | |
| self, | |
| input: StructurePredictionInput, | |
| seed: int | None = None, | |
| device: torch.device | str | None = None, | |
| ) -> tuple[dict, list[ChainInfo]]: | |
| return self.prepare_input(input, seed=seed, device=device) | |
| def decode( | |
| self, | |
| output: dict[str, torch.Tensor], | |
| features: dict[str, torch.Tensor], | |
| chain_infos: list[ChainInfo], | |
| *, | |
| num_diffusion_samples: int = 1, | |
| complex_id: str = "pred", | |
| ) -> MolecularComplexResult | list[MolecularComplexResult]: | |
| """Convert raw model outputs into one MolecularComplexResult per sample. | |
| Parameters | |
| ---------- | |
| output : dict[str, Tensor] | |
| Output dict returned by ESMFold2Model.forward. | |
| features : dict[str, Tensor] | |
| Feature dict from :meth:`prepare_input` (batched, on the model device). | |
| chain_infos : list[ChainInfo] | |
| Chain metadata returned alongside `features`. | |
| num_diffusion_samples : int | |
| Number of diffusion samples present in the output (Bm = B * num_diffusion_samples). | |
| complex_id : str | |
| Identifier assigned to each MolecularComplex. | |
| Returns | |
| ------- | |
| MolecularComplexResult or list[MolecularComplexResult] | |
| A single result when num_diffusion_samples == 1, otherwise a list of length Bm. | |
| """ | |
| atom_mask = features["atom_attention_mask"][0] | |
| ref_element = features["ref_element"][0] | |
| ref_atom_name_chars = features["ref_atom_name_chars"][0] | |
| sample_coords = output["sample_atom_coords"] | |
| plddts = output["plddt"] | |
| Bm = sample_coords.shape[0] | |
| ptm_t = output.get("ptm") | |
| iptm_t = output.get("iptm") | |
| pae_t = output.get("pae") | |
| distogram_t = output.get("distogram_logits") | |
| pair_chains_t = output.get("pair_chains_iptm") | |
| residue_index_t = output.get("residue_index") | |
| entity_id_t = output.get("entity_id") | |
| results: list[MolecularComplexResult] = [] | |
| for i in range(Bm): | |
| mc = build_molecular_complex_from_features( | |
| coords=sample_coords[i], | |
| plddt=plddts[i], | |
| atom_mask=atom_mask, | |
| ref_element=ref_element, | |
| ref_atom_name_chars=ref_atom_name_chars, | |
| chain_infos=chain_infos, | |
| complex_id=complex_id, | |
| ) | |
| results.append( | |
| MolecularComplexResult( | |
| complex=mc, | |
| plddt=plddts[i].detach().cpu(), | |
| ptm=float(ptm_t[i].item()) if ptm_t is not None else None, | |
| iptm=float(iptm_t[i].item()) if iptm_t is not None else None, | |
| pae=pae_t[i].detach().cpu() if pae_t is not None else None, | |
| distogram=( | |
| distogram_t[0].detach().cpu() | |
| if distogram_t is not None | |
| else None | |
| ), | |
| pair_chains_iptm=( | |
| pair_chains_t[i].detach().cpu() | |
| if pair_chains_t is not None | |
| else None | |
| ), | |
| residue_index=( | |
| residue_index_t[0].detach().cpu() | |
| if residue_index_t is not None | |
| else None | |
| ), | |
| entity_id=( | |
| entity_id_t[0].detach().cpu() | |
| if entity_id_t is not None | |
| else None | |
| ), | |
| ) | |
| ) | |
| if num_diffusion_samples == 1 and len(results) == 1: | |
| return results[0] | |
| return results | |
| def fold( | |
| self, | |
| model: Any, | |
| input: StructurePredictionInput, | |
| *, | |
| num_loops: int = 3, | |
| num_sampling_steps: int = 200, | |
| num_diffusion_samples: int = 1, | |
| seed: int | None = None, | |
| noise_scale: float | None = None, | |
| step_scale: float | None = None, | |
| max_inference_sigma: int | None = None, | |
| early_exit: bool = False, | |
| complex_id: str = "pred", | |
| ) -> MolecularComplexResult | list[MolecularComplexResult]: | |
| """Fold a structure end-to-end: encode → model → decode. | |
| Parameters | |
| ---------- | |
| model : ESMFold2Model | |
| The folding model. Must already be on the target device and in eval mode. | |
| input : StructurePredictionInput | |
| User-facing input specification. | |
| num_loops, num_sampling_steps, num_diffusion_samples : int | |
| Inference knobs forwarded to the model. | |
| seed : int, optional | |
| Seeds both input prep (SMILES conformer generation) and diffusion sampling. | |
| noise_scale, step_scale, max_inference_sigma, early_exit | |
| Optional sampler overrides forwarded to the model when not None. | |
| complex_id : str | |
| Identifier assigned to the predicted MolecularComplex(es). | |
| Returns | |
| ------- | |
| MolecularComplexResult or list[MolecularComplexResult] | |
| A single result when num_diffusion_samples == 1, otherwise a list. | |
| """ | |
| features, chain_infos = self.prepare_input( | |
| input, seed=seed, device=model.device | |
| ) | |
| sampler_kwargs: dict[str, Any] = {} | |
| if noise_scale is not None: | |
| sampler_kwargs["noise_scale"] = noise_scale | |
| if step_scale is not None: | |
| sampler_kwargs["step_scale"] = step_scale | |
| if max_inference_sigma is not None: | |
| sampler_kwargs["max_inference_sigma"] = max_inference_sigma | |
| with torch.no_grad(): | |
| with _seed_context(seed) if seed is not None else nullcontext(): | |
| output = model( | |
| **features, | |
| num_loops=num_loops, | |
| num_sampling_steps=num_sampling_steps, | |
| num_diffusion_samples=num_diffusion_samples, | |
| early_exit=early_exit, | |
| **sampler_kwargs, | |
| ) | |
| return self.decode( | |
| output, | |
| features, | |
| chain_infos, | |
| num_diffusion_samples=num_diffusion_samples, | |
| complex_id=complex_id, | |
| ) | |
| __all__ = ["ESMFold2InputBuilder", "clean_esmfold2_input"] | |