| | --- |
| | license: apache-2.0 |
| | datasets: |
| | - dair-ai/emotion |
| | language: |
| | - en |
| | metrics: |
| | - accuracy |
| | - f1 |
| | - precision |
| | - recall |
| | pipeline_tag: text-classification |
| | --- |
| | |
| |
|
| |
|
| | # Emotion Classification with BERT + RL Fine-tuning |
| |
|
| | This model combines BERT architecture with Reinforcement Learning (RL) for emotion classification. Initially fine-tuned on the `dair-ai/emotion` dataset (20k English sentences with 6 emotions), we then applied PPO reinforcement learning to optimize prediction behavior. |
| |
|
| | ## π§ Training Approach |
| |
|
| | 1. **Supervised Phase**: |
| | - Base BERT model fine-tuned with cross-entropy loss |
| | - Achieved strong baseline performance |
| |
|
| | 2. **RL Phase**: |
| | - Implemented Actor-Critic architecture |
| | - Policy Gradient optimization with custom rewards |
| | - PPO clipping (Ξ΅=0.2) and entropy regularization |
| | - Custom reward function: `+1.0` for correct, `-0.1` for incorrect predictions |
| |
|
| | ## π Performance Comparison |
| |
|
| | | Metric | Pre-RL | Post-RL | Ξ | |
| | |------------|---------|---------|---------| |
| | | Accuracy | 0.9205 | 0.931 | +1.14% | |
| | | F1-Score | 0.9227 | 0.9298 | +0.77% | |
| | | Precision | 0.9325 | 0.9305 | -0.21% | |
| | | Recall | 0.9205 | 0.931 | +1.14% | |
| |
|
| | Key observation: RL fine-tuning provided modest but consistent improvements across most metrics, particularly in recall. |
| |
|
| | ## π Usage |
| |
|
| | ```python |
| | from transformers import pipeline |
| | |
| | # Load from your repository |
| | classifier = pipeline("text-classification", |
| | model="SimoGiuffrida/SentimentRL", |
| | tokenizer="bert-base-uncased") |
| | |
| | results = classifier("I'm thrilled about this new opportunity!") |
| | ``` |
| |
|
| | ## π‘ Key Features |
| | - Hybrid training: Supervised + Reinforcement Learning |
| | - Optimized for nuanced emotion detection |
| | - Handles class imbalance (see confusion matrix in repo) |
| |
|
| | For full training details and analysis, visit the [GitHub repository](https://github.com/SimoGiuffrida/DLA2). |