Spaces:
Sleeping
Sleeping
| # Path: src/database/db_manager.py | |
| # This module provides a function to establish a connection to the SQLite database. | |
| # Include a RAG-based dynamic schema retrieval. | |
| import os | |
| import sqlite3 | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_chroma import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.documents import Document | |
| # The path to SQLite database | |
| DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite") | |
| DB_URI = f"sqlite:///{DB_PATH}" | |
| # Define RAG table description as knowledge base | |
| TABLE_DESCRIPTIONS = [ | |
| Document(page_content="Contains music albums associated with artists", metadata={"table_name": "Album"}), | |
| Document(page_content="Contains information about artists", metadata={"table_name": "Artist"}), | |
| Document(page_content="Contains customer information like name, address, phone, and email", metadata={"table_name": "Customer"}), | |
| Document(page_content="Contains employee information such as name, title, hire date, and manager", metadata={"table_name": "Employee"}), | |
| Document(page_content="Contains musical genres like Rock, Jazz, or Metal", metadata={"table_name": "Genre"}), | |
| Document(page_content="Contains details of invoices including billing information and total amount", metadata={"table_name": "Invoice"}), | |
| Document(page_content="Contains line items for each invoice, linking to the purchased tracks.", metadata={"table_name": "InvoiceLine"}), | |
| Document(page_content="Contains media types like MPEG audio or AAC audio.", metadata={"table_name": "MediaType"}), | |
| Document(page_content="Contains custom playlists created by users.", metadata={"table_name": "Playlist"}), | |
| Document(page_content="Mapping table linking tracks to playlists.", metadata={"table_name": "PlaylistTrack"}), | |
| Document(page_content="Contains details of music tracks including name, album, genre, and composer", metadata={"table_name": "Track"}), | |
| ] | |
| # Initialize Chroma vector store with HuggingFace embeddings | |
| embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | |
| # Initialize Chroma vector store with a persistent directory for storing embeddings | |
| vector_store = Chroma.from_documents( | |
| documents = TABLE_DESCRIPTIONS, | |
| embedding = embedding_model, | |
| persist_directory = "chroma_db" | |
| ) | |
| # Get a connection to the SQLite database | |
| def get_db_connection(): | |
| try: | |
| return sqlite3.connect(DB_PATH) | |
| except sqlite3.Error as e: | |
| print(f"Error connecting to database: {e}") | |
| return None | |
| # Retrieve relevant schema context based on the user's question | |
| def get_relevant_tables(question: str, top_k: int = 4) -> list: | |
| """ | |
| RAG function: Retrieves the most relevant table descriptions based on the user's question. | |
| """ | |
| # Perform a similarity search in vector store | |
| docs = vector_store.similarity_search(question, k=top_k) | |
| # Extract table names from the retrieved documents | |
| relevant_tables = [doc.metadata["table_name"] for doc in docs] | |
| return relevant_tables | |
| # Get database schema context | |
| def get_schema_context(question=None): | |
| """ | |
| Returns the DDL and sample rows. | |
| If a question is provided, it retrieves the most relevant tables based on the question and returns their schema information. | |
| """ | |
| try: | |
| tables_to_include = None | |
| # Trigger RAG retrieval if a question is provided | |
| if question: | |
| tables_to_include = get_relevant_tables(question) | |
| print(f"RAG selected tables for question '{question}': {tables_to_include}") | |
| # Connect to the database and retrieve schema information | |
| db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5) | |
| return db.get_table_info() | |
| except Exception as e: | |
| return f"Error retrieving database schema: {e}" |