Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| torch.set_num_threads(1) | |
| BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model.eval() | |
| print("Model ready") | |
| # βββββββββββββββββββββββββ | |
| # SQL FILTER | |
| # βββββββββββββββββββββββββ | |
| SQL_KEYWORDS = [ | |
| "sql", "database", "table", "select", "insert", | |
| "update", "delete", "join", "group by", | |
| "postgres", "mysql", "sqlite", "query" | |
| ] | |
| def is_sql_related(text): | |
| text = text.lower() | |
| return any(k in text for k in SQL_KEYWORDS) | |
| # βββββββββββββββββββββββββ | |
| # GENERATION | |
| # βββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = """ | |
| You are an expert SQL generator. | |
| Rules: | |
| - Only respond to SQL or database related questions. | |
| - If the question is not about SQL or databases, refuse. | |
| - Output ONLY SQL query. | |
| - Do not explain. | |
| """ | |
| def generate_sql(user_input): | |
| if not user_input.strip(): | |
| return "Enter SQL question." | |
| # HARD GUARD | |
| if not is_sql_related(user_input): | |
| return "I only respond to SQL and database related questions. If you want, I can craft helpful database queries for you." | |
| prompt = f""" | |
| {SYSTEM_PROMPT} | |
| User request: {user_input} | |
| SQL: | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=120, | |
| temperature=0.1, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # return only SQL part | |
| result = text.split("SQL:")[-1].strip() | |
| # extra safety: remove explanations | |
| result = result.split("\n\n")[0] | |
| return result | |
| # βββββββββββββββββββββββββ | |
| # UI | |
| # βββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=gr.Textbox( | |
| lines=3, | |
| label="SQL Question", | |
| placeholder="Find duplicate emails in users table" | |
| ), | |
| outputs=gr.Textbox( | |
| lines=8, | |
| label="Generated SQL" | |
| ), | |
| title="AI SQL Generator (Portfolio Project)", | |
| description="This model ONLY responds to SQL/database queries.", | |
| examples=[ | |
| ["Find duplicate emails in users table"], | |
| ["Top 5 highest paid employees"], | |
| ["Count orders per customer last month"], | |
| ["Write a joke about cats"] # will be blocked | |
| ], | |
| ) | |
| demo.launch(server_name="0.0.0.0") | |