diffusers-pr-api / src /slop_farmer /app /save_cache.py
evalstate's picture
evalstate HF Staff
Deploy Diffusers PR API
dbf7313 verified
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
from typing import Any, Protocol, cast
from huggingface_hub import HfApi
from slop_farmer.config import SaveCacheOptions
from slop_farmer.data.parquet_io import read_json
from slop_farmer.data.snapshot_paths import ROOT_MANIFEST_FILENAME, resolve_snapshot_dir_from_output
ANALYSIS_STATE_DIRNAME = "analysis-state"
class HubApiLike(Protocol):
def create_repo(
self,
repo_id: str,
*,
repo_type: str,
private: bool,
exist_ok: bool,
) -> None: ...
def upload_folder(
self,
*,
repo_id: str,
folder_path: Path,
path_in_repo: str,
repo_type: str,
commit_message: str,
) -> None: ...
def run_save_cache(options: SaveCacheOptions) -> dict[str, Any]:
snapshot_dir = resolve_snapshot_dir_from_output(options.output_dir, options.snapshot_dir)
return save_analysis_cache(
snapshot_dir=snapshot_dir,
hf_repo_id=options.hf_repo_id,
private=options.private_hf_repo,
)
def save_analysis_cache(
*,
snapshot_dir: Path,
hf_repo_id: str,
private: bool,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
return _save_analysis_cache_api(
cast("HubApiLike", HfApi()),
snapshot_dir=snapshot_dir,
hf_repo_id=hf_repo_id,
private=private,
log=log,
)
def _save_analysis_cache_api(
api: HubApiLike,
*,
snapshot_dir: Path,
hf_repo_id: str,
private: bool,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
cache_dir = snapshot_dir / ANALYSIS_STATE_DIRNAME
if not cache_dir.exists():
raise FileNotFoundError(f"Analysis cache directory is missing: {cache_dir}")
if not cache_dir.is_dir():
raise NotADirectoryError(f"Analysis cache path is not a directory: {cache_dir}")
artifact_paths = _cache_artifact_paths(cache_dir)
if not artifact_paths:
raise ValueError(f"Analysis cache directory is empty: {cache_dir}")
manifest_path = snapshot_dir / ROOT_MANIFEST_FILENAME
manifest = read_json(manifest_path) if manifest_path.exists() else {}
if not isinstance(manifest, dict):
raise ValueError(f"Snapshot manifest at {manifest_path} must contain a JSON object.")
snapshot_id = str(manifest.get("snapshot_id") or snapshot_dir.name).strip()
repo = str(manifest.get("repo") or "").strip()
if log:
log(f"Ensuring Hub dataset repo exists: {hf_repo_id}")
api.create_repo(hf_repo_id, repo_type="dataset", private=private, exist_ok=True)
if log:
log(f"Saving analysis cache for snapshot {snapshot_id}")
api.upload_folder(
repo_id=hf_repo_id,
folder_path=cache_dir,
path_in_repo=ANALYSIS_STATE_DIRNAME,
repo_type="dataset",
commit_message=f"Save analysis cache for snapshot {snapshot_id}",
)
result = {
"dataset_id": hf_repo_id,
"snapshot_id": snapshot_id,
"artifact_paths": [f"{ANALYSIS_STATE_DIRNAME}/{path}" for path in artifact_paths],
}
if repo:
result["repo"] = repo
if log:
log(f"Saved analysis cache to {hf_repo_id}")
return result
def _cache_artifact_paths(cache_dir: Path) -> list[str]:
return sorted(
str(path.relative_to(cache_dir).as_posix())
for path in cache_dir.rglob("*")
if path.is_file()
)