File size: 6,360 Bytes
28a1786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import asyncio
import aiohttp
from typing import Dict, List, Any, Optional
from collections import defaultdict
import logging
from app.config import settings

logger = logging.getLogger(__name__)

class EmotionEnsemble:
    """Ensemble of emotion detection models"""
    
    def __init__(self):
        self.models = settings.ENABLED_MODELS
        self.emotion_mapping = {
            "angry": ["angry", "ang", "anger"],
            "happy": ["happy", "hap", "happiness", "joy"],
            "sad": ["sad", "sadness"],
            "fear": ["fear", "fearful"],
            "surprise": ["surprise", "surprised"],
            "disgust": ["disgust", "disgusted"],
            "neutral": ["neutral", "neu"]
        }
    
    async def predict(self, audio_bytes: bytes) -> Dict[str, Any]:
        """
        Run ensemble prediction on audio bytes
        Returns fused predictions from all models
        """
        headers = {"Authorization": f"Bearer {settings.HF_TOKEN}"}
        
        async with aiohttp.ClientSession() as session:
            # Create tasks for all enabled models
            tasks = []
            model_names = []
            
            for name, config in self.models.items():
                if config.get("enabled", True):
                    tasks.append(self._query_model(
                        session, name, config, audio_bytes, headers
                    ))
                    model_names.append(name)
            
            # Run all tasks concurrently
            results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Process successful predictions
        model_outputs = {}
        for name, result in zip(model_names, results):
            if result and not isinstance(result, Exception):
                model_outputs[name] = result
                logger.info(f"✓ {name} succeeded")
            else:
                logger.warning(f"✗ {name} failed: {result}")
        
        if not model_outputs:
            raise Exception("No models returned valid predictions")
        
        # Fuse predictions
        return self._fuse_predictions(model_outputs)
    
    async def _query_model(self, session, name, config, audio_bytes, headers):
        """Query a single model with timeout"""
        try:
            async with session.post(
                config["url"],
                headers=headers,
                data=audio_bytes,
                timeout=aiohttp.ClientTimeout(total=config["timeout"])
            ) as response:
                if response.status == 200:
                    return await response.json()
                elif response.status == 503:
                    # Model loading - wait and retry once
                    await asyncio.sleep(2)
                    async with session.post(
                        config["url"],
                        headers=headers,
                        data=audio_bytes
                    ) as retry:
                        if retry.status == 200:
                            return await retry.json()
                
                logger.warning(f"{name} returned {response.status}")
                return None
                
        except asyncio.TimeoutError:
            logger.warning(f"{name} timeout")
            return None
        except Exception as e:
            logger.warning(f"{name} error: {e}")
            return None
    
    def _fuse_predictions(self, model_outputs: Dict[str, List]) -> Dict[str, Any]:
        """Fuse predictions using weighted voting"""
        emotion_scores = defaultdict(float)
        total_weight = 0.0
        model_contributions = []
        
        for name, predictions in model_outputs.items():
            weight = self.models[name]["weight"]
            total_weight += weight
            
            contribution = {
                "model": name,
                "weight": weight,
                "predictions": []
            }
            
            for pred in predictions:
                label = pred.get("label", "").lower()
                score = pred.get("score", 0.0)
                
                # Map to standard emotions
                mapped = self._map_emotion(label)
                contribution["predictions"].append({
                    "original": label,
                    "mapped": mapped,
                    "score": score
                })
                
                emotion_scores[mapped] += score * weight
            
            model_contributions.append(contribution)
        
        # Normalize scores
        if total_weight > 0:
            emotion_scores = {
                k: v / total_weight 
                for k, v in emotion_scores.items()
            }
        
        # Get primary emotion
        if emotion_scores:
            primary = max(emotion_scores.items(), key=lambda x: x[1])
        else:
            primary = ("unknown", 0.0)
        
        return {
            "primary_emotion": primary[0],
            "confidence": round(primary[1], 4),
            "all_emotions": {
                k: round(v, 4) 
                for k, v in sorted(
                    emotion_scores.items(), 
                    key=lambda x: x[1], 
                    reverse=True
                )
            },
            "ensemble_details": {
                "models_used": list(model_outputs.keys()),
                "total_models": len(self.models),
                "model_contributions": model_contributions
            }
        }
    
    def _map_emotion(self, label: str) -> str:
        """Map model-specific label to standard emotion"""
        label_lower = label.lower()
        
        for std_emo, variations in self.emotion_mapping.items():
            if any(var in label_lower for var in variations):
                return std_emo
        
        # Default fallback
        if "ang" in label_lower:
            return "angry"
        elif "hap" in label_lower:
            return "happy"
        elif "sad" in label_lower:
            return "sad"
        elif "neu" in label_lower:
            return "neutral"
        elif "fea" in label_lower:
            return "fear"
        elif "sur" in label_lower:
            return "surprise"
        elif "dis" in label_lower:
            return "disgust"
        
        return "neutral"

ensemble = EmotionEnsemble()