| |
| import os |
| import sys |
| import time |
| import hashlib |
| import numpy as np |
| import torch |
| import sqlite3 |
| import logging |
| import argparse |
| import random |
| import traceback |
| import faiss |
| import pickle |
| from datetime import datetime |
| from collections import deque, defaultdict |
| from typing import List, Dict, Tuple, Optional, Union, Any |
| from pathlib import Path |
|
|
| |
| from eegembed import EEGEmbeddingStream |
|
|
| PRINT_DEBUG_HASH = False |
|
|
| def fix_encoding(s): |
| if not s: |
| return s |
|
|
| if isinstance(s, str): |
| b = s.encode('utf-8', 'surrogateescape') |
| else: |
| b = s |
|
|
| fixed = b.decode('utf-8', 'replace') |
| if 'ì' in s or 'í' in s or 'ï' in s: |
| return "" |
|
|
| return fixed |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger("EEGSemanticStream") |
|
|
| def setup_eeg_logger(eeg_file_path): |
| """Set up a file logger based on the EEG filename.""" |
| base_name = os.path.basename(eeg_file_path) |
| file_name = os.path.splitext(base_name)[0] |
|
|
| logs_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_logs") |
| if not os.path.exists(logs_dir): |
| os.makedirs(logs_dir) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| log_file_path = os.path.join(logs_dir, f"{file_name}_{timestamp}.log") |
|
|
| log_file = open(log_file_path, "w", encoding="utf-8") |
| log_file.write(f"--- Session started at {timestamp} for EEG file: {base_name} ---\n") |
| log_file.flush() |
|
|
| return log_file |
|
|
|
|
| class EmbeddingIndex: |
| def __init__(self, dim=1536, use_gpu=True): |
| self.dim = dim |
| self.use_gpu = use_gpu and faiss.get_num_gpus() > 0 |
| self.index = None |
| self.gpu_resources = None |
| self.message_ids = [] |
|
|
| if self.use_gpu: |
| self.gpu_resources = faiss.StandardGpuResources() |
|
|
| def add_embeddings(self, embeddings: np.ndarray, message_ids: List[int]): |
| logger.info(f"Building FAISS index with {len(embeddings)} embeddings") |
|
|
| norms = np.linalg.norm(embeddings, axis=1, keepdims=True) |
| embeddings = embeddings / (norms + 1e-8) |
|
|
| cpu_index = faiss.IndexFlatIP(self.dim) |
| cpu_index.add(embeddings.astype(np.float32)) |
|
|
| if self.use_gpu: |
| try: |
| self.index = faiss.index_cpu_to_gpu(self.gpu_resources, 0, cpu_index) |
| logger.info("Using GPU FAISS index") |
| except Exception as e: |
| logger.warning(f"GPU failed: {e}. Using CPU.") |
| self.index = cpu_index |
| self.use_gpu = False |
| else: |
| self.index = cpu_index |
| logger.info("Using CPU FAISS index") |
|
|
| self.message_ids = message_ids |
|
|
| def get_current_count(self): |
| if self.index is None: |
| return 0 |
| return self.index.ntotal |
|
|
| def search(self, query: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]: |
| if self.index is None: |
| raise RuntimeError("Index not initialized") |
|
|
| norm = np.linalg.norm(query) |
| if norm > 0: |
| query = query / norm |
|
|
| actual_k = min(k, self.get_current_count()) |
| if actual_k == 0: |
| return np.array([]), np.array([]) |
|
|
| similarities, indices = self.index.search(query.astype(np.float32), actual_k) |
| distances = 1.0 - similarities |
|
|
| labels = np.array([[self.message_ids[idx] for idx in row] for row in indices]) |
|
|
| return distances, labels |
|
|
| def save(self, path: str): |
| if self.index is None: |
| raise RuntimeError("Cannot save uninitialized index") |
|
|
| if self.use_gpu: |
| cpu_index = faiss.index_gpu_to_cpu(self.index) |
| faiss.write_index(cpu_index, f"{path}.index") |
| else: |
| faiss.write_index(self.index, f"{path}.index") |
|
|
| with open(f"{path}_message_ids.pkl", 'wb') as f: |
| pickle.dump(self.message_ids, f) |
|
|
| @classmethod |
| def load(cls, path: str, use_gpu: bool = True) -> 'EmbeddingIndex': |
| with open(f"{path}_message_ids.pkl", 'rb') as f: |
| message_ids = pickle.load(f) |
|
|
| index = cls(use_gpu=use_gpu) |
| index.message_ids = message_ids |
|
|
| cpu_index = faiss.read_index(f"{path}.index") |
|
|
| if index.use_gpu: |
| try: |
| index.index = faiss.index_cpu_to_gpu(index.gpu_resources, 0, cpu_index) |
| logger.info("Loaded existing index and moved to GPU") |
| except Exception as e: |
| logger.warning(f"Failed to move loaded index to GPU: {e}. Using CPU.") |
| index.index = cpu_index |
| index.use_gpu = False |
| else: |
| index.index = cpu_index |
| logger.info("Loaded existing index on CPU") |
|
|
| return index |
|
|
|
|
| class EEGSemanticProcessor: |
| """ |
| Process EEG data through autoencoder and semantic model pipeline, |
| then lookup similar messages. |
| """ |
| def __init__( |
| self, |
| autoencoder_model_path: str, |
| semantic_model_path: str, |
| nexus_db_path: str, |
| embeddings_db_path: str, |
| index_path: str = None, |
| eeg_file_path: str = None, |
| window_size: int = 624, |
| stride: int = 64, |
| batch_size: int = 32, |
| normalize: bool = True, |
| device: str = None, |
| search_k: int = 180, |
| final_k: int = 90, |
| use_raw_eeg: bool = False, |
| last_n_messages: int = 3, |
| input_dim_override: int = None, |
| save_vectors: bool = False, |
| vector_output_path: str = None |
| ): |
| if device is None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| self.device = torch.device(device) |
| |
| logger.info(f"Using device: {self.device}") |
| |
| self.last_n_messages = last_n_messages |
| self.use_raw_eeg = use_raw_eeg |
| self.input_dim_override = input_dim_override |
| |
| |
| self.eeg_stream = EEGEmbeddingStream( |
| file_path=eeg_file_path if eeg_file_path else "", |
| model_path=autoencoder_model_path, |
| window_size=window_size, |
| stride=stride, |
| normalize=normalize, |
| batch_size=batch_size, |
| device=self.device |
| ) |
| |
| |
| logger.info(f"Loading traced semantic model from {semantic_model_path}") |
| self.semantic_model = torch.jit.load(semantic_model_path, map_location=self.device) |
| self.semantic_model.eval() |
| |
| |
| |
| self._semantic_input_dim = None |
| self._semantic_output_dim = None |
| for test_dim in [64, 10112]: |
| try: |
| dummy = torch.randn(1, test_dim, device=self.device) |
| with torch.no_grad(): |
| out = self.semantic_model(dummy) |
| self._semantic_input_dim = test_dim |
| self._semantic_output_dim = out.shape[1] |
| logger.info(f"Semantic model: input_dim={self._semantic_input_dim}, output_dim={self._semantic_output_dim}") |
| break |
| except Exception: |
| continue |
| |
| if self._semantic_input_dim is None: |
| logger.warning("Could not auto-detect semantic model input dim. Will adapt at runtime.") |
|
|
| self.log_file = setup_eeg_logger(eeg_file_path) if eeg_file_path else None |
| |
| |
| self.nexus_conn = sqlite3.connect(nexus_db_path) |
| self.embeddings_conn = sqlite3.connect(embeddings_db_path) |
| |
| |
| self.search_k = search_k |
| self.final_k = final_k |
| self.message_counts = defaultdict(int) |
| self.recent_messages = deque(maxlen=10) |
| self.repetition_penalty = 1.5 |
|
|
| logger.info("Creating embedding index") |
| self.embedding_index = self._create_index(index_path) |
| |
| self.error_count = 0 |
| self.max_consecutive_errors = 5 |
| |
| self.save_vectors = save_vectors |
| self.vector_output_path = vector_output_path |
| |
| if self.save_vectors: |
| self.vectors_list = [] |
| self.timestamps = [] |
| logger.info(f"Vector saving enabled. Output will be saved to {self.vector_output_path}") |
| |
| self.previous_message_sets = deque(maxlen=self.last_n_messages) |
|
|
| def _create_index(self, index_path: str = None) -> EmbeddingIndex: |
| """Create or load the embedding index for similarity search""" |
| |
| cursor = self.embeddings_conn.cursor() |
| |
| cursor.execute("SELECT COUNT(*) FROM embeddings") |
| db_count = cursor.fetchone()[0] |
| |
| cursor.execute("SELECT MAX(message_id) FROM embeddings") |
| db_max_id = cursor.fetchone()[0] |
| |
| if index_path and os.path.exists(f"{index_path}.index"): |
| try: |
| logger.info(f"Checking existing index at {index_path}") |
| |
| index = EmbeddingIndex.load(index_path) |
| |
| metadata_path = f"{index_path}_metadata.npz" |
| if os.path.exists(metadata_path): |
| metadata = np.load(metadata_path, allow_pickle=True) |
| saved_count = int(metadata.get('count', 0)) |
| saved_max_id = int(metadata.get('max_message_id', 0)) |
| |
| logger.info(f"Saved index: {saved_count} items, max_id={saved_max_id}") |
| logger.info(f"Database: {db_count} items, max_id={db_max_id}") |
| |
| if db_count != saved_count or db_max_id != saved_max_id: |
| logger.info("Database has changed. Recreating index...") |
| else: |
| logger.info("Database unchanged. Using existing index...") |
| return index |
| |
| except Exception as e: |
| logger.warning(f"Error checking existing index: {str(e)}") |
| logger.info("Will create new index") |
| |
| logger.info("Creating new index from database...") |
| |
| cursor.execute("SELECT message_id, embedding FROM embeddings ORDER BY message_id") |
| |
| embeddings = [] |
| message_ids = [] |
| |
| for message_id, emb in cursor.fetchall(): |
| embedding = np.frombuffer(emb, dtype=np.float32) |
| embeddings.append(embedding) |
| message_ids.append(message_id) |
| |
| if not embeddings: |
| raise ValueError("No embeddings found in database") |
| |
| embeddings = np.vstack(embeddings) |
| logger.info(f"Loaded {len(embeddings)} embeddings with shape: {embeddings.shape}") |
| |
| index = EmbeddingIndex(dim=embeddings.shape[1]) |
| index.add_embeddings(embeddings, message_ids) |
| |
| if index_path: |
| logger.info(f"Saving index to {index_path}") |
| index.save(index_path) |
| |
| metadata = { |
| 'count': db_count, |
| 'max_message_id': db_max_id |
| } |
| np.savez(f"{index_path}_metadata.npz", **metadata) |
| |
| return index |
|
|
| def process_eeg_embedding(self, eeg_embedding: np.ndarray) -> torch.Tensor: |
| """Convert EEG embedding to text embedding using the traced semantic model""" |
| with torch.no_grad(): |
| tensor = torch.tensor(eeg_embedding, dtype=torch.float32).to(self.device) |
| |
| if len(tensor.shape) < 2: |
| tensor = tensor.unsqueeze(0) |
| |
| batch_size = tensor.shape[0] |
| tensor = tensor.reshape(batch_size, -1) |
| |
| |
| if self._semantic_input_dim is not None: |
| current_features = tensor.shape[1] |
| if current_features != self._semantic_input_dim: |
| if current_features < self._semantic_input_dim: |
| padded = torch.zeros(batch_size, self._semantic_input_dim, device=self.device) |
| padded[:, :current_features] = tensor |
| tensor = padded |
| else: |
| tensor = tensor[:, :self._semantic_input_dim] |
| |
| return self.semantic_model(tensor) |
| |
| def find_similar_messages(self, embedding: torch.Tensor, assistant_only=False) -> List[str]: |
| """Find similar messages using the embedding index""" |
| embedding_np = embedding.detach().cpu().numpy() |
| if len(embedding_np.shape) > 1: |
| embedding_np = embedding_np.reshape(1, -1) |
|
|
| try: |
| distances, indices = self.embedding_index.search(embedding_np, self.search_k) |
| distances = distances.flatten() |
| indices = indices.flatten() |
|
|
| cursor = self.nexus_conn.cursor() |
| candidates = [] |
|
|
| if assistant_only: |
| query = """ |
| SELECT content FROM messages |
| WHERE id = ? AND role = 'assistant' |
| """ |
| else: |
| query = """ |
| SELECT content FROM messages |
| WHERE id = ? |
| """ |
| |
| for message_id, distance in zip(indices, distances): |
| cursor.execute(query, (int(message_id),)) |
| if result := cursor.fetchone(): |
| content = result[0] |
| candidates.append(content) |
| |
| return candidates[:self.final_k] |
|
|
| except Exception as e: |
| logger.error(f"Error during similarity search: {str(e)}") |
| traceback.print_exc() |
| return [] |
| |
| def save_vectors_to_disk(self): |
| """Save the collected vectors and timestamps to disk""" |
| if not self.vectors_list: |
| logger.warning("No vectors to save") |
| return |
| |
| output_dir = os.path.dirname(self.vector_output_path) |
| if output_dir and not os.path.exists(output_dir): |
| os.makedirs(output_dir) |
| |
| vectors_array = np.vstack(self.vectors_list) |
| timestamps_array = np.array(self.timestamps) |
| |
| logger.info(f"Saving {len(self.vectors_list)} vectors to {self.vector_output_path}") |
| np.savez( |
| self.vector_output_path, |
| vectors=vectors_array, |
| timestamps=timestamps_array |
| ) |
| logger.info(f"Vectors saved successfully to {self.vector_output_path}") |
| |
| def process_streaming_embeddings(self, callback=None): |
| """ |
| Process streaming EEG embeddings through the semantic model |
| and find similar messages. |
| """ |
| self.eeg_stream.start() |
| |
| try: |
| consecutive_errors = 0 |
| while True: |
| try: |
| for embedding_data in self.eeg_stream.get_embeddings(timeout=0.5): |
| try: |
| autoencoder_embedding = embedding_data['embedding'] |
| semantic_embedding = self.process_eeg_embedding(autoencoder_embedding) |
| |
| if self.save_vectors: |
| embedding_np = semantic_embedding.detach().cpu().numpy() |
| self.vectors_list.append(embedding_np) |
| self.timestamps.append({ |
| 'start': embedding_data['start_timestamp'], |
| 'end': embedding_data['end_timestamp'] |
| }) |
| |
| if len(self.vectors_list) % 100 == 0: |
| logger.info(f"Collected {len(self.vectors_list)} vectors so far") |
| |
| continue |
| |
| similar_messages = self.find_similar_messages(semantic_embedding) |
| |
| result = { |
| 'start_timestamp': embedding_data['start_timestamp'], |
| 'end_timestamp': embedding_data['end_timestamp'], |
| 'processing_time': 0, |
| 'similar_messages': similar_messages |
| } |
| |
| if callback: |
| callback(result) |
| else: |
| self._print_unique_lines(result) |
| |
| consecutive_errors = 0 |
| |
| except Exception as e: |
| print(f"Error: {str(e)}", file=sys.stderr) |
| consecutive_errors += 1 |
| |
| if consecutive_errors >= self.max_consecutive_errors: |
| raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})") |
| |
| time.sleep(0.01) |
| |
| except Exception as e: |
| if "Too many consecutive errors" in str(e): |
| raise |
| print(f"Error: {str(e)}", file=sys.stderr) |
| consecutive_errors += 1 |
| if consecutive_errors >= self.max_consecutive_errors: |
| raise RuntimeError(f"Too many consecutive errors ({consecutive_errors})") |
| time.sleep(1) |
| |
| except KeyboardInterrupt: |
| pass |
| except Exception as e: |
| print(f"Fatal error: {str(e)}", file=sys.stderr) |
| finally: |
| if self.save_vectors and self.vectors_list: |
| self.save_vectors_to_disk() |
| |
| self.eeg_stream.stop() |
| |
| def _print_unique_lines(self, result): |
| """Print only lines that aren't in common with the last n batches of messages""" |
| if not result['similar_messages']: |
| return |
| |
| sample_size = min(42, len(result['similar_messages'])) |
| current_messages = random.sample(result['similar_messages'], sample_size) |
| |
| current_lines = set() |
| for message in current_messages: |
| for line in message.splitlines(): |
| line = line.strip() |
| if line: |
| current_lines.add(line) |
| |
| unique_lines = current_lines.copy() |
| for previous_lines in self.previous_message_sets: |
| unique_lines -= previous_lines |
| |
| self.previous_message_sets.append(current_lines) |
| |
| __uniq_log_empty = False |
| if unique_lines: |
| if PRINT_DEBUG_HASH: |
| unique_lines = [f"{hash} | {line}" for (hash, line) in zip( |
| map(lambda s: hashlib.md5(s.encode()).hexdigest()[:8], unique_lines), |
| unique_lines)] |
|
|
| unique_lines = filter(lambda s: bool(s), map(fix_encoding, unique_lines)) |
| |
| output_text = "\n".join(sorted(unique_lines)) |
| print(output_text) |
| |
| if hasattr(self, 'log_file') and self.log_file: |
| try: |
| self.log_file.write(output_text + "\n") |
| self.log_file.flush() |
| except Exception as e: |
| print(f"Error writing to log file: {str(e)}", file=sys.stderr) |
| |
| elif __uniq_log_empty: |
| logger.info(f"No unique lines") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Process EEG data through semantic model and lookup similar messages') |
| |
| parser.add_argument('--autoencoder', '-a', type=str, required=True, |
| help='Path to the traced autoencoder encoder model') |
| parser.add_argument('--semantic-model', '-s', type=str, required=True, |
| help='Path to the traced semantic model') |
| |
| parser.add_argument('--nexus-db', '-n', type=str, |
| default=os.path.expanduser('~/.nexus/data/nexus-new.db'), |
| help='Path to the nexus database') |
| parser.add_argument('--embeddings-db', '-e', type=str, default='emb_full.db', |
| help='Path to the embeddings database') |
| parser.add_argument('--index', '-i', type=str, default='embedding_index', |
| help='Path to save/load the FAISS index') |
| |
| parser.add_argument('--eeg-file', '-f', type=str, required=True, |
| help='Path to the EEG data file to monitor') |
| parser.add_argument('--window-size', type=int, default=624, |
| help='Window size in samples') |
| parser.add_argument('--stride', type=int, default=32, |
| help='Stride between windows') |
| parser.add_argument('--batch-size', type=int, default=32, |
| help='Batch size for processing') |
| parser.add_argument('--no-normalize', dest='normalize', action='store_false', |
| help='Disable normalization of EEG data') |
| |
| parser.add_argument('--search-k', type=int, default=180, |
| help='Number of candidates to retrieve for selection') |
| parser.add_argument('--final-k', type=int, default=90, |
| help='Number of results to show') |
| |
| parser.add_argument('--device', type=str, default=None, |
| help='Device to use (cuda or cpu)') |
|
|
| parser.add_argument('--last_n', type=int, default=None, |
| help='Window queue size for repetition filter') |
| |
| parser.add_argument('--use-raw-eeg', action='store_true', |
| help='Use raw EEG data with semantic model (skip autoencoder)') |
| parser.add_argument('--input-dim', type=int, |
| help='Override the input dimension for the semantic model') |
| |
| parser.add_argument('--save-vectors', action='store_true', |
| help='Save semantic vectors to disk without generating output') |
| parser.add_argument('--vector-output', type=str, default='semantic_vectors.npz', |
| help='Path to save the semantic vectors') |
| |
| args = parser.parse_args() |
| |
| processor = EEGSemanticProcessor( |
| autoencoder_model_path=args.autoencoder, |
| semantic_model_path=args.semantic_model, |
| nexus_db_path=args.nexus_db, |
| embeddings_db_path=args.embeddings_db, |
| index_path=args.index, |
| last_n_messages=args.last_n, |
| eeg_file_path=args.eeg_file, |
| window_size=args.window_size, |
| stride=args.stride, |
| batch_size=args.batch_size, |
| normalize=args.normalize, |
| device=args.device, |
| search_k=args.search_k, |
| final_k=args.final_k, |
| use_raw_eeg=args.use_raw_eeg, |
| input_dim_override=args.input_dim, |
| save_vectors=args.save_vectors, |
| vector_output_path=args.vector_output |
| ) |
| |
| try: |
| processor.process_streaming_embeddings() |
| except KeyboardInterrupt: |
| pass |
| except Exception as e: |
| print(f"Error: {str(e)}", file=sys.stderr) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|