import asyncio import aiohttp from typing import Dict, List, Any, Optional from collections import defaultdict import logging from app.config import settings logger = logging.getLogger(__name__) class EmotionEnsemble: """Ensemble of emotion detection models""" def __init__(self): self.models = settings.ENABLED_MODELS self.emotion_mapping = { "angry": ["angry", "ang", "anger"], "happy": ["happy", "hap", "happiness", "joy"], "sad": ["sad", "sadness"], "fear": ["fear", "fearful"], "surprise": ["surprise", "surprised"], "disgust": ["disgust", "disgusted"], "neutral": ["neutral", "neu"] } async def predict(self, audio_bytes: bytes) -> Dict[str, Any]: """ Run ensemble prediction on audio bytes Returns fused predictions from all models """ headers = {"Authorization": f"Bearer {settings.HF_TOKEN}"} async with aiohttp.ClientSession() as session: # Create tasks for all enabled models tasks = [] model_names = [] for name, config in self.models.items(): if config.get("enabled", True): tasks.append(self._query_model( session, name, config, audio_bytes, headers )) model_names.append(name) # Run all tasks concurrently results = await asyncio.gather(*tasks, return_exceptions=True) # Process successful predictions model_outputs = {} for name, result in zip(model_names, results): if result and not isinstance(result, Exception): model_outputs[name] = result logger.info(f"✓ {name} succeeded") else: logger.warning(f"✗ {name} failed: {result}") if not model_outputs: raise Exception("No models returned valid predictions") # Fuse predictions return self._fuse_predictions(model_outputs) async def _query_model(self, session, name, config, audio_bytes, headers): """Query a single model with timeout""" try: async with session.post( config["url"], headers=headers, data=audio_bytes, timeout=aiohttp.ClientTimeout(total=config["timeout"]) ) as response: if response.status == 200: return await response.json() elif response.status == 503: # Model loading - wait and retry once await asyncio.sleep(2) async with session.post( config["url"], headers=headers, data=audio_bytes ) as retry: if retry.status == 200: return await retry.json() logger.warning(f"{name} returned {response.status}") return None except asyncio.TimeoutError: logger.warning(f"{name} timeout") return None except Exception as e: logger.warning(f"{name} error: {e}") return None def _fuse_predictions(self, model_outputs: Dict[str, List]) -> Dict[str, Any]: """Fuse predictions using weighted voting""" emotion_scores = defaultdict(float) total_weight = 0.0 model_contributions = [] for name, predictions in model_outputs.items(): weight = self.models[name]["weight"] total_weight += weight contribution = { "model": name, "weight": weight, "predictions": [] } for pred in predictions: label = pred.get("label", "").lower() score = pred.get("score", 0.0) # Map to standard emotions mapped = self._map_emotion(label) contribution["predictions"].append({ "original": label, "mapped": mapped, "score": score }) emotion_scores[mapped] += score * weight model_contributions.append(contribution) # Normalize scores if total_weight > 0: emotion_scores = { k: v / total_weight for k, v in emotion_scores.items() } # Get primary emotion if emotion_scores: primary = max(emotion_scores.items(), key=lambda x: x[1]) else: primary = ("unknown", 0.0) return { "primary_emotion": primary[0], "confidence": round(primary[1], 4), "all_emotions": { k: round(v, 4) for k, v in sorted( emotion_scores.items(), key=lambda x: x[1], reverse=True ) }, "ensemble_details": { "models_used": list(model_outputs.keys()), "total_models": len(self.models), "model_contributions": model_contributions } } def _map_emotion(self, label: str) -> str: """Map model-specific label to standard emotion""" label_lower = label.lower() for std_emo, variations in self.emotion_mapping.items(): if any(var in label_lower for var in variations): return std_emo # Default fallback if "ang" in label_lower: return "angry" elif "hap" in label_lower: return "happy" elif "sad" in label_lower: return "sad" elif "neu" in label_lower: return "neutral" elif "fea" in label_lower: return "fear" elif "sur" in label_lower: return "surprise" elif "dis" in label_lower: return "disgust" return "neutral" ensemble = EmotionEnsemble()