nl2sql-api / backend /src /database /db_manager.py
dvwn's picture
Folder and Files Structure Refactor
31c1d8c
# Path: src/database/db_manager.py
# This module provides a function to establish a connection to the SQLite database.
# Include a RAG-based dynamic schema retrieval.
import os
import sqlite3
from langchain_community.utilities import SQLDatabase
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
# The path to SQLite database
DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
DB_URI = f"sqlite:///{DB_PATH}"
# Define RAG table description as knowledge base
TABLE_DESCRIPTIONS = [
Document(page_content="Contains music albums associated with artists", metadata={"table_name": "Album"}),
Document(page_content="Contains information about artists", metadata={"table_name": "Artist"}),
Document(page_content="Contains customer information like name, address, phone, and email", metadata={"table_name": "Customer"}),
Document(page_content="Contains employee information such as name, title, hire date, and manager", metadata={"table_name": "Employee"}),
Document(page_content="Contains musical genres like Rock, Jazz, or Metal", metadata={"table_name": "Genre"}),
Document(page_content="Contains details of invoices including billing information and total amount", metadata={"table_name": "Invoice"}),
Document(page_content="Contains line items for each invoice, linking to the purchased tracks.", metadata={"table_name": "InvoiceLine"}),
Document(page_content="Contains media types like MPEG audio or AAC audio.", metadata={"table_name": "MediaType"}),
Document(page_content="Contains custom playlists created by users.", metadata={"table_name": "Playlist"}),
Document(page_content="Mapping table linking tracks to playlists.", metadata={"table_name": "PlaylistTrack"}),
Document(page_content="Contains details of music tracks including name, album, genre, and composer", metadata={"table_name": "Track"}),
]
# Initialize Chroma vector store with HuggingFace embeddings
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Initialize Chroma vector store with a persistent directory for storing embeddings
vector_store = Chroma.from_documents(
documents = TABLE_DESCRIPTIONS,
embedding = embedding_model,
persist_directory = "chroma_db"
)
# Get a connection to the SQLite database
def get_db_connection():
try:
return sqlite3.connect(DB_PATH)
except sqlite3.Error as e:
print(f"Error connecting to database: {e}")
return None
# Retrieve relevant schema context based on the user's question
def get_relevant_tables(question: str, top_k: int = 4) -> list:
"""
RAG function: Retrieves the most relevant table descriptions based on the user's question.
"""
# Perform a similarity search in vector store
docs = vector_store.similarity_search(question, k=top_k)
# Extract table names from the retrieved documents
relevant_tables = [doc.metadata["table_name"] for doc in docs]
return relevant_tables
# Get database schema context
def get_schema_context(question=None):
"""
Returns the DDL and sample rows.
If a question is provided, it retrieves the most relevant tables based on the question and returns their schema information.
"""
try:
tables_to_include = None
# Trigger RAG retrieval if a question is provided
if question:
tables_to_include = get_relevant_tables(question)
print(f"RAG selected tables for question '{question}': {tables_to_include}")
# Connect to the database and retrieve schema information
db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
return db.get_table_info()
except Exception as e:
return f"Error retrieving database schema: {e}"