Spaces:
Running
Running
| import os | |
| import sys | |
| import shutil | |
| import sqlite3 | |
| import json | |
| import time | |
| import hashlib | |
| from datetime import datetime | |
| from huggingface_hub import snapshot_download, HfApi | |
| # Configuration | |
| REPO_ID = os.environ.get("DATASET_REPO_ID") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| DATA_DIR = "data_repo" | |
| DB_FILE = os.path.join(DATA_DIR, "database.db") | |
| STATE_FILE = os.path.join(DATA_DIR, "sync_state.json") | |
| LOCK_FILE = "/tmp/hf_sync.lock" | |
| api = HfApi(token=HF_TOKEN) | |
| def get_state(): | |
| if os.path.exists(STATE_FILE): | |
| try: | |
| with open(STATE_FILE, 'r') as f: | |
| return json.load(f) | |
| except: pass | |
| return {"uploaded_files": {}, "last_db_hash": None, "version": 0} | |
| def save_state(state): | |
| state["last_update"] = datetime.now().isoformat() | |
| with open(STATE_FILE, 'w') as f: | |
| json.dump(state, f, indent=2) | |
| def get_file_hash(path): | |
| if not os.path.exists(path): return None | |
| hasher = hashlib.md5() | |
| with open(path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(4096), b""): hasher.update(chunk) | |
| return hasher.hexdigest() | |
| def safe_db_backup(): | |
| if not os.path.exists(DB_FILE): return None | |
| backup_db = DB_FILE + ".bak" | |
| try: | |
| source_conn = sqlite3.connect(DB_FILE) | |
| dest_conn = sqlite3.connect(backup_db) | |
| with dest_conn: source_conn.backup(dest_conn) | |
| source_conn.close(); dest_conn.close() | |
| return backup_db | |
| except Exception as e: | |
| print(f"Database backup failed: {e}") | |
| return None | |
| def upload(): | |
| if not REPO_ID or not HF_TOKEN: return | |
| if os.path.exists(LOCK_FILE): | |
| if time.time() - os.path.getmtime(LOCK_FILE) < 600: return | |
| try: | |
| with open(LOCK_FILE, 'w') as f: f.write(str(os.getpid())) | |
| state = get_state() | |
| changes_made = False | |
| # 1. Sync Database (Granular) | |
| backup_path = safe_db_backup() | |
| if backup_path: | |
| db_hash = get_file_hash(backup_path) | |
| if db_hash != state.get("last_db_hash"): | |
| print("Syncing Database...") | |
| shutil.move(backup_path, DB_FILE) | |
| api.upload_file(path_or_fileobj=DB_FILE, path_in_repo="database.db", repo_id=REPO_ID, repo_type="dataset") | |
| state["last_db_hash"] = db_hash | |
| changes_made = True | |
| else: | |
| os.remove(backup_path) | |
| # 2. Sync Files Iteratively (Immune to folder timeouts) | |
| for sub_dir in ['uploads', 'processed', 'output']: | |
| dir_path = os.path.join(DATA_DIR, sub_dir) | |
| if not os.path.exists(dir_path): continue | |
| for root, _, files in os.walk(dir_path): | |
| for file in files: | |
| full_path = os.path.join(root, file) | |
| rel_path = os.path.relpath(full_path, DATA_DIR) | |
| # Check if file needs upload (by size/mtime to avoid hashing thousands of images) | |
| mtime = os.path.getmtime(full_path) | |
| size = os.path.getsize(full_path) | |
| file_id = f"{rel_path}_{size}_{mtime}" | |
| if state["uploaded_files"].get(rel_path) != file_id: | |
| print(f"Syncing new file: {rel_path}") | |
| try: | |
| api.upload_file( | |
| path_or_fileobj=full_path, | |
| path_in_repo=rel_path, | |
| repo_id=REPO_ID, | |
| repo_type="dataset" | |
| ) | |
| state["uploaded_files"][rel_path] = file_id | |
| changes_made = True | |
| except Exception as e: | |
| print(f"Failed to upload {rel_path}: {e}") | |
| if changes_made: | |
| state["version"] += 1 | |
| save_state(state) | |
| # Sync state file too | |
| api.upload_file(path_or_fileobj=STATE_FILE, path_in_repo="sync_state.json", repo_id=REPO_ID, repo_type="dataset") | |
| print(f"Sync complete. Version {state['version']} saved.") | |
| else: | |
| print("Everything up to date.") | |
| except Exception as e: print(f"Upload process failed: {e}") | |
| finally: | |
| if os.path.exists(LOCK_FILE): os.remove(LOCK_FILE) | |
| def download(): | |
| if not REPO_ID: return | |
| print(f"Downloading data from {REPO_ID}...") | |
| try: | |
| snapshot_download(repo_id=REPO_ID, repo_type="dataset", local_dir=DATA_DIR, token=HF_TOKEN, max_workers=8) | |
| print("Download successful.") | |
| except Exception as e: print(f"Download failed: {e}") | |
| def init_local(): | |
| for d in ['output', 'processed', 'uploads']: os.makedirs(f"{DATA_DIR}/{d}", exist_ok=True) | |
| if __name__ == "__main__": | |
| action = sys.argv[1] if len(sys.argv) > 1 else "help" | |
| if action == "download": download() | |
| elif action == "upload": upload() | |
| elif action == "init": init_local() | |
| else: print("Usage: python hf_sync.py [download|upload|init]") | |