| import streamlit as st |
| from llama_cpp import Llama |
| from sql import get_table_schema |
|
|
|
|
| @st.cache_resource() |
| def load_llm(repo_id, filename): |
| llm = Llama.from_pretrained( |
| repo_id=repo_id, |
| filename=filename, |
| verbose=True, |
| use_mmap=True, |
| use_mlock=True, |
| n_threads=4, |
| n_threads_batch=4, |
| n_ctx=8000, |
| ) |
| print(f"{repo_id} loaded successfully. ✅") |
| return llm |
|
|
|
|
| def generate_system_prompt(table_name, table_schema): |
| """ |
| Generates a prompt to provide context about a table's schema for LLM to convert natural language to SQL. |
| |
| Args: |
| table_name (str): The name of the table. |
| table_schema (list): A list of tuples where each tuple contains information about the columns in the table. |
| |
| Returns: |
| str: The generated prompt to be used by the LLM. |
| """ |
| prompt = f"""You are an expert in writing SQL queries for relational databases, specially sqlite. |
| You will be provided with a database schema and a natural |
| language question, and your task is to generate an accurate SQL query. |
| |
| The database has a table named '{table_name}' with the following schema:\n\n""" |
|
|
| prompt += "Columns:\n" |
|
|
| for col in table_schema: |
| column_name = col[1] |
| column_type = col[2] |
| prompt += f"- {column_name} ({column_type})\n" |
|
|
| prompt += "\nGenerate a SQL query based on the following natural language question. ONLY return the SQL query and nothing else." |
|
|
| return prompt |
|
|
|
|
| |
| def response_generator(llm, messages, question, table_name, db_name): |
| table_schema = get_table_schema(db_name, table_name) |
| llm_prompt = generate_system_prompt(table_name, table_schema) |
| user_prompt = f"""Question: {question}""" |
|
|
| print(messages, llm_prompt, user_prompt) |
| history = [{"content": llm_prompt.format(table_name=table_name), "role": "system"}] |
|
|
| for val in messages: |
| history.append(val) |
|
|
| history.append({"role": "user", "content": user_prompt}) |
|
|
| response = llm.create_chat_completion( |
| messages=history, |
| max_tokens=2048, |
| temperature=0.7, |
| top_p=0.95, |
| ) |
| answer = response["choices"][0]["message"]["content"] |
|
|
| query = answer.replace("```sql", "").replace("```", "") |
| query = query.strip() |
| return query |
|
|