File size: 5,017 Bytes
b8af3c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import json
import random
from collections import deque

# ============ КЛАСС МОДЕЛИ (ДОЛЖЕН СОВПАДАТЬ) ============
class AndreyBot(nn.Module):
    def __init__(self, vocab_size=89, hidden_size=256, num_layers=2, embedding_dim=64):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, num_layers=num_layers, dropout=0.2)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        emb = self.embedding(x)
        out, _ = self.lstm(emb)
        out = self.dropout(out)
        out = out[:, -1, :]
        return self.fc(out)

# ============ ЗАГРУЗЧИК ============
class AndreyLoader:
    def __init__(self, model_path='andrey_full.pt'):
        print("🤖 Загрузка Андрея...")
        
        # Загружаем чекпоинт
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Создаём модель
        self.model = AndreyBot(
            vocab_size=checkpoint['vocab_size'],
            hidden_size=checkpoint['hidden_size'],
            num_layers=checkpoint['num_layers'],
            embedding_dim=checkpoint['embedding_dim']
        )
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        # Загружаем словари
        self.word_to_idx = checkpoint['word_to_idx']
        self.idx_to_word = {int(k): v for k, v in checkpoint['idx_to_word'].items()}
        
        # Параметры
        self.vocab_size = checkpoint['vocab_size']
        self.epoch = checkpoint['epoch']
        
        # Константы
        self.PAD = 0
        self.UNK = 1
        
        print(f"✅ Андрей загружен!")
        print(f"   📚 Словарь: {self.vocab_size} слов")
        print(f"   🧠 Обучен: {self.epoch} эпох")
    
    def tokenize(self, text):
        """Текст → список токенов"""
        words = text.lower().split()
        return [self.word_to_idx.get(w, self.UNK) for w in words if w in self.word_to_idx]
    
    def detokenize(self, tokens):
        """Токены → текст"""
        words = []
        for t in tokens:
            if t not in [self.PAD, self.UNK]:
                words.append(self.idx_to_word.get(t, '?'))
        return ' '.join(words)
    
    def generate(self, question, temperature=0.85, max_length=12):
        """Генерация ответа"""
        tokens = self.tokenize(question)
        
        if not tokens:
            return "..."
        
        current = tokens[0]
        response_tokens = []
        
        with torch.no_grad():
            for _ in range(max_length):
                # Получаем логиты от модели
                logits = self.model(torch.tensor([[current]]))
                
                # Применяем температуру
                probs = torch.softmax(logits[0] / temperature, dim=-1)
                
                # Top-K выбор
                top_k = min(5, self.vocab_size)
                top_probs, top_idx = torch.topk(probs, top_k)
                probs = top_probs / top_probs.sum()
                current = top_idx[torch.multinomial(probs, 1)].item()
                
                # Пропускаем спецтокены
                if current in [self.PAD, self.UNK]:
                    continue
                
                response_tokens.append(current)
                
                # Останавливаемся при длине
                if len(response_tokens) >= max_length - 2:
                    break
        
        if not response_tokens:
            return "..."
        
        return self.detokenize(response_tokens)
    
    def chat(self):
        """Запуск чата"""
        print("\n" + "="*50)
        print("🤖 АНДРЕЙ")
        print("   Скажи 'пока' для выхода")
        print("="*50 + "\n")
        
        while True:
            user = input("👤 Вы: ").strip().lower()
            
            if user in ['пока', 'выход', 'exit', 'quit']:
                print("🤖 Андрей: Пока! 👋")
                break
            
            if user == '':
                continue
            
            answer = self.generate(user)
            print(f"🤖 Андрей: {answer}\n")

# ============ ЗАПУСК ============
if __name__ == "__main__":
    # Загружаем Андрея
    andrey = AndreyLoader('andrey_full.pt')
    
    # Запускаем чат
    andrey.chat()