from __future__ import annotations from collections.abc import Callable from pathlib import Path from typing import Any, Protocol, cast from huggingface_hub import HfApi from slop_farmer.config import SaveCacheOptions from slop_farmer.data.parquet_io import read_json from slop_farmer.data.snapshot_paths import ROOT_MANIFEST_FILENAME, resolve_snapshot_dir_from_output ANALYSIS_STATE_DIRNAME = "analysis-state" class HubApiLike(Protocol): def create_repo( self, repo_id: str, *, repo_type: str, private: bool, exist_ok: bool, ) -> None: ... def upload_folder( self, *, repo_id: str, folder_path: Path, path_in_repo: str, repo_type: str, commit_message: str, ) -> None: ... def run_save_cache(options: SaveCacheOptions) -> dict[str, Any]: snapshot_dir = resolve_snapshot_dir_from_output(options.output_dir, options.snapshot_dir) return save_analysis_cache( snapshot_dir=snapshot_dir, hf_repo_id=options.hf_repo_id, private=options.private_hf_repo, ) def save_analysis_cache( *, snapshot_dir: Path, hf_repo_id: str, private: bool, log: Callable[[str], None] | None = None, ) -> dict[str, Any]: return _save_analysis_cache_api( cast("HubApiLike", HfApi()), snapshot_dir=snapshot_dir, hf_repo_id=hf_repo_id, private=private, log=log, ) def _save_analysis_cache_api( api: HubApiLike, *, snapshot_dir: Path, hf_repo_id: str, private: bool, log: Callable[[str], None] | None = None, ) -> dict[str, Any]: cache_dir = snapshot_dir / ANALYSIS_STATE_DIRNAME if not cache_dir.exists(): raise FileNotFoundError(f"Analysis cache directory is missing: {cache_dir}") if not cache_dir.is_dir(): raise NotADirectoryError(f"Analysis cache path is not a directory: {cache_dir}") artifact_paths = _cache_artifact_paths(cache_dir) if not artifact_paths: raise ValueError(f"Analysis cache directory is empty: {cache_dir}") manifest_path = snapshot_dir / ROOT_MANIFEST_FILENAME manifest = read_json(manifest_path) if manifest_path.exists() else {} if not isinstance(manifest, dict): raise ValueError(f"Snapshot manifest at {manifest_path} must contain a JSON object.") snapshot_id = str(manifest.get("snapshot_id") or snapshot_dir.name).strip() repo = str(manifest.get("repo") or "").strip() if log: log(f"Ensuring Hub dataset repo exists: {hf_repo_id}") api.create_repo(hf_repo_id, repo_type="dataset", private=private, exist_ok=True) if log: log(f"Saving analysis cache for snapshot {snapshot_id}") api.upload_folder( repo_id=hf_repo_id, folder_path=cache_dir, path_in_repo=ANALYSIS_STATE_DIRNAME, repo_type="dataset", commit_message=f"Save analysis cache for snapshot {snapshot_id}", ) result = { "dataset_id": hf_repo_id, "snapshot_id": snapshot_id, "artifact_paths": [f"{ANALYSIS_STATE_DIRNAME}/{path}" for path in artifact_paths], } if repo: result["repo"] = repo if log: log(f"Saved analysis cache to {hf_repo_id}") return result def _cache_artifact_paths(cache_dir: Path) -> list[str]: return sorted( str(path.relative_to(cache_dir).as_posix()) for path in cache_dir.rglob("*") if path.is_file() )