Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| import asyncio | |
| import aiohttp | |
| import zipfile | |
| import shutil | |
| from typing import Dict, List, Set, Optional, Tuple, Any | |
| from urllib.parse import quote | |
| from datetime import datetime | |
| from pathlib import Path | |
| import io | |
| from fastapi import FastAPI, BackgroundTasks, HTTPException, status | |
| from pydantic import BaseModel, Field | |
| from huggingface_hub import HfApi, hf_hub_download | |
| # --- Configuration --- | |
| AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found | |
| FLOW_ID = os.getenv("FLOW_ID", "flow_default") | |
| FLOW_PORT = int(os.getenv("FLOW_PORT", 8001)) | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD") | |
| HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG") | |
| # Progress and State Tracking | |
| PROGRESS_FILE = Path("processing_progress.json") | |
| HF_STATE_FILE = "processing_state_transcriptions.json" | |
| LOCAL_STATE_FOLDER = Path(".state") | |
| LOCAL_STATE_FOLDER.mkdir(exist_ok=True) | |
| # Directory within the HF dataset where the audio files are located | |
| AUDIO_FILE_PREFIX = "audio/" | |
| # FIX: Updated server list based on the logs showing 'eliasishere' prefix | |
| WHISPER_SERVERS = [ | |
| f"https://eliasishere-makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21) | |
| ] | |
| # Temporary storage for audio files | |
| TEMP_DIR = Path(f"temp_audio_{FLOW_ID}") | |
| TEMP_DIR.mkdir(exist_ok=True) | |
| # --- Models --- | |
| class ProcessStartRequest(BaseModel): | |
| start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).") | |
| class WhisperServer: | |
| def __init__(self, url: str): | |
| self.url = url | |
| self.is_processing = False | |
| self.total_processed = 0 | |
| self.total_time = 0.0 | |
| def fps(self): | |
| """Files per second""" | |
| return self.total_processed / self.total_time if self.total_time > 0 else 0 | |
| def release(self): | |
| """Release the server for a new file""" | |
| self.is_processing = False | |
| # Global state for whisper servers | |
| servers = [WhisperServer(url) for url in WHISPER_SERVERS] | |
| server_lock = asyncio.Lock() | |
| server_index = 0 # For round-robin selection | |
| # --- Progress and State Management Functions --- | |
| def load_progress() -> Dict: | |
| default_structure = { | |
| "last_processed_index": 0, | |
| "processed_files": {}, | |
| "file_list": [], | |
| "uploaded_count": 0 | |
| } | |
| if PROGRESS_FILE.exists(): | |
| try: | |
| with PROGRESS_FILE.open('r') as f: | |
| data = json.load(f) | |
| for key, value in default_structure.items(): | |
| if key not in data: | |
| data[key] = value | |
| return data | |
| except json.JSONDecodeError: | |
| print(f"[{FLOW_ID}] WARNING: Progress file is corrupted.") | |
| return default_structure | |
| def save_progress(progress_data: Dict): | |
| try: | |
| with PROGRESS_FILE.open('w') as f: | |
| json.dump(progress_data, f, indent=4) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress: {e}") | |
| def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]: | |
| if os.path.exists(file_path): | |
| try: | |
| with open(file_path, "r") as f: | |
| data = json.load(f) | |
| if "file_states" not in data or not isinstance(data["file_states"], dict): | |
| data["file_states"] = {} | |
| if "next_download_index" not in data: | |
| data["next_download_index"] = 0 | |
| return data | |
| except Exception: | |
| pass | |
| return default_value | |
| def save_json_state(file_path: str, data: Dict[str, Any]): | |
| with open(file_path, "w") as f: | |
| json.dump(data, f, indent=2) | |
| async def download_hf_state() -> Dict[str, Any]: | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| default_state = {"next_download_index": 0, "file_states": {}} | |
| try: | |
| hf_hub_download( | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| filename=HF_STATE_FILE, | |
| repo_type="dataset", | |
| local_dir=LOCAL_STATE_FOLDER, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN | |
| ) | |
| return load_json_state(str(local_path), default_state) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to download state file: {str(e)}") | |
| return load_json_state(str(local_path), default_state) | |
| async def upload_hf_state(state: Dict[str, Any]) -> bool: | |
| local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE | |
| try: | |
| save_json_state(str(local_path), state) | |
| HfApi(token=HF_TOKEN).upload_file( | |
| path_or_fileobj=str(local_path), | |
| path_in_repo=HF_STATE_FILE, | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"Update transcription state: next_index={state.get('next_download_index')}" | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}") | |
| return False | |
| async def get_audio_file_list(progress_data: Dict) -> List[str]: | |
| if progress_data.get('file_list'): | |
| return progress_data['file_list'] | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset") | |
| wav_files = sorted([f for f in repo_files if f.endswith('.wav')]) | |
| progress_data['file_list'] = wav_files | |
| save_progress(progress_data) | |
| return wav_files | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error fetching file list: {e}") | |
| return [] | |
| # --- Core Processing Logic --- | |
| async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]: | |
| start_time = time.time() | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| with open(wav_path, 'rb') as f: | |
| data = aiohttp.FormData() | |
| data.add_field('file', f, filename=wav_path.name) | |
| async with session.post(server.url, data=data, timeout=600) as resp: | |
| if resp.status == 200: | |
| result = await resp.json() | |
| elapsed = time.time() - start_time | |
| server.total_processed += 1 | |
| server.total_time += elapsed | |
| return result | |
| else: | |
| print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}") | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}") | |
| return None | |
| async def get_available_server() -> WhisperServer: | |
| global server_index | |
| while True: | |
| async with server_lock: | |
| # Round-robin check | |
| for _ in range(len(servers)): | |
| s = servers[server_index] | |
| server_index = (server_index + 1) % len(servers) | |
| if not s.is_processing: | |
| s.is_processing = True | |
| return s | |
| await asyncio.sleep(1) | |
| async def process_file_task(wav_file: str, state: Dict, progress: Dict): | |
| server = await get_available_server() | |
| try: | |
| print(f"[{FLOW_ID}] Downloading {wav_file}...") | |
| downloaded_path_str = hf_hub_download( | |
| repo_id=HF_AUDIO_DATASET_ID, | |
| filename=wav_file, | |
| repo_type="dataset", | |
| local_dir=TEMP_DIR, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN | |
| ) | |
| wav_path = Path(downloaded_path_str) | |
| if not wav_path.exists(): | |
| raise FileNotFoundError(f"Downloaded file not found at {wav_path}") | |
| # Transcribe | |
| result = await transcribe_with_server(server, wav_path) | |
| if result: | |
| json_filename = Path(wav_file).with_suffix('.json').name | |
| json_content = json.dumps(result, indent=2, ensure_ascii=False).encode('utf-8') | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(json_content), | |
| path_in_repo=json_filename, | |
| repo_id=HF_OUTPUT_DATASET_ID, | |
| repo_type="dataset", | |
| commit_message=f"[{FLOW_ID}] Transcription for {wav_file}" | |
| ) | |
| state["file_states"][wav_file] = "processed" | |
| progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1 | |
| print(f"[{FLOW_ID}] ✅ Success: {wav_file}") | |
| else: | |
| state["file_states"][wav_file] = "failed_transcription" | |
| print(f"[{FLOW_ID}] ❌ Failed: {wav_file}") | |
| if wav_path.exists(): | |
| wav_path.unlink() | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error processing {wav_file}: {e}") | |
| state["file_states"][wav_file] = "failed_transcription" | |
| finally: | |
| server.release() | |
| async def main_processing_loop(): | |
| print(f"[{FLOW_ID}] Starting main processing loop...") | |
| while True: | |
| try: | |
| state = await download_hf_state() | |
| progress = load_progress() | |
| file_list = await get_audio_file_list(progress) | |
| if not file_list: | |
| await asyncio.sleep(60) | |
| continue | |
| print(f"[{FLOW_ID}] Checking {HF_OUTPUT_DATASET_ID} for existing JSON outputs...") | |
| try: | |
| api = HfApi(token=HF_TOKEN) | |
| existing_files = api.list_repo_files(repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset") | |
| existing_json_files = {f for f in existing_files if f.endswith('.json')} | |
| print(f"[{FLOW_ID}] Found {len(existing_json_files)} existing JSON files.") | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Warning: Could not fetch existing files: {e}") | |
| existing_json_files = set() | |
| failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"] | |
| next_idx = state.get("next_download_index", 0) | |
| new_files_chunk = file_list[next_idx:next_idx + 1000] | |
| files_to_check = failed_files + [f for f in new_files_chunk if f not in state["file_states"]] | |
| if not files_to_check: | |
| await asyncio.sleep(60) | |
| continue | |
| files_to_process = [] | |
| state_changed_locally = False | |
| print(f"[{FLOW_ID}] Scanning {len(files_to_check)} files for existing results...") | |
| for f in files_to_check: | |
| expected_json_name = Path(f).with_suffix('.json').name | |
| if expected_json_name in existing_json_files: | |
| if state["file_states"].get(f) != "processed": | |
| state["file_states"][f] = "processed" | |
| state_changed_locally = True | |
| if f in new_files_chunk: | |
| current_idx = file_list.index(f) | |
| if current_idx >= state.get("next_download_index", 0): | |
| state["next_download_index"] = current_idx + 1 | |
| continue | |
| print(f"[{FLOW_ID}] Found unprocessed file: {f}") | |
| if state_changed_locally: | |
| print(f"[{FLOW_ID}] Synchronizing skipped files to HF state...") | |
| await upload_hf_state(state) | |
| state_changed_locally = False | |
| files_to_process.append(f) | |
| break | |
| if state_changed_locally and not files_to_process: | |
| print(f"[{FLOW_ID}] Uploading final batch of skips...") | |
| await upload_hf_state(state) | |
| if not files_to_process: | |
| continue | |
| # Process the unprocessed file | |
| batch_size = len(servers) | |
| for i in range(0, len(files_to_process), batch_size): | |
| batch = files_to_process[i:i + batch_size] | |
| tasks = [process_file_task(f, state, progress) for f in batch] | |
| await asyncio.gather(*tasks) | |
| for f in batch: | |
| if f in file_list: | |
| current_idx = file_list.index(f) | |
| if current_idx >= state.get("next_download_index", 0): | |
| state["next_download_index"] = current_idx + 1 | |
| await upload_hf_state(state) | |
| save_progress(progress) | |
| await asyncio.sleep(2) | |
| except Exception as e: | |
| print(f"[{FLOW_ID}] Error in main loop: {e}") | |
| await asyncio.sleep(60) | |
| # --- FastAPI App --- | |
| app = FastAPI(title=f"Flow Server {FLOW_ID} API") | |
| async def startup_event(): | |
| asyncio.create_task(main_processing_loop()) | |
| async def root(): | |
| progress = load_progress() | |
| state = await download_hf_state() | |
| failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription") | |
| return { | |
| "flow_id": FLOW_ID, | |
| "status": "running", | |
| "next_download_index": state.get("next_download_index", 0), | |
| "failed_transcriptions": failed_count, | |
| "uploaded_count": progress.get("uploaded_count", 0), | |
| "total_files_in_list": len(progress.get('file_list', [])) | |
| } | |
| async def start_processing(request: ProcessStartRequest): | |
| state = await download_hf_state() | |
| state["next_download_index"] = request.start_index - 1 | |
| await upload_hf_state(state) | |
| return {"status": "index_reset", "new_index": request.start_index} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT) | |