"""Unit tests for helper functions.""" import types from pathlib import Path import pandas as pd import constraints import design import file_utils import viewers # --------------------------------------------------------------------------- # _get_file_path # --------------------------------------------------------------------------- class TestGetFilePath: """Normalizes Gradio's various file input formats to a Path.""" def test_string_input(self): assert file_utils._get_file_path("/some/path.pdb") == Path("/some/path.pdb") def test_object_with_path_attr(self): obj = types.SimpleNamespace(path="/uploads/file.pdb") assert file_utils._get_file_path(obj) == Path("/uploads/file.pdb") def test_dict_with_path_key(self): result = file_utils._get_file_path({"path": "/uploads/file.pdb", "name": "file.pdb"}) assert result == Path("/uploads/file.pdb") def test_fallback_to_str(self): assert file_utils._get_file_path(42) == Path("42") # --------------------------------------------------------------------------- # _build_pos_constraint_df # --------------------------------------------------------------------------- class TestBuildPosConstraintDf: """Builds a positional constraint DataFrame for caliby.""" def test_all_empty_returns_none(self): assert constraints._build_pos_constraint_df("1YCR", "", "", "", "", "") is None def test_all_whitespace_returns_none(self): assert constraints._build_pos_constraint_df("1YCR", " ", " ", " ", " ", " ") is None def test_single_field_populated(self): df = constraints._build_pos_constraint_df("1YCR", "A1-100", "", "", "", "") assert df is not None assert len(df) == 1 assert df.iloc[0]["pdb_key"] == "1YCR" assert df.iloc[0]["fixed_pos_seq"] == "A1-100" # Only populated columns + pdb_key should be present assert "fixed_pos_scn" not in df.columns def test_all_fields_populated(self): df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5") assert set(df.columns) == { "pdb_key", "fixed_pos_seq", "fixed_pos_scn", "fixed_pos_override_seq", "pos_restrict_aatype", "symmetry_pos", } def test_columns_match_caliby_valid_columns(self): """All columns must be in caliby's _VALID_POS_CONSTRAINT_COLUMNS.""" valid = { "pdb_key", "fixed_pos_seq", "fixed_pos_scn", "fixed_pos_override_seq", "pos_restrict_aatype", "symmetry_pos", } df = constraints._build_pos_constraint_df("X", "A1", "B2", "A3:G", "A4:V", "A5,B5") assert set(df.columns).issubset(valid) # --------------------------------------------------------------------------- # _df_to_csv # --------------------------------------------------------------------------- class TestDfToCsv: """Writes a DataFrame to a temp CSV file.""" def test_none_returns_none(self): assert file_utils._df_to_csv(None) is None def test_empty_dataframe_returns_none(self): assert file_utils._df_to_csv(pd.DataFrame()) is None def test_valid_dataframe_roundtrips(self): df = pd.DataFrame({"pdb_key": ["1YCR"], "fixed_pos_seq": ["A1-100"]}) path = file_utils._df_to_csv(df) assert path is not None assert Path(path).exists() assert path.endswith(".csv") loaded = pd.read_csv(path) pd.testing.assert_frame_equal(df, loaded) def test_uses_sample_name_for_csv_basename(self): df = pd.DataFrame( { "Sample": ["1YCR_sample0"], "Sequence": ["ACDE"], "Energy (U)": [-1.0], } ) path = file_utils._df_to_csv(df) assert path is not None assert Path(path).name == "1YCR_results.csv" class TestCsvDownloadOutput: """Formats CSV downloads for the Gradio file component.""" def test_hides_component_for_empty_dataframe(self): update = viewers._csv_download_output(pd.DataFrame()) assert update["visible"] is False assert update["value"] is None def test_shows_named_csv_for_results_dataframe(self): df = pd.DataFrame( { "Sample": ["1YCR_sample0"], "Sequence": ["ACDE"], "Energy (U)": [-1.0], } ) update = viewers._csv_download_output(df) assert update["visible"] is True assert Path(update["value"]).name == "1YCR_results.csv" class TestFormatResultsDisplay: """Formats the on-screen results table without changing the raw dataframe.""" def test_formats_last_four_numeric_columns(self): df = pd.DataFrame( { "Sample": ["1YCR_sample0"], "Sequence": ["ACDE"], "Energy (U)": [-1.2345], "sc_ca_rmsd": [1.0], "avg_ca_plddt": [88.888], "tmalign_score": [0.12345], } ) styler = viewers._format_results_display(df) html = styler.to_html() assert "-1.23" in html assert ">1<" in html assert "88.89" in html assert "0.12" in html # --------------------------------------------------------------------------- # _format_outputs # --------------------------------------------------------------------------- class TestFormatOutputs: """Formats caliby output dict into (DataFrame, FASTA, out_pdb_list).""" def test_dataframe_structure(self, sample_outputs_with_out_pdbs): df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs) assert list(df.columns) == ["Sample", "Sequence", "Energy (U)"] assert len(df) == 2 def test_sample_names_from_path_stems(self, sample_outputs_with_out_pdbs): df, _, _ = design._format_outputs(sample_outputs_with_out_pdbs) assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"] def test_fasta_format(self, sample_outputs_with_out_pdbs): _, fasta, _ = design._format_outputs(sample_outputs_with_out_pdbs) lines = fasta.strip().split("\n") assert lines[0] == ">1YCR_sample0" assert lines[1] == "MTEEQWAQ" assert lines[2] == ">1YCR_sample1" assert lines[3] == "VSEQQWAQ" def test_uses_caliby_out_pdb_key(self, sample_outputs): assert "out_pdbs" not in sample_outputs df, fasta, out_pdb_list = design._format_outputs(sample_outputs) assert list(df["Sample"]) == ["1YCR_sample0", "1YCR_sample1"] assert ">1YCR_sample0" in fasta assert out_pdb_list == sample_outputs["out_pdb"] # --------------------------------------------------------------------------- # _get_best_sc_sample # --------------------------------------------------------------------------- class TestGetBestScSample: """Picks the sample with the highest tmalign_score.""" def test_picks_highest_tmalign_score(self): df = pd.DataFrame( { "Sample": ["1YCR_sample0", "1YCR_sample1", "1YCR_sample2"], "tmalign_score": [0.5, 0.9, 0.7], } ) assert viewers._get_best_sc_sample(df) == "1YCR_sample1" def test_falls_back_to_first_when_no_tmalign(self): df = pd.DataFrame({"Sample": ["1YCR_sample0", "1YCR_sample1"]}) assert viewers._get_best_sc_sample(df) == "1YCR_sample0" def test_falls_back_to_first_when_all_nan(self): df = pd.DataFrame( { "Sample": ["A_sample0", "A_sample1"], "tmalign_score": [float("nan"), float("nan")], } ) assert viewers._get_best_sc_sample(df) == "A_sample0" def test_returns_none_for_empty_df(self): assert viewers._get_best_sc_sample(pd.DataFrame()) is None # --------------------------------------------------------------------------- # _render_af2_viewer / _render_reference_viewer # --------------------------------------------------------------------------- _MINIMAL_PDB = "ATOM 1 CA ALA A 1 0.000 0.000 0.000 1.00 90.00 C\nEND\n" class TestRenderAf2Viewer: """Renders AF2 prediction with pLDDT coloring via molview.""" def test_returns_html_with_valid_data(self): html = viewers._render_af2_viewer("test_sample0", {"test_sample0": _MINIMAL_PDB}) assert "iframe" in html def test_returns_empty_for_missing_sample(self): assert viewers._render_af2_viewer("missing", {"other": _MINIMAL_PDB}) == "" def test_returns_empty_for_none_sample(self): assert viewers._render_af2_viewer(None, {"test": _MINIMAL_PDB}) == "" def test_returns_empty_for_empty_data(self): assert viewers._render_af2_viewer("test", {}) == "" class TestRenderReferenceViewer: """Renders original input PDB with chain coloring via molview.""" def test_maps_sample_to_input_key(self): html = viewers._render_reference_viewer("1YCR_sample0", {"1YCR": _MINIMAL_PDB}) assert "iframe" in html def test_returns_empty_when_input_key_missing(self): assert viewers._render_reference_viewer("1YCR_sample0", {"OTHER": _MINIMAL_PDB}) == "" def test_returns_empty_for_none_sample(self): assert viewers._render_reference_viewer(None, {"1YCR": _MINIMAL_PDB}) == "" # --------------------------------------------------------------------------- # _update_viewers # --------------------------------------------------------------------------- class TestUpdateViewers: """Combined handler for overlay toggle.""" def test_overlay_off_hides_reference(self): af2_html, ref_update = viewers._update_viewers("s0", {"s0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, False) assert "iframe" in af2_html assert ref_update["visible"] is False def test_overlay_on_shows_reference(self): af2_html, ref_update = viewers._update_viewers( "s_sample0", {"s_sample0": _MINIMAL_PDB}, {"s": _MINIMAL_PDB}, True ) assert "iframe" in af2_html assert ref_update["visible"] is True assert "iframe" in ref_update["value"]