"""Tests for sf_cluster: shapes, determinism, in-pool guarantee.""" from __future__ import annotations import os import sys from pathlib import Path import numpy as np import pytest # Allow `python -m pytest tests/` from the repo root before installing. sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) from sf_cluster import ( # noqa: E402 contrast_hvlv, high_variance_mask, method_gradient, method_mosaic, pool_msa, read_a3m, write_a3m, ) from sf_cluster.methods import N_SUBSETS, TARGET_SIZE # noqa: E402 # --------------------------------------------------------------------------- # fixtures # --------------------------------------------------------------------------- @pytest.fixture def synthetic_pool(tmp_path): """Synthetic A3M + FI matrix written to disk; returns paths.""" rng = np.random.default_rng(0) N, L = 200, 50 alphabet = np.array(list("ACDEFGHIKLMNPQRSTVWY-")) seqs = rng.choice(alphabet, size=(N, L)) a3m_path = tmp_path / "syn.a3m" with open(a3m_path, "w") as f: f.write(f"#{L}\t1\n") for i, row in enumerate(seqs): tag = "query" if i == 0 else f"seq{i:04d}" f.write(f">{tag}\n{''.join(row)}\n") fi = rng.normal(0, 0.3, size=(N, L)).astype(np.float64) hv_cols = rng.choice(L, size=L // 5, replace=False) fi[:, hv_cols] += rng.normal(0, 1.5, size=(N, len(hv_cols))) fi_path = tmp_path / "syn_fi.npy" np.save(fi_path, fi) return a3m_path, fi_path, N, L # --------------------------------------------------------------------------- # pool / a3m # --------------------------------------------------------------------------- def test_a3m_roundtrip(tmp_path): p = tmp_path / "rt.a3m" write_a3m(p, "#5\t1", [("query", "ACDEF"), ("h2 desc", "ACDef")]) hl, seqs = read_a3m(p) assert hl == "#5\t1" assert seqs == [("query", "ACDEF"), ("h2 desc", "ACDef")] def test_pool_shapes(synthetic_pool): a3m, fi, N, L = synthetic_pool pool = pool_msa(a3m, fi) assert pool.n_seq == N assert pool.n_cols == L assert pool.fi_matrix.shape == (N, L) assert len(pool.sequences) == N assert pool.headers[0] == "query" def test_pool_rejects_shape_mismatch(tmp_path, synthetic_pool): a3m, fi, N, L = synthetic_pool bad = tmp_path / "bad_fi.npy" np.save(bad, np.zeros((N + 1, L))) with pytest.raises(ValueError, match="FI rows"): pool_msa(a3m, bad) # --------------------------------------------------------------------------- # score # --------------------------------------------------------------------------- def test_hv_mask_fraction(): rng = np.random.default_rng(1) F = rng.normal(size=(100, 50)) hv = high_variance_mask(F, percentile=80) # At p=80 we expect ~20% True (allow some slack since percentile is a # threshold, not an exact split). frac = hv.mean() assert 0.1 <= frac <= 0.4 def test_contrast_hvlv_shape_and_finite(synthetic_pool): a3m, fi, N, L = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) assert score.shape == (N,) assert np.all(np.isfinite(score)) # --------------------------------------------------------------------------- # methods: mosaic # --------------------------------------------------------------------------- def test_mosaic_shapes(synthetic_pool): a3m, fi, N, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) subs = method_mosaic(score) assert len(subs) == N_SUBSETS for s in subs: assert len(s) == TARGET_SIZE def test_mosaic_determinism(synthetic_pool): a3m, fi, _, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) a = method_mosaic(score) b = method_mosaic(score) assert a == b def test_mosaic_in_pool(synthetic_pool): a3m, fi, N, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) subs = method_mosaic(score) for s in subs: assert all(0 <= i < N for i in s), "out-of-pool index in mosaic subset" def test_mosaic_tier_composition(synthetic_pool): """High tier draws should come from upper third of sorted score.""" a3m, fi, N, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) sorted_idx = np.argsort(score) high_set = set(sorted_idx[2 * N // 3:].tolist()) low_set = set(sorted_idx[: N // 3].tolist()) mid_set = set(sorted_idx[N // 3: 2 * N // 3].tolist()) subs = method_mosaic(score) # First 11 = high, next 11 = low, last 10 = mid. for s in subs: assert all(i in high_set for i in s[:11]) assert all(i in low_set for i in s[11:22]) assert all(i in mid_set for i in s[22:32]) # --------------------------------------------------------------------------- # methods: gradient # --------------------------------------------------------------------------- def test_gradient_shapes(synthetic_pool): a3m, fi, _, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) subs = method_gradient(score) assert len(subs) == N_SUBSETS for s in subs: assert len(s) == TARGET_SIZE def test_gradient_determinism(synthetic_pool): a3m, fi, _, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) a = method_gradient(score) b = method_gradient(score) assert a == b def test_gradient_in_pool_and_homogeneous(synthetic_pool): a3m, fi, N, _ = synthetic_pool pool = pool_msa(a3m, fi) score = contrast_hvlv(pool.fi_matrix) sorted_idx = np.argsort(score) bins = [] for b in range(4): bins.append(set(sorted_idx[(b * N) // 4: ((b + 1) * N) // 4].tolist())) subs = method_gradient(score) for grp_i in range(4): for s_i in range(3): sub = subs[grp_i * 3 + s_i] assert all(0 <= i < N for i in sub), "out-of-pool index" assert all(i in bins[grp_i] for i in sub), \ f"gradient subset {grp_i*3+s_i} leaked outside quartile {grp_i}" # --------------------------------------------------------------------------- # CLI smoke # --------------------------------------------------------------------------- def test_cli_build_smoke(tmp_path, synthetic_pool): from sf_cluster.cli import main as cli_main a3m, fi, _, _ = synthetic_pool out = tmp_path / "subs_mosaic" rc = cli_main([ "build", "--a3m", str(a3m), "--fi", str(fi), "--method", "mosaic", "--out", str(out), ]) assert rc == 0 files = sorted(out.glob("mosaic_subset_*.a3m")) assert len(files) == N_SUBSETS assert (out / "mosaic_subset_index.tsv").exists() assert (out / "mosaic_meta.json").exists()