PsalmsJava's picture
Some other change
28a1786
import asyncio
import aiohttp
from typing import Dict, List, Any, Optional
from collections import defaultdict
import logging
from app.config import settings
logger = logging.getLogger(__name__)
class EmotionEnsemble:
"""Ensemble of emotion detection models"""
def __init__(self):
self.models = settings.ENABLED_MODELS
self.emotion_mapping = {
"angry": ["angry", "ang", "anger"],
"happy": ["happy", "hap", "happiness", "joy"],
"sad": ["sad", "sadness"],
"fear": ["fear", "fearful"],
"surprise": ["surprise", "surprised"],
"disgust": ["disgust", "disgusted"],
"neutral": ["neutral", "neu"]
}
async def predict(self, audio_bytes: bytes) -> Dict[str, Any]:
"""
Run ensemble prediction on audio bytes
Returns fused predictions from all models
"""
headers = {"Authorization": f"Bearer {settings.HF_TOKEN}"}
async with aiohttp.ClientSession() as session:
# Create tasks for all enabled models
tasks = []
model_names = []
for name, config in self.models.items():
if config.get("enabled", True):
tasks.append(self._query_model(
session, name, config, audio_bytes, headers
))
model_names.append(name)
# Run all tasks concurrently
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process successful predictions
model_outputs = {}
for name, result in zip(model_names, results):
if result and not isinstance(result, Exception):
model_outputs[name] = result
logger.info(f"✓ {name} succeeded")
else:
logger.warning(f"✗ {name} failed: {result}")
if not model_outputs:
raise Exception("No models returned valid predictions")
# Fuse predictions
return self._fuse_predictions(model_outputs)
async def _query_model(self, session, name, config, audio_bytes, headers):
"""Query a single model with timeout"""
try:
async with session.post(
config["url"],
headers=headers,
data=audio_bytes,
timeout=aiohttp.ClientTimeout(total=config["timeout"])
) as response:
if response.status == 200:
return await response.json()
elif response.status == 503:
# Model loading - wait and retry once
await asyncio.sleep(2)
async with session.post(
config["url"],
headers=headers,
data=audio_bytes
) as retry:
if retry.status == 200:
return await retry.json()
logger.warning(f"{name} returned {response.status}")
return None
except asyncio.TimeoutError:
logger.warning(f"{name} timeout")
return None
except Exception as e:
logger.warning(f"{name} error: {e}")
return None
def _fuse_predictions(self, model_outputs: Dict[str, List]) -> Dict[str, Any]:
"""Fuse predictions using weighted voting"""
emotion_scores = defaultdict(float)
total_weight = 0.0
model_contributions = []
for name, predictions in model_outputs.items():
weight = self.models[name]["weight"]
total_weight += weight
contribution = {
"model": name,
"weight": weight,
"predictions": []
}
for pred in predictions:
label = pred.get("label", "").lower()
score = pred.get("score", 0.0)
# Map to standard emotions
mapped = self._map_emotion(label)
contribution["predictions"].append({
"original": label,
"mapped": mapped,
"score": score
})
emotion_scores[mapped] += score * weight
model_contributions.append(contribution)
# Normalize scores
if total_weight > 0:
emotion_scores = {
k: v / total_weight
for k, v in emotion_scores.items()
}
# Get primary emotion
if emotion_scores:
primary = max(emotion_scores.items(), key=lambda x: x[1])
else:
primary = ("unknown", 0.0)
return {
"primary_emotion": primary[0],
"confidence": round(primary[1], 4),
"all_emotions": {
k: round(v, 4)
for k, v in sorted(
emotion_scores.items(),
key=lambda x: x[1],
reverse=True
)
},
"ensemble_details": {
"models_used": list(model_outputs.keys()),
"total_models": len(self.models),
"model_contributions": model_contributions
}
}
def _map_emotion(self, label: str) -> str:
"""Map model-specific label to standard emotion"""
label_lower = label.lower()
for std_emo, variations in self.emotion_mapping.items():
if any(var in label_lower for var in variations):
return std_emo
# Default fallback
if "ang" in label_lower:
return "angry"
elif "hap" in label_lower:
return "happy"
elif "sad" in label_lower:
return "sad"
elif "neu" in label_lower:
return "neutral"
elif "fea" in label_lower:
return "fear"
elif "sur" in label_lower:
return "surprise"
elif "dis" in label_lower:
return "disgust"
return "neutral"
ensemble = EmotionEnsemble()