Spaces:
No application file
No application file
File size: 6,404 Bytes
12f6795 8ef2cca 84f84c3 8ef2cca 12f6795 8ef2cca 84f84c3 8ef2cca 12f6795 8ef2cca 84f84c3 8ef2cca 12f6795 84f84c3 8ef2cca 84f84c3 8ef2cca 84f84c3 12f6795 8ef2cca 84f84c3 8ef2cca 84f84c3 8ef2cca 84f84c3 8ef2cca 84f84c3 8ef2cca 84f84c3 12f6795 84f84c3 12f6795 84f84c3 12f6795 84f84c3 8ef2cca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | import os
import time
import jwt
import logging
import asyncio
import hashlib
import tempfile
import subprocess
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Any
import aiohttp
import librosa
import uvicorn
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
# --- 1. CONFIGURATION ---
class GlobalConfig:
# Set these in HF Space Secrets
HF_TOKEN = os.getenv("HF_TOKEN", "")
API_SECRET = os.getenv("API_SECRET_KEY", "default_secret_change_me_in_production")
MODELS = {
"emotion2vec": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "w": 0.50},
"meralion": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "w": 0.25},
"wav2vec2": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "w": 0.15},
"hubert": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "w": 0.07},
"gigam": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "w": 0.03}
}
# Standardized internal labels
MAPPING = {
"angry": ["ang", "fear"], # Merging high-arousal negative
"happy": ["hap", "joy", "surp"],
"sad": ["sad"],
"neutral": ["neu", "calm"]
}
cfg = GlobalConfig()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("EmotionAPI")
# --- 2. AUTHENTICATION ---
security = HTTPBearer()
def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=60)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, cfg.API_SECRET, algorithm="HS256")
async def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
try:
payload = jwt.decode(credentials.credentials, cfg.API_SECRET, algorithms=["HS256"])
return payload
except Exception:
raise HTTPException(status_code=401, detail="Invalid/Expired Token")
# --- 3. CORE LOGIC ---
async def process_audio(file: UploadFile):
"""Handles format conversion and validation"""
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_in:
content = await file.read()
tmp_in.write(content)
input_path = tmp_in.name
output_path = input_path + ".wav"
try:
# Standardize to 16kHz Mono WAV
proc = subprocess.run(
["ffmpeg", "-i", input_path, "-ar", "16000", "-ac", "1", "-y", output_path],
capture_output=True, text=True
)
if proc.returncode != 0:
raise Exception(f"FFmpeg error: {proc.stderr}")
with open(output_path, "rb") as f:
audio_bytes = f.read()
duration = librosa.get_duration(path=output_path)
return audio_bytes, duration
finally:
for p in [input_path, output_path]:
if os.path.exists(p): os.unlink(p)
async def query_hf(session, name, url, data):
"""Individual model call with retry for 'loading' status"""
headers = {"Authorization": f"Bearer {cfg.HF_TOKEN}"}
for _ in range(3): # Simple retry if model is loading
async with session.post(url, headers=headers, data=data) as resp:
res = await resp.json()
if resp.status == 200:
return res
elif resp.status == 503: # Model loading
await asyncio.sleep(5)
continue
return None
def ensemble_logic(responses: dict):
"""Weighted average of results"""
final_scores = defaultdict(float)
for name, preds in responses.items():
if not isinstance(preds, list): continue
weight = cfg.MODELS[name]["w"]
for p in preds:
label = p['label'].lower()
# Map labels to our standard set
mapped = "neutral"
for std, keywords in cfg.MAPPING.items():
if any(k in label for k in keywords):
mapped = std
break
final_scores[mapped] += p['score'] * weight
sorted_res = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)
return {
"primary": sorted_res[0][0] if sorted_res else "unknown",
"confidence": round(sorted_res[0][1], 3) if sorted_res else 0,
"distribution": {k: round(v, 3) for k, v in sorted_res}
}
# --- 4. API ENDPOINTS ---
app = FastAPI(title="Emotion Ensemble API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "online", "hf_configured": bool(cfg.HF_TOKEN)}
@app.get("/token")
def get_token(user: str = "hf_user"):
return {"token": create_access_token({"sub": user})}
@app.post("/analyze")
async def analyze(file: UploadFile = File(...), auth=Depends(verify_jwt)):
start_time = time.time()
# 1. Process Audio
try:
audio_bytes, duration = await process_audio(file)
except Exception as e:
raise HTTPException(400, f"Audio processing failed: {str(e)}")
# 2. Run Parallel Inference
async with aiohttp.ClientSession() as session:
tasks = {name: query_hf(session, name, m["url"], audio_bytes)
for name, m in cfg.MODELS.items()}
results = await asyncio.gather(*tasks.values())
raw_responses = dict(zip(tasks.keys(), results))
# 3. Ensemble & Format
successful_models = {k: v for k, v in raw_responses.items() if v is not None}
if not successful_models:
raise HTTPException(503, "All upstream models failed.")
analysis = ensemble_logic(successful_models)
return {
"emotion": analysis["primary"],
"confidence": analysis["confidence"],
"scores": analysis["distribution"],
"meta": {
"duration_sec": round(duration, 2),
"latency_sec": round(time.time() - start_time, 2),
"models_responding": len(successful_models)
}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |