morphism / decode.py
acb's picture
Upload 5 files
5e284bb verified
#!/usr/bin/env python3
import os
import sys
import time
import hashlib
import numpy as np
import torch
import sqlite3
import logging
import argparse
import random
import traceback
import faiss
import pickle
from datetime import datetime
from collections import deque, defaultdict
from typing import List, Dict, Tuple, Optional, Union, Any
from pathlib import Path
# Import from our streaming module
from eegembed import EEGEmbeddingStream
PRINT_DEBUG_HASH = False
def fix_encoding(s):
if not s:
return s
if isinstance(s, str):
b = s.encode('utf-8', 'surrogateescape')
else:
b = s
fixed = b.decode('utf-8', 'replace')
if 'ì' in s or 'í' in s or 'ï' in s:
return ""
return fixed
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("EEGSemanticStream")
def setup_eeg_logger(eeg_file_path):
"""Set up a file logger based on the EEG filename."""
base_name = os.path.basename(eeg_file_path)
file_name = os.path.splitext(base_name)[0]
logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_logs")
if not os.path.exists(logs_dir):
os.makedirs(logs_dir)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file_path = os.path.join(logs_dir, f"{file_name}_{timestamp}.log")
log_file = open(log_file_path, "w", encoding="utf-8")
log_file.write(f"--- Session started at {timestamp} for EEG file: {base_name} ---\n")
log_file.flush()
return log_file
class EmbeddingIndex:
def __init__(self, dim=1536, use_gpu=True):
self.dim = dim
self.use_gpu = use_gpu and faiss.get_num_gpus() > 0
self.index = None
self.gpu_resources = None
self.message_ids = []
if self.use_gpu:
self.gpu_resources = faiss.StandardGpuResources()
def add_embeddings(self, embeddings: np.ndarray, message_ids: List[int]):
logger.info(f"Building FAISS index with {len(embeddings)} embeddings")
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / (norms + 1e-8)
cpu_index = faiss.IndexFlatIP(self.dim)
cpu_index.add(embeddings.astype(np.float32))
if self.use_gpu:
try:
self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index)
logger.info("Using GPU FAISS index")
except Exception as e:
logger.warning(f"GPU failed: {e}. Using CPU.")
self.index = cpu_index
self.use_gpu = False
else:
self.index = cpu_index
logger.info("Using CPU FAISS index")
self.message_ids = message_ids
def get_current_count(self):
if self.index is None:
return 0
return self.index.ntotal
def search(self, query: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
if self.index is None:
raise RuntimeError("Index not initialized")
norm = np.linalg.norm(query)
if norm > 0:
query = query / norm
actual_k = min(k, self.get_current_count())
if actual_k == 0:
return np.array([]), np.array([])
similarities, indices = self.index.search(query.astype(np.float32), actual_k)
distances = 1.0 - similarities
labels = np.array([[self.message_ids[idx] for idx in row] for row in indices])
return distances, labels
def save(self, path: str):
if self.index is None:
raise RuntimeError("Cannot save uninitialized index")
if self.use_gpu:
cpu_index = faiss.index_gpu_to_cpu(self.index)
faiss.write_index(cpu_index, f"{path}.index")
else:
faiss.write_index(self.index, f"{path}.index")
with open(f"{path}_message_ids.pkl", 'wb') as f:
pickle.dump(self.message_ids, f)
@classmethod
def load(cls, path: str, use_gpu: bool = True) -> 'EmbeddingIndex':
with open(f"{path}_message_ids.pkl", 'rb') as f:
message_ids = pickle.load(f)
index = cls(use_gpu=use_gpu)
index.message_ids = message_ids
cpu_index = faiss.read_index(f"{path}.index")
if index.use_gpu:
try:
index.index = faiss.index_cpu_to_gpu(index.gpu_resources, 0, cpu_index)
logger.info("Loaded existing index and moved to GPU")
except Exception as e:
logger.warning(f"Failed to move loaded index to GPU: {e}. Using CPU.")
index.index = cpu_index
index.use_gpu = False
else:
index.index = cpu_index
logger.info("Loaded existing index on CPU")
return index
class EEGSemanticProcessor:
"""
Process EEG data through autoencoder and semantic model pipeline,
then lookup similar messages.
"""
def __init__(
self,
autoencoder_model_path: str,
semantic_model_path: str,
nexus_db_path: str,
embeddings_db_path: str,
index_path: str = None,
eeg_file_path: str = None,
window_size: int = 624,
stride: int = 64,
batch_size: int = 32,
normalize: bool = True,
device: str = None,
search_k: int = 180,
final_k: int = 90,
use_raw_eeg: bool = False,
last_n_messages: int = 3,
input_dim_override: int = None,
save_vectors: bool = False,
vector_output_path: str = None
):
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
logger.info(f"Using device: {self.device}")
self.last_n_messages = last_n_messages
self.use_raw_eeg = use_raw_eeg
self.input_dim_override = input_dim_override
# Initialize EEG stream
self.eeg_stream = EEGEmbeddingStream(
file_path=eeg_file_path if eeg_file_path else "",
model_path=autoencoder_model_path,
window_size=window_size,
stride=stride,
normalize=normalize,
batch_size=batch_size,
device=self.device
)
# Load traced semantic model
logger.info(f"Loading traced semantic model from {semantic_model_path}")
self.semantic_model = torch.jit.load(semantic_model_path, map_location=self.device)
self.semantic_model.eval()
# Probe to get input/output dims
# Try a few common input sizes to find the right one
self._semantic_input_dim = None
self._semantic_output_dim = None
for test_dim in [64, 10112]:
try:
dummy = torch.randn(1, test_dim, device=self.device)
with torch.no_grad():
out = self.semantic_model(dummy)
self._semantic_input_dim = test_dim
self._semantic_output_dim = out.shape[1]
logger.info(f"Semantic model: input_dim={self._semantic_input_dim}, output_dim={self._semantic_output_dim}")
break
except Exception:
continue
if self._semantic_input_dim is None:
logger.warning("Could not auto-detect semantic model input dim. Will adapt at runtime.")
self.log_file = setup_eeg_logger(eeg_file_path) if eeg_file_path else None
# Initialize database connections
self.nexus_conn = sqlite3.connect(nexus_db_path)
self.embeddings_conn = sqlite3.connect(embeddings_db_path)
# Message tracking system
self.search_k = search_k
self.final_k = final_k
self.message_counts = defaultdict(int)
self.recent_messages = deque(maxlen=10)
self.repetition_penalty = 1.5
logger.info("Creating embedding index")
self.embedding_index = self._create_index(index_path)
self.error_count = 0
self.max_consecutive_errors = 5
self.save_vectors = save_vectors
self.vector_output_path = vector_output_path
if self.save_vectors:
self.vectors_list = []
self.timestamps = []
logger.info(f"Vector saving enabled. Output will be saved to {self.vector_output_path}")
self.previous_message_sets = deque(maxlen=self.last_n_messages)
def _create_index(self, index_path: str = None) -> EmbeddingIndex:
"""Create or load the embedding index for similarity search"""
cursor = self.embeddings_conn.cursor()
cursor.execute("SELECT COUNT(*) FROM embeddings")
db_count = cursor.fetchone()[0]
cursor.execute("SELECT MAX(message_id) FROM embeddings")
db_max_id = cursor.fetchone()[0]
if index_path and os.path.exists(f"{index_path}.index"):
try:
logger.info(f"Checking existing index at {index_path}")
index = EmbeddingIndex.load(index_path)
metadata_path = f"{index_path}_metadata.npz"
if os.path.exists(metadata_path):
metadata = np.load(metadata_path, allow_pickle=True)
saved_count = int(metadata.get('count', 0))
saved_max_id = int(metadata.get('max_message_id', 0))
logger.info(f"Saved index: {saved_count} items, max_id={saved_max_id}")
logger.info(f"Database: {db_count} items, max_id={db_max_id}")
if db_count != saved_count or db_max_id != saved_max_id:
logger.info("Database has changed. Recreating index...")
else:
logger.info("Database unchanged. Using existing index...")
return index
except Exception as e:
logger.warning(f"Error checking existing index: {str(e)}")
logger.info("Will create new index")
logger.info("Creating new index from database...")
cursor.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id")
embeddings = []
message_ids = []
for message_id, emb in cursor.fetchall():
embedding = np.frombuffer(emb, dtype=np.float32)
embeddings.append(embedding)
message_ids.append(message_id)
if not embeddings:
raise ValueError("No embeddings found in database")
embeddings = np.vstack(embeddings)
logger.info(f"Loaded {len(embeddings)} embeddings with shape: {embeddings.shape}")
index = EmbeddingIndex(dim=embeddings.shape[1])
index.add_embeddings(embeddings, message_ids)
if index_path:
logger.info(f"Saving index to {index_path}")
index.save(index_path)
metadata = {
'count': db_count,
'max_message_id': db_max_id
}
np.savez(f"{index_path}_metadata.npz", **metadata)
return index
def process_eeg_embedding(self, eeg_embedding: np.ndarray) -> torch.Tensor:
"""Convert EEG embedding to text embedding using the traced semantic model"""
with torch.no_grad():
tensor = torch.tensor(eeg_embedding, dtype=torch.float32).to(self.device)
if len(tensor.shape) < 2:
tensor = tensor.unsqueeze(0)
batch_size = tensor.shape[0]
tensor = tensor.reshape(batch_size, -1)
# Adapt dimensions if needed
if self._semantic_input_dim is not None:
current_features = tensor.shape[1]
if current_features != self._semantic_input_dim:
if current_features < self._semantic_input_dim:
padded = torch.zeros(batch_size, self._semantic_input_dim, device=self.device)
padded[:, :current_features] = tensor
tensor = padded
else:
tensor = tensor[:, :self._semantic_input_dim]
return self.semantic_model(tensor)
def find_similar_messages(self, embedding: torch.Tensor, assistant_only=False) -> List[str]:
"""Find similar messages using the embedding index"""
embedding_np = embedding.detach().cpu().numpy()
if len(embedding_np.shape) > 1:
embedding_np = embedding_np.reshape(1, -1)
try:
distances, indices = self.embedding_index.search(embedding_np, self.search_k)
distances = distances.flatten()
indices = indices.flatten()
cursor = self.nexus_conn.cursor()
candidates = []
if assistant_only:
query = """
SELECT content FROM messages
WHERE id = ? AND role = 'assistant'
"""
else:
query = """
SELECT content FROM messages
WHERE id = ?
"""
for message_id, distance in zip(indices, distances):
cursor.execute(query, (int(message_id),))
if result := cursor.fetchone():
content = result[0]
candidates.append(content)
return candidates[:self.final_k]
except Exception as e:
logger.error(f"Error during similarity search: {str(e)}")
traceback.print_exc()
return []
def save_vectors_to_disk(self):
"""Save the collected vectors and timestamps to disk"""
if not self.vectors_list:
logger.warning("No vectors to save")
return
output_dir = os.path.dirname(self.vector_output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
vectors_array = np.vstack(self.vectors_list)
timestamps_array = np.array(self.timestamps)
logger.info(f"Saving {len(self.vectors_list)} vectors to {self.vector_output_path}")
np.savez(
self.vector_output_path,
vectors=vectors_array,
timestamps=timestamps_array
)
logger.info(f"Vectors saved successfully to {self.vector_output_path}")
def process_streaming_embeddings(self, callback=None):
"""
Process streaming EEG embeddings through the semantic model
and find similar messages.
"""
self.eeg_stream.start()
try:
consecutive_errors = 0
while True:
try:
for embedding_data in self.eeg_stream.get_embeddings(timeout=0.5):
try:
autoencoder_embedding = embedding_data['embedding']
semantic_embedding = self.process_eeg_embedding(autoencoder_embedding)
if self.save_vectors:
embedding_np = semantic_embedding.detach().cpu().numpy()
self.vectors_list.append(embedding_np)
self.timestamps.append({
'start': embedding_data['start_timestamp'],
'end': embedding_data['end_timestamp']
})
if len(self.vectors_list) % 100 == 0:
logger.info(f"Collected {len(self.vectors_list)} vectors so far")
continue
similar_messages = self.find_similar_messages(semantic_embedding)
result = {
'start_timestamp': embedding_data['start_timestamp'],
'end_timestamp': embedding_data['end_timestamp'],
'processing_time': 0,
'similar_messages': similar_messages
}
if callback:
callback(result)
else:
self._print_unique_lines(result)
consecutive_errors = 0
except Exception as e:
print(f"Error: {str(e)}", file=sys.stderr)
consecutive_errors += 1
if consecutive_errors >= self.max_consecutive_errors:
raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
time.sleep(0.01)
except Exception as e:
if "Too many consecutive errors" in str(e):
raise
print(f"Error: {str(e)}", file=sys.stderr)
consecutive_errors += 1
if consecutive_errors >= self.max_consecutive_errors:
raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})")
time.sleep(1)
except KeyboardInterrupt:
pass
except Exception as e:
print(f"Fatal error: {str(e)}", file=sys.stderr)
finally:
if self.save_vectors and self.vectors_list:
self.save_vectors_to_disk()
self.eeg_stream.stop()
def _print_unique_lines(self, result):
"""Print only lines that aren't in common with the last n batches of messages"""
if not result['similar_messages']:
return
sample_size = min(42, len(result['similar_messages']))
current_messages = random.sample(result['similar_messages'], sample_size)
current_lines = set()
for message in current_messages:
for line in message.splitlines():
line = line.strip()
if line:
current_lines.add(line)
unique_lines = current_lines.copy()
for previous_lines in self.previous_message_sets:
unique_lines -= previous_lines
self.previous_message_sets.append(current_lines)
__uniq_log_empty = False
if unique_lines:
if PRINT_DEBUG_HASH:
unique_lines = [f"{hash} | {line}" for (hash, line) in zip(
map(lambda s: hashlib.md5(s.encode()).hexdigest()[:8], unique_lines),
unique_lines)]
unique_lines = filter(lambda s: bool(s), map(fix_encoding, unique_lines))
output_text = "\n".join(sorted(unique_lines))
print(output_text)
if hasattr(self, 'log_file') and self.log_file:
try:
self.log_file.write(output_text + "\n")
self.log_file.flush()
except Exception as e:
print(f"Error writing to log file: {str(e)}", file=sys.stderr)
elif __uniq_log_empty:
logger.info(f"No unique lines")
def main():
parser = argparse.ArgumentParser(description='Process EEG data through semantic model and lookup similar messages')
parser.add_argument('--autoencoder', '-a', type=str, required=True,
help='Path to the traced autoencoder encoder model')
parser.add_argument('--semantic-model', '-s', type=str, required=True,
help='Path to the traced semantic model')
parser.add_argument('--nexus-db', '-n', type=str,
default=os.path.expanduser('~/.nexus/data/nexus-new.db'),
help='Path to the nexus database')
parser.add_argument('--embeddings-db', '-e', type=str, default='emb_full.db',
help='Path to the embeddings database')
parser.add_argument('--index', '-i', type=str, default='embedding_index',
help='Path to save/load the FAISS index')
parser.add_argument('--eeg-file', '-f', type=str, required=True,
help='Path to the EEG data file to monitor')
parser.add_argument('--window-size', type=int, default=624,
help='Window size in samples')
parser.add_argument('--stride', type=int, default=32,
help='Stride between windows')
parser.add_argument('--batch-size', type=int, default=32,
help='Batch size for processing')
parser.add_argument('--no-normalize', dest='normalize', action='store_false',
help='Disable normalization of EEG data')
parser.add_argument('--search-k', type=int, default=180,
help='Number of candidates to retrieve for selection')
parser.add_argument('--final-k', type=int, default=90,
help='Number of results to show')
parser.add_argument('--device', type=str, default=None,
help='Device to use (cuda or cpu)')
parser.add_argument('--last_n', type=int, default=None,
help='Window queue size for repetition filter')
parser.add_argument('--use-raw-eeg', action='store_true',
help='Use raw EEG data with semantic model (skip autoencoder)')
parser.add_argument('--input-dim', type=int,
help='Override the input dimension for the semantic model')
parser.add_argument('--save-vectors', action='store_true',
help='Save semantic vectors to disk without generating output')
parser.add_argument('--vector-output', type=str, default='semantic_vectors.npz',
help='Path to save the semantic vectors')
args = parser.parse_args()
processor = EEGSemanticProcessor(
autoencoder_model_path=args.autoencoder,
semantic_model_path=args.semantic_model,
nexus_db_path=args.nexus_db,
embeddings_db_path=args.embeddings_db,
index_path=args.index,
last_n_messages=args.last_n,
eeg_file_path=args.eeg_file,
window_size=args.window_size,
stride=args.stride,
batch_size=args.batch_size,
normalize=args.normalize,
device=args.device,
search_k=args.search_k,
final_k=args.final_k,
use_raw_eeg=args.use_raw_eeg,
input_dim_override=args.input_dim,
save_vectors=args.save_vectors,
vector_output_path=args.vector_output
)
try:
processor.process_streaming_embeddings()
except KeyboardInterrupt:
pass
except Exception as e:
print(f"Error: {str(e)}", file=sys.stderr)
if __name__ == "__main__":
main()