| | import json
|
| | import torch
|
| | from sentence_transformers import SentenceTransformer
|
| | from chromadb import Client, Settings, EmbeddingFunction
|
| | from tqdm import tqdm
|
| | import numpy as np
|
| | import os
|
| | import psutil
|
| | import time
|
| | import hashlib
|
| | from datetime import datetime
|
| | from concurrent.futures import ThreadPoolExecutor
|
| | from typing import List, Dict, Any
|
| |
|
| |
|
| | CHROMA_URI = "./Data/database"
|
| | EMBEDDING_MODEL_NAME = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
|
| | BATCH_SIZE = 1024
|
| | VECTOR_DIM = 768
|
| | INSERT_BATCH_SIZE = 1024
|
| | EMBEDDINGS_DIR = "./Data/Embeddings"
|
| |
|
| | class BioEmbeddingFunction(EmbeddingFunction):
|
| | def __init__(self):
|
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
| | self.model.to(self.device)
|
| |
|
| | def __call__(self, input: list[str]) -> list[list[float]]:
|
| | embeddings = self.model.encode(
|
| | input,
|
| | normalize_embeddings=True,
|
| | convert_to_numpy=True
|
| | )
|
| | return embeddings.tolist()
|
| |
|
| |
|
| | client = Client(
|
| | Settings(
|
| | persist_directory=CHROMA_URI,
|
| | anonymized_telemetry=False,
|
| | is_persistent=True
|
| | )
|
| | )
|
| |
|
| |
|
| | embedding_function = BioEmbeddingFunction()
|
| | model = embedding_function.model
|
| |
|
| | def get_memory_usage():
|
| | process = psutil.Process(os.getpid())
|
| | return process.memory_info().rss / 1024 / 1024
|
| |
|
| | def format_time(seconds):
|
| | return time.strftime('%H:%M:%S', time.gmtime(seconds))
|
| |
|
| | def batch_embed(texts):
|
| | """
|
| | 使用sentence-transformers进行批量文本嵌入
|
| | """
|
| |
|
| | embeddings = []
|
| |
|
| | for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="生成文本向量"):
|
| | batch_texts = texts[i:i + BATCH_SIZE]
|
| | batch_embeddings = model.encode(
|
| | batch_texts,
|
| | batch_size=BATCH_SIZE,
|
| | show_progress_bar=False,
|
| | convert_to_numpy=True,
|
| | normalize_embeddings=True
|
| | )
|
| | embeddings.append(batch_embeddings)
|
| |
|
| | return np.concatenate(embeddings, axis=0)
|
| |
|
| | def parallel_upsert(collection, start_idx: int, end_idx: int,
|
| | documents: List[str], embeddings: np.ndarray,
|
| | metadatas: List[Dict[str, Any]]) -> None:
|
| | """
|
| | 使用add而不是upsert,因为我们使用的是临时内存模式
|
| | """
|
| | batch_ids = [str(j) for j in range(start_idx, end_idx)]
|
| | batch_embeddings = embeddings[start_idx:end_idx].tolist()
|
| | batch_metadatas = metadatas[start_idx:end_idx]
|
| | batch_documents = documents[start_idx:end_idx]
|
| |
|
| | collection.add(
|
| | ids=batch_ids,
|
| | embeddings=batch_embeddings,
|
| | metadatas=batch_metadatas,
|
| | documents=batch_documents
|
| | )
|
| |
|
| | def calculate_data_hash(documents: List[str]) -> str:
|
| | """
|
| | 计算文档列表的哈希值,用于验证数据是否改变
|
| | """
|
| | combined_text = "".join(documents)
|
| | return hashlib.md5(combined_text.encode()).hexdigest()
|
| |
|
| | def save_embeddings(embeddings: np.ndarray, data_hash: str):
|
| | """
|
| | 保存embeddings到文件
|
| | """
|
| | os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
|
| | embedding_path = os.path.join(EMBEDDINGS_DIR, f"embeddings_{data_hash}.npy")
|
| | np.save(embedding_path, embeddings)
|
| | print(f"Embeddings已保存到: {embedding_path}")
|
| |
|
| | def load_embeddings(data_hash: str) -> np.ndarray:
|
| | """
|
| | 从文件加载embeddings
|
| | """
|
| | embedding_path = os.path.join(EMBEDDINGS_DIR, f"embeddings_{data_hash}.npy")
|
| | if os.path.exists(embedding_path):
|
| | return np.load(embedding_path)
|
| | return None
|
| |
|
| | def vectorize_data(documents, embeddings, metadatas):
|
| | collection = client.get_or_create_collection(
|
| | name="healthcare_qa",
|
| | embedding_function=embedding_function
|
| | )
|
| | PERSIST_BATCH_SIZE = 5000
|
| | total_records = len(documents)
|
| |
|
| | with tqdm(total=total_records, desc="持久化进度") as pbar:
|
| | for i in range(0, total_records, PERSIST_BATCH_SIZE):
|
| | end_idx = min(i + PERSIST_BATCH_SIZE, total_records)
|
| |
|
| | batch_ids = [str(j) for j in range(i, end_idx)]
|
| | batch_embeddings = embeddings[i:end_idx]
|
| | batch_documents = documents[i:end_idx]
|
| | batch_metadatas = metadatas[i:end_idx]
|
| |
|
| | collection.upsert(
|
| | ids=batch_ids,
|
| | embeddings=batch_embeddings,
|
| | documents=batch_documents,
|
| | metadatas=batch_metadatas
|
| | )
|
| |
|
| | pbar.update(end_idx - i)
|
| |
|
| | return collection
|
| |
|
| | if __name__ == "__main__":
|
| | start_time = time.time()
|
| | print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] 开始向量化处理...")
|
| | print(f"使用设备: {model.device}")
|
| | print(f"初始内存使用: {get_memory_usage():.2f} MB")
|
| |
|
| |
|
| | os.makedirs(CHROMA_URI, exist_ok=True)
|
| | os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
|
| |
|
| |
|
| | print("\n[1/5] 加载数据文件...")
|
| | loading_start = time.time()
|
| | with open("./Data/Processed/keywords/keyword_index.json") as f:
|
| | keyword_index = json.load(f)
|
| | with open("./Data/Processed/cleaned_qa/qa_database.json") as f:
|
| | qa_database = json.load(f)
|
| | print(f"数据加载完成,用时: {format_time(time.time() - loading_start)}")
|
| | print(f"当前内存使用: {get_memory_usage():.2f} MB")
|
| |
|
| |
|
| | print("\n[2/5] 处理文档数据...")
|
| | documents = []
|
| | metadatas = []
|
| |
|
| |
|
| | print("建立QA索引映射...")
|
| | qa_map = {qa["id"]: qa for qa in qa_database}
|
| |
|
| |
|
| | total_items = sum(len(item_ids) for item_ids in keyword_index.values())
|
| | with tqdm(total=total_items, desc="处理文档") as pbar:
|
| | for source, item_ids in keyword_index.items():
|
| | for item_id in item_ids:
|
| | qa = qa_map.get(item_id)
|
| | if not qa:
|
| | pbar.update(1)
|
| | continue
|
| |
|
| | combined_text = f"Question: {qa['question']}\nAnswer: {qa['answer']}\nKeywords: {', '.join(qa.get('keywords', []))}"
|
| |
|
| | metadata = {
|
| | "source": source,
|
| | "item_id": item_id,
|
| | "keywords": ", ".join(qa.get("keywords", [])),
|
| | "type": "qa"
|
| | }
|
| |
|
| | documents.append(combined_text)
|
| | metadatas.append(metadata)
|
| | pbar.update(1)
|
| |
|
| | print(f"文档处理完成,共处理 {len(documents)} 条记录")
|
| | print(f"当前内存使用: {get_memory_usage():.2f} MB")
|
| |
|
| | if 0:
|
| | documents = documents[:1000]
|
| | metadatas = metadatas[:1000]
|
| |
|
| |
|
| | print("\n[3/5] 生成文本向量...")
|
| | vector_start = time.time()
|
| |
|
| |
|
| | data_hash = calculate_data_hash(documents)
|
| |
|
| |
|
| | embeddings = load_embeddings(data_hash)
|
| |
|
| | if embeddings is not None:
|
| | print("找到缓存的embeddings,直接加载使用")
|
| | else:
|
| | print("未找到缓存的embeddings,重新计算...")
|
| | embeddings = batch_embed(documents)
|
| |
|
| | save_embeddings(embeddings, data_hash)
|
| |
|
| | print(f"向量生成完成,用时: {format_time(time.time() - vector_start)}")
|
| | print(f"当前内存使用: {get_memory_usage():.2f} MB")
|
| |
|
| |
|
| | print("\n[4/5] 创建数据库集合...")
|
| | collection = vectorize_data(documents, embeddings, metadatas)
|
| |
|
| | total_time = time.time() - start_time
|
| | print(f"\n数据库构建完成!")
|
| | print(f"总用时: {format_time(total_time)}")
|
| | print(f"总条目数: {collection.count()} 条")
|
| | print(f"数据库大小: {os.path.getsize(CHROMA_URI) / 1024 / 1024:.2f} MB")
|
| | print(f"最终内存使用: {get_memory_usage():.2f} MB")
|
| |
|