#!/usr/bin/env python3 import os import sys import time import hashlib import numpy as np import torch import sqlite3 import logging import argparse import random import traceback import faiss import pickle from datetime import datetime from collections import deque, defaultdict from typing import List, Dict, Tuple, Optional, Union, Any from pathlib import Path # Import from our streaming module from eegembed import EEGEmbeddingStream PRINT_DEBUG_HASH = False def fix_encoding(s): if not s: return s if isinstance(s, str): b = s.encode('utf-8', 'surrogateescape') else: b = s fixed = b.decode('utf-8', 'replace') if 'ì' in s or 'í' in s or 'ï' in s: return "" return fixed # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("EEGSemanticStream") def setup_eeg_logger(eeg_file_path): """Set up a file logger based on the EEG filename.""" base_name = os.path.basename(eeg_file_path) file_name = os.path.splitext(base_name)[0] logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_logs") if not os.path.exists(logs_dir): os.makedirs(logs_dir) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file_path = os.path.join(logs_dir, f"{file_name}_{timestamp}.log") log_file = open(log_file_path, "w", encoding="utf-8") log_file.write(f"--- Session started at {timestamp} for EEG file: {base_name} ---\n") log_file.flush() return log_file class EmbeddingIndex: def __init__(self, dim=1536, use_gpu=True): self.dim = dim self.use_gpu = use_gpu and faiss.get_num_gpus() > 0 self.index = None self.gpu_resources = None self.message_ids = [] if self.use_gpu: self.gpu_resources = faiss.StandardGpuResources() def add_embeddings(self, embeddings: np.ndarray, message_ids: List[int]): logger.info(f"Building FAISS index with {len(embeddings)} embeddings") norms = np.linalg.norm(embeddings, axis=1, keepdims=True) embeddings = embeddings / (norms + 1e-8) cpu_index = faiss.IndexFlatIP(self.dim) cpu_index.add(embeddings.astype(np.float32)) if self.use_gpu: try: self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index) logger.info("Using GPU FAISS index") except Exception as e: logger.warning(f"GPU failed: {e}. Using CPU.") self.index = cpu_index self.use_gpu = False else: self.index = cpu_index logger.info("Using CPU FAISS index") self.message_ids = message_ids def get_current_count(self): if self.index is None: return 0 return self.index.ntotal def search(self, query: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]: if self.index is None: raise RuntimeError("Index not initialized") norm = np.linalg.norm(query) if norm > 0: query = query / norm actual_k = min(k, self.get_current_count()) if actual_k == 0: return np.array([]), np.array([]) similarities, indices = self.index.search(query.astype(np.float32), actual_k) distances = 1.0 - similarities labels = np.array([[self.message_ids[idx] for idx in row] for row in indices]) return distances, labels def save(self, path: str): if self.index is None: raise RuntimeError("Cannot save uninitialized index") if self.use_gpu: cpu_index = faiss.index_gpu_to_cpu(self.index) faiss.write_index(cpu_index, f"{path}.index") else: faiss.write_index(self.index, f"{path}.index") with open(f"{path}_message_ids.pkl", 'wb') as f: pickle.dump(self.message_ids, f) @classmethod def load(cls, path: str, use_gpu: bool = True) -> 'EmbeddingIndex': with open(f"{path}_message_ids.pkl", 'rb') as f: message_ids = pickle.load(f) index = cls(use_gpu=use_gpu) index.message_ids = message_ids cpu_index = faiss.read_index(f"{path}.index") if index.use_gpu: try: index.index = faiss.index_cpu_to_gpu(index.gpu_resources, 0, cpu_index) logger.info("Loaded existing index and moved to GPU") except Exception as e: logger.warning(f"Failed to move loaded index to GPU: {e}. Using CPU.") index.index = cpu_index index.use_gpu = False else: index.index = cpu_index logger.info("Loaded existing index on CPU") return index class EEGSemanticProcessor: """ Process EEG data through autoencoder and semantic model pipeline, then lookup similar messages. """ def __init__( self, autoencoder_model_path: str, semantic_model_path: str, nexus_db_path: str, embeddings_db_path: str, index_path: str = None, eeg_file_path: str = None, window_size: int = 624, stride: int = 64, batch_size: int = 32, normalize: bool = True, device: str = None, search_k: int = 180, final_k: int = 90, use_raw_eeg: bool = False, last_n_messages: int = 3, input_dim_override: int = None, save_vectors: bool = False, vector_output_path: str = None ): if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) logger.info(f"Using device: {self.device}") self.last_n_messages = last_n_messages self.use_raw_eeg = use_raw_eeg self.input_dim_override = input_dim_override # Initialize EEG stream self.eeg_stream = EEGEmbeddingStream( file_path=eeg_file_path if eeg_file_path else "", model_path=autoencoder_model_path, window_size=window_size, stride=stride, normalize=normalize, batch_size=batch_size, device=self.device ) # Load traced semantic model logger.info(f"Loading traced semantic model from {semantic_model_path}") self.semantic_model = torch.jit.load(semantic_model_path, map_location=self.device) self.semantic_model.eval() # Probe to get input/output dims # Try a few common input sizes to find the right one self._semantic_input_dim = None self._semantic_output_dim = None for test_dim in [64, 10112]: try: dummy = torch.randn(1, test_dim, device=self.device) with torch.no_grad(): out = self.semantic_model(dummy) self._semantic_input_dim = test_dim self._semantic_output_dim = out.shape[1] logger.info(f"Semantic model: input_dim={self._semantic_input_dim}, output_dim={self._semantic_output_dim}") break except Exception: continue if self._semantic_input_dim is None: logger.warning("Could not auto-detect semantic model input dim. Will adapt at runtime.") self.log_file = setup_eeg_logger(eeg_file_path) if eeg_file_path else None # Initialize database connections self.nexus_conn = sqlite3.connect(nexus_db_path) self.embeddings_conn = sqlite3.connect(embeddings_db_path) # Message tracking system self.search_k = search_k self.final_k = final_k self.message_counts = defaultdict(int) self.recent_messages = deque(maxlen=10) self.repetition_penalty = 1.5 logger.info("Creating embedding index") self.embedding_index = self._create_index(index_path) self.error_count = 0 self.max_consecutive_errors = 5 self.save_vectors = save_vectors self.vector_output_path = vector_output_path if self.save_vectors: self.vectors_list = [] self.timestamps = [] logger.info(f"Vector saving enabled. Output will be saved to {self.vector_output_path}") self.previous_message_sets = deque(maxlen=self.last_n_messages) def _create_index(self, index_path: str = None) -> EmbeddingIndex: """Create or load the embedding index for similarity search""" cursor = self.embeddings_conn.cursor() cursor.execute("SELECT COUNT(*) FROM embeddings") db_count = cursor.fetchone()[0] cursor.execute("SELECT MAX(message_id) FROM embeddings") db_max_id = cursor.fetchone()[0] if index_path and os.path.exists(f"{index_path}.index"): try: logger.info(f"Checking existing index at {index_path}") index = EmbeddingIndex.load(index_path) metadata_path = f"{index_path}_metadata.npz" if os.path.exists(metadata_path): metadata = np.load(metadata_path, allow_pickle=True) saved_count = int(metadata.get('count', 0)) saved_max_id = int(metadata.get('max_message_id', 0)) logger.info(f"Saved index: {saved_count} items, max_id={saved_max_id}") logger.info(f"Database: {db_count} items, max_id={db_max_id}") if db_count != saved_count or db_max_id != saved_max_id: logger.info("Database has changed. Recreating index...") else: logger.info("Database unchanged. Using existing index...") return index except Exception as e: logger.warning(f"Error checking existing index: {str(e)}") logger.info("Will create new index") logger.info("Creating new index from database...") cursor.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id") embeddings = [] message_ids = [] for message_id, emb in cursor.fetchall(): embedding = np.frombuffer(emb, dtype=np.float32) embeddings.append(embedding) message_ids.append(message_id) if not embeddings: raise ValueError("No embeddings found in database") embeddings = np.vstack(embeddings) logger.info(f"Loaded {len(embeddings)} embeddings with shape: {embeddings.shape}") index = EmbeddingIndex(dim=embeddings.shape[1]) index.add_embeddings(embeddings, message_ids) if index_path: logger.info(f"Saving index to {index_path}") index.save(index_path) metadata = { 'count': db_count, 'max_message_id': db_max_id } np.savez(f"{index_path}_metadata.npz", **metadata) return index def process_eeg_embedding(self, eeg_embedding: np.ndarray) -> torch.Tensor: """Convert EEG embedding to text embedding using the traced semantic model""" with torch.no_grad(): tensor = torch.tensor(eeg_embedding, dtype=torch.float32).to(self.device) if len(tensor.shape) < 2: tensor = tensor.unsqueeze(0) batch_size = tensor.shape[0] tensor = tensor.reshape(batch_size, -1) # Adapt dimensions if needed if self._semantic_input_dim is not None: current_features = tensor.shape[1] if current_features != self._semantic_input_dim: if current_features < self._semantic_input_dim: padded = torch.zeros(batch_size, self._semantic_input_dim, device=self.device) padded[:, :current_features] = tensor tensor = padded else: tensor = tensor[:, :self._semantic_input_dim] return self.semantic_model(tensor) def find_similar_messages(self, embedding: torch.Tensor, assistant_only=False) -> List[str]: """Find similar messages using the embedding index""" embedding_np = embedding.detach().cpu().numpy() if len(embedding_np.shape) > 1: embedding_np = embedding_np.reshape(1, -1) try: distances, indices = self.embedding_index.search(embedding_np, self.search_k) distances = distances.flatten() indices = indices.flatten() cursor = self.nexus_conn.cursor() candidates = [] if assistant_only: query = """ SELECT content FROM messages WHERE id = ? AND role = 'assistant' """ else: query = """ SELECT content FROM messages WHERE id = ? """ for message_id, distance in zip(indices, distances): cursor.execute(query, (int(message_id),)) if result := cursor.fetchone(): content = result[0] candidates.append(content) return candidates[:self.final_k] except Exception as e: logger.error(f"Error during similarity search: {str(e)}") traceback.print_exc() return [] def save_vectors_to_disk(self): """Save the collected vectors and timestamps to disk""" if not self.vectors_list: logger.warning("No vectors to save") return output_dir = os.path.dirname(self.vector_output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) vectors_array = np.vstack(self.vectors_list) timestamps_array = np.array(self.timestamps) logger.info(f"Saving {len(self.vectors_list)} vectors to {self.vector_output_path}") np.savez( self.vector_output_path, vectors=vectors_array, timestamps=timestamps_array ) logger.info(f"Vectors saved successfully to {self.vector_output_path}") def process_streaming_embeddings(self, callback=None): """ Process streaming EEG embeddings through the semantic model and find similar messages. """ self.eeg_stream.start() try: consecutive_errors = 0 while True: try: for embedding_data in self.eeg_stream.get_embeddings(timeout=0.5): try: autoencoder_embedding = embedding_data['embedding'] semantic_embedding = self.process_eeg_embedding(autoencoder_embedding) if self.save_vectors: embedding_np = semantic_embedding.detach().cpu().numpy() self.vectors_list.append(embedding_np) self.timestamps.append({ 'start': embedding_data['start_timestamp'], 'end': embedding_data['end_timestamp'] }) if len(self.vectors_list) % 100 == 0: logger.info(f"Collected {len(self.vectors_list)} vectors so far") continue similar_messages = self.find_similar_messages(semantic_embedding) result = { 'start_timestamp': embedding_data['start_timestamp'], 'end_timestamp': embedding_data['end_timestamp'], 'processing_time': 0, 'similar_messages': similar_messages } if callback: callback(result) else: self._print_unique_lines(result) consecutive_errors = 0 except Exception as e: print(f"Error: {str(e)}", file=sys.stderr) consecutive_errors += 1 if consecutive_errors >= self.max_consecutive_errors: raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})") time.sleep(0.01) except Exception as e: if "Too many consecutive errors" in str(e): raise print(f"Error: {str(e)}", file=sys.stderr) consecutive_errors += 1 if consecutive_errors >= self.max_consecutive_errors: raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})") time.sleep(1) except KeyboardInterrupt: pass except Exception as e: print(f"Fatal error: {str(e)}", file=sys.stderr) finally: if self.save_vectors and self.vectors_list: self.save_vectors_to_disk() self.eeg_stream.stop() def _print_unique_lines(self, result): """Print only lines that aren't in common with the last n batches of messages""" if not result['similar_messages']: return sample_size = min(42, len(result['similar_messages'])) current_messages = random.sample(result['similar_messages'], sample_size) current_lines = set() for message in current_messages: for line in message.splitlines(): line = line.strip() if line: current_lines.add(line) unique_lines = current_lines.copy() for previous_lines in self.previous_message_sets: unique_lines -= previous_lines self.previous_message_sets.append(current_lines) __uniq_log_empty = False if unique_lines: if PRINT_DEBUG_HASH: unique_lines = [f"{hash} | {line}" for (hash, line) in zip( map(lambda s: hashlib.md5(s.encode()).hexdigest()[:8], unique_lines), unique_lines)] unique_lines = filter(lambda s: bool(s), map(fix_encoding, unique_lines)) output_text = "\n".join(sorted(unique_lines)) print(output_text) if hasattr(self, 'log_file') and self.log_file: try: self.log_file.write(output_text + "\n") self.log_file.flush() except Exception as e: print(f"Error writing to log file: {str(e)}", file=sys.stderr) elif __uniq_log_empty: logger.info(f"No unique lines") def main(): parser = argparse.ArgumentParser(description='Process EEG data through semantic model and lookup similar messages') parser.add_argument('--autoencoder', '-a', type=str, required=True, help='Path to the traced autoencoder encoder model') parser.add_argument('--semantic-model', '-s', type=str, required=True, help='Path to the traced semantic model') parser.add_argument('--nexus-db', '-n', type=str, default=os.path.expanduser('~/.nexus/data/nexus-new.db'), help='Path to the nexus database') parser.add_argument('--embeddings-db', '-e', type=str, default='emb_full.db', help='Path to the embeddings database') parser.add_argument('--index', '-i', type=str, default='embedding_index', help='Path to save/load the FAISS index') parser.add_argument('--eeg-file', '-f', type=str, required=True, help='Path to the EEG data file to monitor') parser.add_argument('--window-size', type=int, default=624, help='Window size in samples') parser.add_argument('--stride', type=int, default=32, help='Stride between windows') parser.add_argument('--batch-size', type=int, default=32, help='Batch size for processing') parser.add_argument('--no-normalize', dest='normalize', action='store_false', help='Disable normalization of EEG data') parser.add_argument('--search-k', type=int, default=180, help='Number of candidates to retrieve for selection') parser.add_argument('--final-k', type=int, default=90, help='Number of results to show') parser.add_argument('--device', type=str, default=None, help='Device to use (cuda or cpu)') parser.add_argument('--last_n', type=int, default=None, help='Window queue size for repetition filter') parser.add_argument('--use-raw-eeg', action='store_true', help='Use raw EEG data with semantic model (skip autoencoder)') parser.add_argument('--input-dim', type=int, help='Override the input dimension for the semantic model') parser.add_argument('--save-vectors', action='store_true', help='Save semantic vectors to disk without generating output') parser.add_argument('--vector-output', type=str, default='semantic_vectors.npz', help='Path to save the semantic vectors') args = parser.parse_args() processor = EEGSemanticProcessor( autoencoder_model_path=args.autoencoder, semantic_model_path=args.semantic_model, nexus_db_path=args.nexus_db, embeddings_db_path=args.embeddings_db, index_path=args.index, last_n_messages=args.last_n, eeg_file_path=args.eeg_file, window_size=args.window_size, stride=args.stride, batch_size=args.batch_size, normalize=args.normalize, device=args.device, search_k=args.search_k, final_k=args.final_k, use_raw_eeg=args.use_raw_eeg, input_dim_override=args.input_dim, save_vectors=args.save_vectors, vector_output_path=args.vector_output ) try: processor.process_streaming_embeddings() except KeyboardInterrupt: pass except Exception as e: print(f"Error: {str(e)}", file=sys.stderr) if __name__ == "__main__": main()