import os import time import jwt import logging import asyncio import hashlib import tempfile import subprocess from datetime import datetime, timedelta, timezone from typing import Dict, List, Any import aiohttp import librosa import uvicorn from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.middleware.cors import CORSMiddleware # --- 1. CONFIGURATION --- class GlobalConfig: # Set these in HF Space Secrets HF_TOKEN = os.getenv("HF_TOKEN", "") API_SECRET = os.getenv("API_SECRET_KEY", "default_secret_change_me_in_production") MODELS = { "emotion2vec": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "w": 0.50}, "meralion": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "w": 0.25}, "wav2vec2": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "w": 0.15}, "hubert": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "w": 0.07}, "gigam": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "w": 0.03} } # Standardized internal labels MAPPING = { "angry": ["ang", "fear"], # Merging high-arousal negative "happy": ["hap", "joy", "surp"], "sad": ["sad"], "neutral": ["neu", "calm"] } cfg = GlobalConfig() logging.basicConfig(level=logging.INFO) logger = logging.getLogger("EmotionAPI") # --- 2. AUTHENTICATION --- security = HTTPBearer() def create_access_token(data: dict): to_encode = data.copy() expire = datetime.now(timezone.utc) + timedelta(minutes=60) to_encode.update({"exp": expire}) return jwt.encode(to_encode, cfg.API_SECRET, algorithm="HS256") async def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)): try: payload = jwt.decode(credentials.credentials, cfg.API_SECRET, algorithms=["HS256"]) return payload except Exception: raise HTTPException(status_code=401, detail="Invalid/Expired Token") # --- 3. CORE LOGIC --- async def process_audio(file: UploadFile): """Handles format conversion and validation""" suffix = f".{file.filename.split('.')[-1]}" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_in: content = await file.read() tmp_in.write(content) input_path = tmp_in.name output_path = input_path + ".wav" try: # Standardize to 16kHz Mono WAV proc = subprocess.run( ["ffmpeg", "-i", input_path, "-ar", "16000", "-ac", "1", "-y", output_path], capture_output=True, text=True ) if proc.returncode != 0: raise Exception(f"FFmpeg error: {proc.stderr}") with open(output_path, "rb") as f: audio_bytes = f.read() duration = librosa.get_duration(path=output_path) return audio_bytes, duration finally: for p in [input_path, output_path]: if os.path.exists(p): os.unlink(p) async def query_hf(session, name, url, data): """Individual model call with retry for 'loading' status""" headers = {"Authorization": f"Bearer {cfg.HF_TOKEN}"} for _ in range(3): # Simple retry if model is loading async with session.post(url, headers=headers, data=data) as resp: res = await resp.json() if resp.status == 200: return res elif resp.status == 503: # Model loading await asyncio.sleep(5) continue return None def ensemble_logic(responses: dict): """Weighted average of results""" final_scores = defaultdict(float) for name, preds in responses.items(): if not isinstance(preds, list): continue weight = cfg.MODELS[name]["w"] for p in preds: label = p['label'].lower() # Map labels to our standard set mapped = "neutral" for std, keywords in cfg.MAPPING.items(): if any(k in label for k in keywords): mapped = std break final_scores[mapped] += p['score'] * weight sorted_res = sorted(final_scores.items(), key=lambda x: x[1], reverse=True) return { "primary": sorted_res[0][0] if sorted_res else "unknown", "confidence": round(sorted_res[0][1], 3) if sorted_res else 0, "distribution": {k: round(v, 3) for k, v in sorted_res} } # --- 4. API ENDPOINTS --- app = FastAPI(title="Emotion Ensemble API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") def health(): return {"status": "online", "hf_configured": bool(cfg.HF_TOKEN)} @app.get("/token") def get_token(user: str = "hf_user"): return {"token": create_access_token({"sub": user})} @app.post("/analyze") async def analyze(file: UploadFile = File(...), auth=Depends(verify_jwt)): start_time = time.time() # 1. Process Audio try: audio_bytes, duration = await process_audio(file) except Exception as e: raise HTTPException(400, f"Audio processing failed: {str(e)}") # 2. Run Parallel Inference async with aiohttp.ClientSession() as session: tasks = {name: query_hf(session, name, m["url"], audio_bytes) for name, m in cfg.MODELS.items()} results = await asyncio.gather(*tasks.values()) raw_responses = dict(zip(tasks.keys(), results)) # 3. Ensemble & Format successful_models = {k: v for k, v in raw_responses.items() if v is not None} if not successful_models: raise HTTPException(503, "All upstream models failed.") analysis = ensemble_logic(successful_models) return { "emotion": analysis["primary"], "confidence": analysis["confidence"], "scores": analysis["distribution"], "meta": { "duration_sec": round(duration, 2), "latency_sec": round(time.time() - start_time, 2), "models_responding": len(successful_models) } } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))