switch / app.py
factorstudios's picture
Update app.py
e330188 verified
import os
import json
import time
import asyncio
import aiohttp
import zipfile
import shutil
from typing import Dict, List, Set, Optional, Tuple, Any
from urllib.parse import quote
from datetime import datetime
from pathlib import Path
import io
from fastapi import FastAPI, BackgroundTasks, HTTPException, status
from pydantic import BaseModel, Field
from huggingface_hub import HfApi, hf_hub_download
# --- Configuration ---
AUTO_START_INDEX = 1 # Hardcoded default start index if no progress is found
FLOW_ID = os.getenv("FLOW_ID", "flow_default")
FLOW_PORT = int(os.getenv("FLOW_PORT", 8001))
HF_TOKEN = os.getenv("HF_TOKEN", "")
HF_AUDIO_DATASET_ID = os.getenv("HF_AUDIO_DATASET_ID", "Samfredoly/BG_VAUD")
HF_OUTPUT_DATASET_ID = os.getenv("HF_OUTPUT_DATASET_ID", "samfred2/ATO_TG")
# Progress and State Tracking
PROGRESS_FILE = Path("processing_progress.json")
HF_STATE_FILE = "processing_state_transcriptions.json"
LOCAL_STATE_FOLDER = Path(".state")
LOCAL_STATE_FOLDER.mkdir(exist_ok=True)
# Directory within the HF dataset where the audio files are located
AUDIO_FILE_PREFIX = "audio/"
# FIX: Updated server list based on the logs showing 'eliasishere' prefix
WHISPER_SERVERS = [
f"https://eliasishere-makeitfr-mineo-{i}.hf.space/transcribe" for i in range(1, 21)
]
# Temporary storage for audio files
TEMP_DIR = Path(f"temp_audio_{FLOW_ID}")
TEMP_DIR.mkdir(exist_ok=True)
# --- Models ---
class ProcessStartRequest(BaseModel):
start_index: int = Field(AUTO_START_INDEX, ge=1, description="The index number of the audio file to start processing from (1-indexed).")
class WhisperServer:
def __init__(self, url: str):
self.url = url
self.is_processing = False
self.total_processed = 0
self.total_time = 0.0
@property
def fps(self):
"""Files per second"""
return self.total_processed / self.total_time if self.total_time > 0 else 0
def release(self):
"""Release the server for a new file"""
self.is_processing = False
# Global state for whisper servers
servers = [WhisperServer(url) for url in WHISPER_SERVERS]
server_lock = asyncio.Lock()
server_index = 0 # For round-robin selection
# --- Progress and State Management Functions ---
def load_progress() -> Dict:
default_structure = {
"last_processed_index": 0,
"processed_files": {},
"file_list": [],
"uploaded_count": 0
}
if PROGRESS_FILE.exists():
try:
with PROGRESS_FILE.open('r') as f:
data = json.load(f)
for key, value in default_structure.items():
if key not in data:
data[key] = value
return data
except json.JSONDecodeError:
print(f"[{FLOW_ID}] WARNING: Progress file is corrupted.")
return default_structure
def save_progress(progress_data: Dict):
try:
with PROGRESS_FILE.open('w') as f:
json.dump(progress_data, f, indent=4)
except Exception as e:
print(f"[{FLOW_ID}] CRITICAL ERROR: Could not save progress: {e}")
def load_json_state(file_path: str, default_value: Dict[str, Any]) -> Dict[str, Any]:
if os.path.exists(file_path):
try:
with open(file_path, "r") as f:
data = json.load(f)
if "file_states" not in data or not isinstance(data["file_states"], dict):
data["file_states"] = {}
if "next_download_index" not in data:
data["next_download_index"] = 0
return data
except Exception:
pass
return default_value
def save_json_state(file_path: str, data: Dict[str, Any]):
with open(file_path, "w") as f:
json.dump(data, f, indent=2)
async def download_hf_state() -> Dict[str, Any]:
local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
default_state = {"next_download_index": 0, "file_states": {}}
try:
hf_hub_download(
repo_id=HF_OUTPUT_DATASET_ID,
filename=HF_STATE_FILE,
repo_type="dataset",
local_dir=LOCAL_STATE_FOLDER,
local_dir_use_symlinks=False,
token=HF_TOKEN
)
return load_json_state(str(local_path), default_state)
except Exception as e:
print(f"[{FLOW_ID}] Failed to download state file: {str(e)}")
return load_json_state(str(local_path), default_state)
async def upload_hf_state(state: Dict[str, Any]) -> bool:
local_path = LOCAL_STATE_FOLDER / HF_STATE_FILE
try:
save_json_state(str(local_path), state)
HfApi(token=HF_TOKEN).upload_file(
path_or_fileobj=str(local_path),
path_in_repo=HF_STATE_FILE,
repo_id=HF_OUTPUT_DATASET_ID,
repo_type="dataset",
commit_message=f"Update transcription state: next_index={state.get('next_download_index')}"
)
return True
except Exception as e:
print(f"[{FLOW_ID}] Failed to upload state file: {str(e)}")
return False
async def get_audio_file_list(progress_data: Dict) -> List[str]:
if progress_data.get('file_list'):
return progress_data['file_list']
try:
api = HfApi(token=HF_TOKEN)
repo_files = api.list_repo_files(repo_id=HF_AUDIO_DATASET_ID, repo_type="dataset")
wav_files = sorted([f for f in repo_files if f.endswith('.wav')])
progress_data['file_list'] = wav_files
save_progress(progress_data)
return wav_files
except Exception as e:
print(f"[{FLOW_ID}] Error fetching file list: {e}")
return []
# --- Core Processing Logic ---
async def transcribe_with_server(server: WhisperServer, wav_path: Path) -> Optional[Dict]:
start_time = time.time()
try:
async with aiohttp.ClientSession() as session:
with open(wav_path, 'rb') as f:
data = aiohttp.FormData()
data.add_field('file', f, filename=wav_path.name)
async with session.post(server.url, data=data, timeout=600) as resp:
if resp.status == 200:
result = await resp.json()
elapsed = time.time() - start_time
server.total_processed += 1
server.total_time += elapsed
return result
else:
print(f"[{FLOW_ID}] Server {server.url} returned status {resp.status}")
except Exception as e:
print(f"[{FLOW_ID}] Error transcribing with {server.url}: {e}")
return None
async def get_available_server() -> WhisperServer:
global server_index
while True:
async with server_lock:
# Round-robin check
for _ in range(len(servers)):
s = servers[server_index]
server_index = (server_index + 1) % len(servers)
if not s.is_processing:
s.is_processing = True
return s
await asyncio.sleep(1)
async def process_file_task(wav_file: str, state: Dict, progress: Dict):
server = await get_available_server()
try:
print(f"[{FLOW_ID}] Downloading {wav_file}...")
downloaded_path_str = hf_hub_download(
repo_id=HF_AUDIO_DATASET_ID,
filename=wav_file,
repo_type="dataset",
local_dir=TEMP_DIR,
local_dir_use_symlinks=False,
token=HF_TOKEN
)
wav_path = Path(downloaded_path_str)
if not wav_path.exists():
raise FileNotFoundError(f"Downloaded file not found at {wav_path}")
# Transcribe
result = await transcribe_with_server(server, wav_path)
if result:
json_filename = Path(wav_file).with_suffix('.json').name
json_content = json.dumps(result, indent=2, ensure_ascii=False).encode('utf-8')
api = HfApi(token=HF_TOKEN)
api.upload_file(
path_or_fileobj=io.BytesIO(json_content),
path_in_repo=json_filename,
repo_id=HF_OUTPUT_DATASET_ID,
repo_type="dataset",
commit_message=f"[{FLOW_ID}] Transcription for {wav_file}"
)
state["file_states"][wav_file] = "processed"
progress["uploaded_count"] = progress.get("uploaded_count", 0) + 1
print(f"[{FLOW_ID}] ✅ Success: {wav_file}")
else:
state["file_states"][wav_file] = "failed_transcription"
print(f"[{FLOW_ID}] ❌ Failed: {wav_file}")
if wav_path.exists():
wav_path.unlink()
except Exception as e:
print(f"[{FLOW_ID}] Error processing {wav_file}: {e}")
state["file_states"][wav_file] = "failed_transcription"
finally:
server.release()
async def main_processing_loop():
print(f"[{FLOW_ID}] Starting main processing loop...")
while True:
try:
state = await download_hf_state()
progress = load_progress()
file_list = await get_audio_file_list(progress)
if not file_list:
await asyncio.sleep(60)
continue
print(f"[{FLOW_ID}] Checking {HF_OUTPUT_DATASET_ID} for existing JSON outputs...")
try:
api = HfApi(token=HF_TOKEN)
existing_files = api.list_repo_files(repo_id=HF_OUTPUT_DATASET_ID, repo_type="dataset")
existing_json_files = {f for f in existing_files if f.endswith('.json')}
print(f"[{FLOW_ID}] Found {len(existing_json_files)} existing JSON files.")
except Exception as e:
print(f"[{FLOW_ID}] Warning: Could not fetch existing files: {e}")
existing_json_files = set()
failed_files = [f for f, s in state.get("file_states", {}).items() if s == "failed_transcription"]
next_idx = state.get("next_download_index", 0)
new_files_chunk = file_list[next_idx:next_idx + 1000]
files_to_check = failed_files + [f for f in new_files_chunk if f not in state["file_states"]]
if not files_to_check:
await asyncio.sleep(60)
continue
files_to_process = []
state_changed_locally = False
print(f"[{FLOW_ID}] Scanning {len(files_to_check)} files for existing results...")
for f in files_to_check:
expected_json_name = Path(f).with_suffix('.json').name
if expected_json_name in existing_json_files:
if state["file_states"].get(f) != "processed":
state["file_states"][f] = "processed"
state_changed_locally = True
if f in new_files_chunk:
current_idx = file_list.index(f)
if current_idx >= state.get("next_download_index", 0):
state["next_download_index"] = current_idx + 1
continue
print(f"[{FLOW_ID}] Found unprocessed file: {f}")
if state_changed_locally:
print(f"[{FLOW_ID}] Synchronizing skipped files to HF state...")
await upload_hf_state(state)
state_changed_locally = False
files_to_process.append(f)
break
if state_changed_locally and not files_to_process:
print(f"[{FLOW_ID}] Uploading final batch of skips...")
await upload_hf_state(state)
if not files_to_process:
continue
# Process the unprocessed file
batch_size = len(servers)
for i in range(0, len(files_to_process), batch_size):
batch = files_to_process[i:i + batch_size]
tasks = [process_file_task(f, state, progress) for f in batch]
await asyncio.gather(*tasks)
for f in batch:
if f in file_list:
current_idx = file_list.index(f)
if current_idx >= state.get("next_download_index", 0):
state["next_download_index"] = current_idx + 1
await upload_hf_state(state)
save_progress(progress)
await asyncio.sleep(2)
except Exception as e:
print(f"[{FLOW_ID}] Error in main loop: {e}")
await asyncio.sleep(60)
# --- FastAPI App ---
app = FastAPI(title=f"Flow Server {FLOW_ID} API")
@app.on_event("startup")
async def startup_event():
asyncio.create_task(main_processing_loop())
@app.get("/")
async def root():
progress = load_progress()
state = await download_hf_state()
failed_count = sum(1 for s in state.get("file_states", {}).values() if s == "failed_transcription")
return {
"flow_id": FLOW_ID,
"status": "running",
"next_download_index": state.get("next_download_index", 0),
"failed_transcriptions": failed_count,
"uploaded_count": progress.get("uploaded_count", 0),
"total_files_in_list": len(progress.get('file_list', []))
}
@app.post("/start_processing")
async def start_processing(request: ProcessStartRequest):
state = await download_hf_state()
state["next_download_index"] = request.start_index - 1
await upload_hf_state(state)
return {"status": "index_reset", "new_index": request.start_index}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=FLOW_PORT)