| |
| """ |
| Text embedding script with SQLite storage (using numpy buffers) |
| Now with flexible text splitting modes! |
| |
| Usage: python embed_flex.py <directory_path> <db_path> [--split-mode MODE] |
| |
| Split modes: |
| - line (default): Each non-empty line becomes one embedding |
| - block: Double-newline separated blocks (paragraphs) |
| - sentence: Split on sentence boundaries (., !, ?) |
| - chunk: Fixed token-ish chunks with overlap (for long docs) |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import sqlite3 |
| import numpy as np |
| from tqdm import tqdm |
| from transformers import AutoModel, AutoTokenizer |
| import torch |
| import gc |
| import random |
| import re |
|
|
| INITIAL_BATCH_SIZE = 128 |
| MIN_BATCH_SIZE = 1 |
| SHUFFLE_SEED = 42 |
|
|
| |
| DEFAULT_CHUNK_SIZE = 512 |
| DEFAULT_CHUNK_OVERLAP = 64 |
|
|
|
|
| def create_index_if_possible(cursor): |
| try: |
| cursor.execute(""" |
| CREATE INDEX IF NOT EXISTS idx_content ON messages(content) |
| """) |
| except sqlite3.OperationalError: |
| pass |
|
|
|
|
| def get_existing_content(cursor): |
| try: |
| cursor.execute("SELECT content FROM messages") |
| return {row[0] for row in cursor.fetchall()} |
| except sqlite3.OperationalError: |
| return set() |
|
|
|
|
| def clear_gpu_memory(): |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
|
|
| |
| |
| |
|
|
| def split_by_lines(text): |
| """Original behavior: each non-empty line is one unit.""" |
| lines = [] |
| for line in text.split('\n'): |
| line = line.strip() |
| if line: |
| lines.append(line) |
| return lines |
|
|
|
|
| def split_by_blocks(text): |
| blocks = re.split(r'\n\s*\n+', text) |
| result = [] |
| for block in blocks: |
| cleaned = ' '.join(block.split()) |
| if cleaned: |
| result.append(cleaned) |
| return result |
|
|
|
|
| def split_by_sentences(text): |
| """ |
| Split on sentence boundaries. |
| Handles common abbreviations somewhat gracefully. |
| """ |
| |
| text = ' '.join(text.split()) |
| |
| |
| |
| pattern = r'(?<=[.!?])\s+(?=[A-Z])' |
| |
| sentences = re.split(pattern, text) |
| result = [] |
| for sent in sentences: |
| sent = sent.strip() |
| if sent: |
| result.append(sent) |
| return result |
|
|
|
|
| def split_by_chunks(text, chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_CHUNK_OVERLAP): |
| """ |
| Fixed-size character chunks with overlap. |
| Good for long documents where you want sliding window coverage. |
| """ |
| |
| text = ' '.join(text.split()) |
| |
| if len(text) <= chunk_size: |
| return [text] if text else [] |
| |
| chunks = [] |
| start = 0 |
| while start < len(text): |
| end = start + chunk_size |
| chunk = text[start:end] |
| |
| |
| if end < len(text): |
| last_space = chunk.rfind(' ') |
| if last_space > chunk_size // 2: |
| chunk = chunk[:last_space] |
| end = start + last_space |
| |
| chunk = chunk.strip() |
| if chunk: |
| chunks.append(chunk) |
| |
| |
| start = end - overlap |
| if start <= chunks[-1] if chunks else 0: |
| start = end |
| |
| return chunks |
|
|
|
|
| def get_splitter(mode, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP): |
| """Return the appropriate splitting function.""" |
| if mode == 'line': |
| return split_by_lines |
| elif mode == 'block': |
| return split_by_blocks |
| elif mode == 'sentence': |
| return split_by_sentences |
| elif mode == 'chunk': |
| return lambda text: split_by_chunks(text, chunk_size, chunk_overlap) |
| else: |
| raise ValueError(f"Unknown split mode: {mode}") |
|
|
|
|
| |
| |
| |
|
|
| def process_batch(model, batch_lines, cursor, task="text-matching"): |
| try: |
| with torch.no_grad(): |
| batch_embeddings = model.encode(batch_lines, task=task, device="cuda") |
| |
| for line_text, embedding in zip(batch_lines, batch_embeddings): |
| try: |
| cursor.execute( |
| "INSERT INTO messages (content, role) VALUES (?, ?)", |
| (line_text, "system") |
| ) |
| message_id = cursor.lastrowid |
| |
| if torch.is_tensor(embedding): |
| embedding_np = embedding.cpu().numpy() |
| elif not isinstance(embedding, np.ndarray): |
| embedding_np = np.array(embedding) |
| else: |
| embedding_np = embedding |
| |
| embedding_blob = embedding_np.astype(np.float32).tobytes() |
| |
| cursor.execute( |
| "INSERT INTO embeddings (message_id, embedding) VALUES (?, ?)", |
| (message_id, embedding_blob) |
| ) |
| except sqlite3.Error as e: |
| print(f"Error processing entry: {e}") |
| continue |
| |
| return True |
| |
| except (torch.cuda.OutOfMemoryError, RuntimeError) as e: |
| if "out of memory" in str(e).lower(): |
| clear_gpu_memory() |
| return False |
| else: |
| raise |
|
|
|
|
| def convert_existing_pickles(cursor, conn): |
| """Convert any existing pickle embeddings to numpy buffers""" |
| import pickle |
| |
| def is_numpy_buffer(blob): |
| try: |
| np_array = np.frombuffer(blob, dtype=np.float32) |
| if np_array.ndim >= 1 and len(np_array) > 0: |
| return True |
| except Exception: |
| pass |
| return False |
| |
| def unpickle_to_numpy(blob): |
| try: |
| pickled_obj = pickle.loads(blob) |
| if isinstance(pickled_obj, np.ndarray): |
| return pickled_obj |
| elif torch.is_tensor(pickled_obj): |
| return pickled_obj.cpu().numpy() |
| else: |
| return np.array(pickled_obj) |
| except Exception: |
| return None |
| |
| cursor.execute("SELECT COUNT(*) FROM embeddings") |
| total_embeddings = cursor.fetchone()[0] |
| |
| if total_embeddings == 0: |
| return |
| |
| print(f"Checking {total_embeddings} existing embeddings for pickle->numpy conversion...") |
| |
| cursor.execute("SELECT message_id, embedding FROM embeddings") |
| embeddings = cursor.fetchall() |
| |
| converted_count = 0 |
| for message_id, embedding_blob in embeddings: |
| if not is_numpy_buffer(embedding_blob): |
| numpy_array = unpickle_to_numpy(embedding_blob) |
| |
| if numpy_array is not None: |
| np_buffer = numpy_array.astype(np.float32).tobytes() |
| cursor.execute( |
| "UPDATE embeddings SET embedding = ? WHERE message_id = ?", |
| (np_buffer, message_id) |
| ) |
| converted_count += 1 |
| |
| if converted_count > 0: |
| conn.commit() |
| print(f"Converted {converted_count} pickle embeddings to numpy buffers") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description='Generate embeddings for text files with flexible splitting modes', |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Split Modes: |
| line Each non-empty line = one embedding (default, original behavior) |
| block Double-newline separated paragraphs = one embedding each |
| sentence Split on sentence boundaries (., !, ?) |
| chunk Fixed-size character chunks with overlap (good for long docs) |
| |
| Examples: |
| %(prog)s ~/docs embeddings.db # line mode (default) |
| %(prog)s ~/docs embeddings.db --split-mode block # paragraph mode |
| %(prog)s ~/docs embeddings.db --split-mode sentence # sentence mode |
| %(prog)s ~/docs embeddings.db --split-mode chunk --chunk-size 1024 --chunk-overlap 128 |
| """ |
| ) |
| |
| parser.add_argument('directory', |
| help='Directory containing .txt files to process') |
| parser.add_argument('database', |
| help='SQLite database path (will be created if not exists)') |
| parser.add_argument('--split-mode', '-s', |
| choices=['line', 'block', 'sentence', 'chunk'], |
| default='line', |
| help='Text splitting strategy (default: line)') |
| parser.add_argument('--chunk-size', type=int, default=DEFAULT_CHUNK_SIZE, |
| help=f'Character chunk size for chunk mode (default: {DEFAULT_CHUNK_SIZE})') |
| parser.add_argument('--chunk-overlap', type=int, default=DEFAULT_CHUNK_OVERLAP, |
| help=f'Overlap between chunks (default: {DEFAULT_CHUNK_OVERLAP})') |
| parser.add_argument('--batch-size', type=int, default=INITIAL_BATCH_SIZE, |
| help=f'Initial batch size (default: {INITIAL_BATCH_SIZE})') |
| parser.add_argument('--task', default='text-matching', |
| help='Encoding task (default: text-matching)') |
| parser.add_argument('--model', default='jinaai/jina-embeddings-v3', |
| help='Model name (default: jinaai/jina-embeddings-v3)') |
| parser.add_argument('--skip-conversion', action='store_true', |
| help='Skip checking/converting existing pickle embeddings') |
| |
| args = parser.parse_args() |
| |
| directory_path = os.path.expanduser(args.directory) |
| db_path = os.path.expanduser(args.database) |
| |
| if not os.path.isdir(directory_path): |
| print(f"Error: Directory '{directory_path}' does not exist") |
| sys.exit(1) |
| |
| print(f"Processing directory: {directory_path}") |
| print(f"Database: {db_path}") |
| print(f"Split mode: {args.split_mode}") |
| if args.split_mode == 'chunk': |
| print(f"Chunk size: {args.chunk_size}, overlap: {args.chunk_overlap}") |
| print(f"Initial batch size: {args.batch_size}") |
| |
| |
| splitter = get_splitter(args.split_mode, args.chunk_size, args.chunk_overlap) |
| |
| |
| print(f"Loading model: {args.model}") |
| model = AutoModel.from_pretrained(args.model, trust_remote_code=True).cuda() |
| model.eval() |
| |
| |
| conn = sqlite3.connect(db_path) |
| cursor = conn.cursor() |
| |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS messages ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| content TEXT, |
| role TEXT |
| ) |
| """) |
| |
| cursor.execute(""" |
| CREATE TABLE IF NOT EXISTS embeddings ( |
| message_id INTEGER PRIMARY KEY, |
| embedding BLOB, |
| FOREIGN KEY (message_id) REFERENCES messages(message_id) ON DELETE CASCADE |
| ) |
| """) |
| conn.commit() |
| |
| create_index_if_possible(cursor) |
| conn.commit() |
| |
| if not args.skip_conversion: |
| convert_existing_pickles(cursor, conn) |
| |
| existing_content = get_existing_content(cursor) |
| print(f"Already processed: {len(existing_content)} entries") |
| |
| |
| all_units = [] |
| txt_files = [f for f in os.listdir(directory_path) if f.lower().endswith(".txt")] |
| |
| if not txt_files: |
| print(f"Warning: No .txt files found in {directory_path}") |
| conn.close() |
| return |
| |
| print(f"Found {len(txt_files)} .txt files") |
| |
| for filename in txt_files: |
| filepath = os.path.join(directory_path, filename) |
| with open(filepath, "r", encoding="utf-8", errors="ignore") as f: |
| content = f.read() |
| units = splitter(content) |
| all_units.extend(units) |
| |
| print(f"Total units from source ({args.split_mode} mode): {len(all_units)}") |
| |
| |
| random.seed(SHUFFLE_SEED) |
| random.shuffle(all_units) |
| |
| |
| new_units = [u for u in all_units if u not in existing_content] |
| |
| print(f"Remaining to process: {len(new_units)}") |
| |
| if not new_units: |
| print("Nothing new to process.") |
| conn.close() |
| return |
| |
| |
| batch_size = args.batch_size |
| total = len(new_units) |
| task = args.task |
| |
| idx = 0 |
| processed_count = 0 |
| |
| with tqdm(total=total, desc="Processing") as pbar: |
| while idx < total: |
| end_idx = min(idx + batch_size, total) |
| batch = new_units[idx:end_idx] |
| |
| success = process_batch(model, batch, cursor, task) |
| |
| if success: |
| try: |
| conn.commit() |
| except sqlite3.Error as e: |
| print(f"Error committing batch: {e}") |
| |
| batch_processed = len(batch) |
| pbar.update(batch_processed) |
| processed_count += batch_processed |
| idx = end_idx |
| |
| if batch_size < args.batch_size and processed_count % (batch_size * 10) == 0: |
| batch_size = min(batch_size * 2, args.batch_size) |
| else: |
| if batch_size > MIN_BATCH_SIZE: |
| batch_size = max(batch_size // 2, MIN_BATCH_SIZE) |
| print(f"\nOOM - batch size -> {batch_size}") |
| else: |
| print(f"\nSkipping: {batch[0][:100]}...") |
| idx += 1 |
| pbar.update(1) |
| processed_count += 1 |
| |
| conn.close() |
| print(f"\nProcessed {processed_count:,} entries total.") |
| print("All embeddings stored as numpy buffers (float32).") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|