import os import json import time import asyncio import aiohttp import zipfile import shutil from typing import Dict, List, Set, Optional, Tuple, Any from urllib.parse import quote from datetime import datetime from pathlib import Path import io from fastapi import FastAPI, BackgroundTasks, HTTPException, status from pydantic import BaseModel, Field from huggingface_hub import HfApi, hf_hub_download # --- Configuration --- AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found FLOW_ID = os.getenv("FLOW_ID", "flow_default") FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) HF_TOKEN = os.getenv("HF_TOKEN", "") HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD") HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG") # Progress and State Tracking PROGRESS_FILE = Path("processing_progress.json") HF_STATE_FILE = "processing_state_transcriptions.json" LOCAL_STATE_FOLDER = Path(".state") LOCAL_STATE_FOLDER.mkdir(exist_ok=True) # Directory within the HF dataset where the audio files are located AUDIO_FILE_PREFIX = "audio/" # FIX: Updated server list based on the logs showing 'eliasishere' prefix WHISPER_SERVERS = [ f"https://eliasishere-makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21) ] # Temporary storage for audio files TEMP_DIR = Path(f"temp_audio_{FLOW_ID}") TEMP_DIR.mkdir(exist_ok=True) # --- Models --- class ProcessStartRequest(BaseModel): start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).") class WhisperServer: def __init__(self, url: str): self.url = url self.is_processing = False self.total_processed = 0 self.total_time = 0.0 @property def fps(self): """Files per second""" return self.total_processed / self.total_time if self.total_time > 0 else 0 def release(self): """Release the server for a new file""" self.is_processing = False # Global state for whisper servers servers = [WhisperServer(url) for url in WHISPER_SERVERS] server_lock = asyncio.Lock() server_index = 0 # For round-robin selection # --- Progress and State Management Functions --- def load_progress() -> Dict: default_structure = { "last_processed_index": 0, "processed_files": {}, "file_list": [], "uploaded_count": 0 } if PROGRESS_FILE.exists(): try: with PROGRESS_FILE.open('r') as f: data = json.load(f) for key, value in default_structure.items(): if key not in data: data[key] = value return data except json.JSONDecodeError: print(f"[{FLOW_ID}] WARNING: Progress file is corrupted.") return default_structure def save_progress(progress_data: Dict): try: with PROGRESS_FILE.open('w') as f: json.dump(progress_data, f, indent=4) except Exception as e: print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress: {e}") def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]: if os.path.exists(file_path): try: with open(file_path, "r") as f: data = json.load(f) if "file_states" not in data or not isinstance(data["file_states"], dict): data["file_states"] = {} if "next_download_index" not in data: data["next_download_index"] = 0 return data except Exception: pass return default_value def save_json_state(file_path: str, data: Dict[str, Any]): with open(file_path, "w") as f: json.dump(data, f, indent=2) async def download_hf_state() -> Dict[str, Any]: local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE default_state = {"next_download_index": 0, "file_states": {}} try: hf_hub_download( repo_id=HF_OUTPUT_DATASET_ID, filename=HF_STATE_FILE, repo_type="dataset", local_dir=LOCAL_STATE_FOLDER, local_dir_use_symlinks=False, token=HF_TOKEN ) return load_json_state(str(local_path), default_state) except Exception as e: print(f"[{FLOW_ID}] Failed to download state file: {str(e)}") return load_json_state(str(local_path), default_state) async def upload_hf_state(state: Dict[str, Any]) -> bool: local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE try: save_json_state(str(local_path), state) HfApi(token=HF_TOKEN).upload_file( path_or_fileobj=str(local_path), path_in_repo=HF_STATE_FILE, repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset", commit_message=f"Update transcription state: next_index={state.get('next_download_index')}" ) return True except Exception as e: print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}") return False async def get_audio_file_list(progress_data: Dict) -> List[str]: if progress_data.get('file_list'): return progress_data['file_list'] try: api = HfApi(token=HF_TOKEN) repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset") wav_files = sorted([f for f in repo_files if f.endswith('.wav')]) progress_data['file_list'] = wav_files save_progress(progress_data) return wav_files except Exception as e: print(f"[{FLOW_ID}] Error fetching file list: {e}") return [] # --- Core Processing Logic --- async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]: start_time = time.time() try: async with aiohttp.ClientSession() as session: with open(wav_path, 'rb') as f: data = aiohttp.FormData() data.add_field('file', f, filename=wav_path.name) async with session.post(server.url, data=data, timeout=600) as resp: if resp.status == 200: result = await resp.json() elapsed = time.time() - start_time server.total_processed += 1 server.total_time += elapsed return result else: print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}") except Exception as e: print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}") return None async def get_available_server() -> WhisperServer: global server_index while True: async with server_lock: # Round-robin check for _ in range(len(servers)): s = servers[server_index] server_index = (server_index + 1) % len(servers) if not s.is_processing: s.is_processing = True return s await asyncio.sleep(1) async def process_file_task(wav_file: str, state: Dict, progress: Dict): server = await get_available_server() try: print(f"[{FLOW_ID}] Downloading {wav_file}...") downloaded_path_str = hf_hub_download( repo_id=HF_AUDIO_DATASET_ID, filename=wav_file, repo_type="dataset", local_dir=TEMP_DIR, local_dir_use_symlinks=False, token=HF_TOKEN ) wav_path = Path(downloaded_path_str) if not wav_path.exists(): raise FileNotFoundError(f"Downloaded file not found at {wav_path}") # Transcribe result = await transcribe_with_server(server, wav_path) if result: json_filename = Path(wav_file).with_suffix('.json').name json_content = json.dumps(result, indent=2, ensure_ascii=False).encode('utf-8') api = HfApi(token=HF_TOKEN) api.upload_file( path_or_fileobj=io.BytesIO(json_content), path_in_repo=json_filename, repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset", commit_message=f"[{FLOW_ID}] Transcription for {wav_file}" ) state["file_states"][wav_file] = "processed" progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1 print(f"[{FLOW_ID}] ✅ Success: {wav_file}") else: state["file_states"][wav_file] = "failed_transcription" print(f"[{FLOW_ID}] ❌ Failed: {wav_file}") if wav_path.exists(): wav_path.unlink() except Exception as e: print(f"[{FLOW_ID}] Error processing {wav_file}: {e}") state["file_states"][wav_file] = "failed_transcription" finally: server.release() async def main_processing_loop(): print(f"[{FLOW_ID}] Starting main processing loop...") while True: try: state = await download_hf_state() progress = load_progress() file_list = await get_audio_file_list(progress) if not file_list: await asyncio.sleep(60) continue print(f"[{FLOW_ID}] Checking {HF_OUTPUT_DATASET_ID} for existing JSON outputs...") try: api = HfApi(token=HF_TOKEN) existing_files = api.list_repo_files(repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset") existing_json_files = {f for f in existing_files if f.endswith('.json')} print(f"[{FLOW_ID}] Found {len(existing_json_files)} existing JSON files.") except Exception as e: print(f"[{FLOW_ID}] Warning: Could not fetch existing files: {e}") existing_json_files = set() failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"] next_idx = state.get("next_download_index", 0) new_files_chunk = file_list[next_idx:next_idx + 1000] files_to_check = failed_files + [f for f in new_files_chunk if f not in state["file_states"]] if not files_to_check: await asyncio.sleep(60) continue files_to_process = [] state_changed_locally = False print(f"[{FLOW_ID}] Scanning {len(files_to_check)} files for existing results...") for f in files_to_check: expected_json_name = Path(f).with_suffix('.json').name if expected_json_name in existing_json_files: if state["file_states"].get(f) != "processed": state["file_states"][f] = "processed" state_changed_locally = True if f in new_files_chunk: current_idx = file_list.index(f) if current_idx >= state.get("next_download_index", 0): state["next_download_index"] = current_idx + 1 continue print(f"[{FLOW_ID}] Found unprocessed file: {f}") if state_changed_locally: print(f"[{FLOW_ID}] Synchronizing skipped files to HF state...") await upload_hf_state(state) state_changed_locally = False files_to_process.append(f) break if state_changed_locally and not files_to_process: print(f"[{FLOW_ID}] Uploading final batch of skips...") await upload_hf_state(state) if not files_to_process: continue # Process the unprocessed file batch_size = len(servers) for i in range(0, len(files_to_process), batch_size): batch = files_to_process[i:i + batch_size] tasks = [process_file_task(f, state, progress) for f in batch] await asyncio.gather(*tasks) for f in batch: if f in file_list: current_idx = file_list.index(f) if current_idx >= state.get("next_download_index", 0): state["next_download_index"] = current_idx + 1 await upload_hf_state(state) save_progress(progress) await asyncio.sleep(2) except Exception as e: print(f"[{FLOW_ID}] Error in main loop: {e}") await asyncio.sleep(60) # --- FastAPI App --- app = FastAPI(title=f"Flow Server {FLOW_ID} API") @app.on_event("startup") async def startup_event(): asyncio.create_task(main_processing_loop()) @app.get("/") async def root(): progress = load_progress() state = await download_hf_state() failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription") return { "flow_id": FLOW_ID, "status": "running", "next_download_index": state.get("next_download_index", 0), "failed_transcriptions": failed_count, "uploaded_count": progress.get("uploaded_count", 0), "total_files_in_list": len(progress.get('file_list', [])) } @app.post("/start_processing") async def start_processing(request: ProcessStartRequest): state = await download_hf_state() state["next_download_index"] = request.start_index - 1 await upload_hf_state(state) return {"status": "index_reset", "new_index": request.start_index} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)