File size: 3,873 Bytes
ccb82de
 
 
97dafec
0731ede
ccb82de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97dafec
 
ccb82de
 
 
97dafec
0731ede
ccb82de
 
 
 
0731ede
ccb82de
 
0731ede
ccb82de
 
 
0731ede
ccb82de
 
 
 
 
 
0731ede
ccb82de
 
 
 
 
 
 
 
 
 
 
a218c12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Path: src/database/db_manager.py
# This module provides a function to establish a connection to the SQLite database.
# Include a RAG-based dynamic schema retrieval.
import os
import sqlite3
from langchain_community.utilities import SQLDatabase
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document

# The path to SQLite database
DB_PATH = os.getenv("SQLITE_DB_PATH", "src/database/Chinook_Sqlite.sqlite")
DB_URI = f"sqlite:///{DB_PATH}"

# Define RAG table description as knowledge base
TABLE_DESCRIPTIONS = [
    Document(page_content="Contains music albums associated with artists", metadata={"table_name": "Album"}),
    Document(page_content="Contains information about artists", metadata={"table_name": "Artist"}),
    Document(page_content="Contains customer information like name, address, phone, and email", metadata={"table_name": "Customer"}),
    Document(page_content="Contains employee information such as name, title, hire date, and manager", metadata={"table_name": "Employee"}),
    Document(page_content="Contains musical genres like Rock, Jazz, or Metal", metadata={"table_name": "Genre"}),
    Document(page_content="Contains details of invoices including billing information and total amount", metadata={"table_name": "Invoice"}),
    Document(page_content="Contains line items for each invoice, linking to the purchased tracks.", metadata={"table_name": "InvoiceLine"}),
    Document(page_content="Contains media types like MPEG audio or AAC audio.", metadata={"table_name": "MediaType"}),
    Document(page_content="Contains custom playlists created by users.", metadata={"table_name": "Playlist"}),
    Document(page_content="Mapping table linking tracks to playlists.", metadata={"table_name": "PlaylistTrack"}),
    Document(page_content="Contains details of music tracks including name, album, genre, and composer", metadata={"table_name": "Track"}),
]

# Initialize Chroma vector store with HuggingFace embeddings
embedding_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

# Initialize Chroma vector store with a persistent directory for storing embeddings
vector_store = Chroma.from_documents(
    documents = TABLE_DESCRIPTIONS,
    embedding = embedding_model,
    persist_directory = "chroma_db"
)

# Get a connection to the SQLite database
def get_db_connection():
    try:
        return sqlite3.connect(DB_PATH)
    except sqlite3.Error as e:
        print(f"Error connecting to database: {e}")
        return None

# Retrieve relevant schema context based on the user's question
def get_relevant_tables(question: str, top_k: int = 4) -> list:
    """
    RAG function: Retrieves the most relevant table descriptions based on the user's question.
    """
    # Perform a similarity search in vector store
    docs = vector_store.similarity_search(question, k=top_k)

    # Extract table names from the retrieved documents
    relevant_tables = [doc.metadata["table_name"] for doc in docs]
    return relevant_tables

# Get database schema context
def get_schema_context(question=None):
    """
    Returns the DDL and sample rows.
    If a question is provided, it retrieves the most relevant tables based on the question and returns their schema information.
    """
    try:
        tables_to_include = None

        # Trigger RAG retrieval if a question is provided
        if question:
            tables_to_include = get_relevant_tables(question)
            print(f"RAG selected tables for question '{question}': {tables_to_include}")
        
        # Connect to the database and retrieve schema information
        db = SQLDatabase.from_uri(DB_URI, include_tables = tables_to_include, sample_rows_in_table_info=5)
        return db.get_table_info()
    except Exception as e:
        return f"Error retrieving database schema: {e}"