| import streamlit as st |
| import numpy as np |
| from llm import load_llm, response_generator |
| from sql import csv_to_sqlite, run_sql_query |
|
|
|
|
| repo_id = "Qwen/Qwen2.5-Coder-3B-Instruct-GGUF" |
| filename = "qwen2.5-coder-3b-instruct-q6_k.gguf" |
| |
| |
|
|
| llm = load_llm(repo_id, filename) |
|
|
| st.title("CSV TO SQL") |
| st.write("To start, Upload your CSV below 👇") |
| if st.button("Example prompt"): |
| st.session_state.csv_file = "./data/sales.csv" |
| st.session_state.db_name = "sales" |
| st.session_state.table_name = "sales" |
| csv_to_sqlite("./data/sales.csv", "sales", "sales") |
|
|
| prompt = "What is the sum, count and average sales?" |
|
|
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| response_sql = response_generator( |
| db_name=st.session_state.db_name, |
| table_name=st.session_state.table_name, |
| llm=llm, |
| messages=st.session_state.messages, |
| question=prompt, |
| ) |
| result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
| st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
| st.session_state.messages.append( |
| {"role": "assistant", "content": str(result), "result": result} |
| ) |
|
|
|
|
| with st.expander("Upload CSV"): |
| csv_file = st.file_uploader( |
| "CSV", |
| ) |
| db_name = st.text_input("DB Name") |
| table_name = st.text_input("Table Name") |
| if st.button("Save"): |
| if csv_file and db_name and table_name: |
| st.session_state.csv_file = csv_file |
| st.session_state.db_name = db_name |
| st.session_state.table_name = table_name |
|
|
| csv_to_sqlite(csv_file, db_name, table_name) |
| st.write("Saved ✅") |
| else: |
| st.write("Please enter all values") |
|
|
| |
| if "messages" not in st.session_state: |
| st.session_state.messages = [] |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| if "content" in message: |
| if message["role"] == "user": |
| st.markdown(message["content"]) |
| else: |
| st.code(message["content"]) |
| if "result" in message: |
| st.dataframe(message["result"]) |
|
|
| |
| if prompt := st.chat_input( |
| "What is up?", |
| disabled=( |
| not "db_name" in st.session_state or not "table_name" in st.session_state |
| ), |
| ): |
| |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| |
| with st.chat_message("user"): |
| st.markdown(prompt) |
|
|
| |
| with st.chat_message("assistant"): |
| response_sql = response_generator( |
| db_name=st.session_state.db_name, |
| table_name=st.session_state.table_name, |
| llm=llm, |
| messages=st.session_state.messages, |
| question=prompt, |
| ) |
| response = st.code(response_sql) |
| result = run_sql_query(db_name=st.session_state.db_name, query=response_sql) |
| st.markdown(result) |
| st.table(result) |
|
|
| |
| st.session_state.messages.append({"role": "assistant", "content": response_sql}) |
|
|
| with st.sidebar: |
| st.title("Data Previewer") |
| st.write("You can see you CSV file content here") |
| if ( |
| "csv_file" in st.session_state |
| and "db_name" in st.session_state |
| and "table_name" in st.session_state |
| ): |
| result = run_sql_query( |
| db_name=st.session_state.db_name, |
| query=f"SELECT * FROM {st.session_state.table_name}", |
| ) |
| st.dataframe(result) |
|
|