AndreyBot / Example_Loader.py
root39058's picture
Create Example_Loader.py
b8af3c7 verified
import torch
import torch.nn as nn
import json
import random
from collections import deque
# ============ КЛАСС МОДЕЛИ (ДОЛЖЕН СОВПАДАТЬ) ============
class AndreyBot(nn.Module):
def __init__(self, vocab_size=89, hidden_size=256, num_layers=2, embedding_dim=64):
super().__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding_dim = embedding_dim
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, num_layers=num_layers, dropout=0.2)
self.fc = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
if x.dim() == 1:
x = x.unsqueeze(0)
emb = self.embedding(x)
out, _ = self.lstm(emb)
out = self.dropout(out)
out = out[:, -1, :]
return self.fc(out)
# ============ ЗАГРУЗЧИК ============
class AndreyLoader:
def __init__(self, model_path='andrey_full.pt'):
print("🤖 Загрузка Андрея...")
# Загружаем чекпоинт
checkpoint = torch.load(model_path, map_location='cpu')
# Создаём модель
self.model = AndreyBot(
vocab_size=checkpoint['vocab_size'],
hidden_size=checkpoint['hidden_size'],
num_layers=checkpoint['num_layers'],
embedding_dim=checkpoint['embedding_dim']
)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Загружаем словари
self.word_to_idx = checkpoint['word_to_idx']
self.idx_to_word = {int(k): v for k, v in checkpoint['idx_to_word'].items()}
# Параметры
self.vocab_size = checkpoint['vocab_size']
self.epoch = checkpoint['epoch']
# Константы
self.PAD = 0
self.UNK = 1
print(f"✅ Андрей загружен!")
print(f" 📚 Словарь: {self.vocab_size} слов")
print(f" 🧠 Обучен: {self.epoch} эпох")
def tokenize(self, text):
"""Текст → список токенов"""
words = text.lower().split()
return [self.word_to_idx.get(w, self.UNK) for w in words if w in self.word_to_idx]
def detokenize(self, tokens):
"""Токены → текст"""
words = []
for t in tokens:
if t not in [self.PAD, self.UNK]:
words.append(self.idx_to_word.get(t, '?'))
return ' '.join(words)
def generate(self, question, temperature=0.85, max_length=12):
"""Генерация ответа"""
tokens = self.tokenize(question)
if not tokens:
return "..."
current = tokens[0]
response_tokens = []
with torch.no_grad():
for _ in range(max_length):
# Получаем логиты от модели
logits = self.model(torch.tensor([[current]]))
# Применяем температуру
probs = torch.softmax(logits[0] / temperature, dim=-1)
# Top-K выбор
top_k = min(5, self.vocab_size)
top_probs, top_idx = torch.topk(probs, top_k)
probs = top_probs / top_probs.sum()
current = top_idx[torch.multinomial(probs, 1)].item()
# Пропускаем спецтокены
if current in [self.PAD, self.UNK]:
continue
response_tokens.append(current)
# Останавливаемся при длине
if len(response_tokens) >= max_length - 2:
break
if not response_tokens:
return "..."
return self.detokenize(response_tokens)
def chat(self):
"""Запуск чата"""
print("\n" + "="*50)
print("🤖 АНДРЕЙ")
print(" Скажи 'пока' для выхода")
print("="*50 + "\n")
while True:
user = input("👤 Вы: ").strip().lower()
if user in ['пока', 'выход', 'exit', 'quit']:
print("🤖 Андрей: Пока! 👋")
break
if user == '':
continue
answer = self.generate(user)
print(f"🤖 Андрей: {answer}\n")
# ============ ЗАПУСК ============
if __name__ == "__main__":
# Загружаем Андрея
andrey = AndreyLoader('andrey_full.pt')
# Запускаем чат
andrey.chat()