Spaces:
Sleeping
Sleeping
File size: 7,025 Bytes
edef444 dfa643b edef444 c96208b edef444 dfa643b edef444 c96208b edef444 dfa643b c96208b dfa643b c96208b edef444 c96208b edef444 c96208b edef444 c96208b edef444 dfa643b edef444 c96208b edef444 dfa643b c96208b dfa643b c96208b dfa643b edef444 dfa643b edef444 c96208b edef444 dfa643b edef444 dfa643b c96208b edef444 c96208b edef444 dfa643b edef444 dfa643b c96208b dfa643b c96208b dfa643b edef444 dfa643b edef444 dfa643b c96208b dfa643b edef444 c96208b dfa643b c96208b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | # Path: src/scripts/evaluation_mode.py
# Evaluation script for Hugging Face SQL generation.
import json
import sqlglot
from pathlib import Path
import pandas as pd
from src.database.db_manager import get_db_connection
from src.nl2sql.hf_engine import get_models
from src.nl2sql.sql_agent import nl2sql_agent
from src.scripts.taxonomy_report import print_taxonomyReport
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
def _normalize_dataframe(dataframe: pd.DataFrame) -> list:
# Normalize dataframe to ensure accurate comparison
"""
Standardize dataframes for Execution Accuracy (EX).
- Converts dataframe to a list of tuples to ignore column names.
- Rounds floating points to 4 decimal places to avoid precision mismatch.
- Sorts the final list to ensure order-agnostic comparison.
"""
if dataframe is None or dataframe.empty:
return []
normalized = dataframe.copy()
for column in normalized.columns:
normalized[column] = normalized[column].map(
lambda x: round(float(x), 4)
if pd.api.types.is_numeric_dtype(type(x)) and isinstance(x, float)
else x
#lambda value: round(float(value), 6)
#if isinstance(value, (float, int))
#else value
)
# Convert to list of tuples for order-agnostic comparison
data_tuples = [tuple(row) for row in normalized.to_numpy()]
# Sort to ensure order agnoticism
try:
data_tuples.sort(key=lambda x: str(x))
except Exception as e:
pass
return data_tuples
# Semantic safety net
def extract_tables(sql: str) -> set:
"""
Extract a set of all table names used in a SQL query.
Used to catch false positives where EX passes but the model queried the wrong tables.
"""
if not sql:
return set()
try:
parsed = sqlglot.parse_one(sql, read=None)
# Find all table expressions & extract names, ignore aliases
return set(table.name.lower() for table in parsed.find_all(sqlglot.exp.Table) if table.name)
except Exception as e:
return set()
# EX: Compare generated SQL results with expected results
def calculate_ex(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
"""
Execution Accuracy (EX): Compare generated SQL results with expected results.
"""
if df_generated is None or df_gold is None:
return False
#if normalized_generated.shape != normalized_gold.shape:
# return False
try:
normalized_generated = _normalize_dataframe(df_generated)
normalized_gold = _normalize_dataframe(df_gold)
return normalized_generated == normalized_gold
except Exception as error:
print(f"EX Evaluation Error: {error}")
return False
def calculate_esm(generated_sql: str, gold_sql: str) -> bool:
"""
Exact Set Match (ESM): Compare AST structure using sqlglot.
- Ignores formatting, capitalization, and minor syntactic sugar.
"""
if not generated_sql or not gold_sql:
return False
try:
# Parse both SQL queries into expressions
generated_exp = sqlglot.parse_one(generated_sql, read=None)
gold_exp = sqlglot.parse_one(gold_sql, read=None)
# Compare the expressions for structural equivalence
return generated_exp == gold_exp
except Exception as error:
print(f"ESM Evaluation Error: {error}")
return False
def run_evaluation(model_id: str):
print(f"\nRunning SQL evaluation for model: {model_id}")
print("\n" + "-" *50)
if not TEST_CASES_PATH.exists():
print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
return
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
test_cases = json.load(handle)
results = []
ex_count = 0
esm_count = 0
print(f"Running evaluation on {len(test_cases)} test cases...\n")
for case in test_cases:
id = case.get("id")
question = case.get("question")
gold_sql = case.get("gold_sql")
taxonomy = case.get("taxonomy", "Unknown")
print(f"Testing ID {id}: {question[:40]}...")
# Implement agent to handle RAG retrieval and SQL generation
agent_response = nl2sql_agent(user_question=question, model_id=model_id)
generated_sql = agent_response.get("query", "")
# ESM Evaluation
esm_result = calculate_esm(generated_sql, gold_sql)
if esm_result:
esm_count += 1
# EX Evaluation
ex_result = False
connection = get_db_connection()
if connection is None:
raise RuntimeError("Unable to connect to the SQLite database.")
try:
df_generated = pd.read_sql_query(generated_sql, connection)
df_gold = pd.read_sql_query(gold_sql, connection)
# Trap the False Positive (empty set): weak test case
if df_gold.empty:
print(f"[!]WARNING: Gold SQL for ID {id} returned an emoty response.")
ex_result = calculate_ex(df_generated, df_gold)
# Semantic safety net check
if ex_result:
gen_tables = extract_tables(generated_sql)
gold_tables = extract_tables(gold_sql)
if gen_tables != gold_tables:
print(f"[X] FALSE POSITIVE (ID{id}): Data matched, tables not")
print(f"\nGenerated SQL tables: {gen_tables} | Gold SQL tables: {gold_tables}")
ex_result = False
if ex_result:
ex_count += 1
except Exception as error:
print(f"Error executing SQL for ID {id}: {error}")
finally:
connection.close()
results.append({
"id": id,
"question": question,
"taxonomy": taxonomy,
"ex_pass": ex_result,
"esm_pass": esm_result,
"generated_sql": generated_sql,
"gold_sql": gold_sql
})
# Summary Statistics
total = len(test_cases)
ex_accuracy = (ex_count / total) * 100 if total > 0 else 0
esm_accuracy = (esm_count / total) * 100 if total > 0 else 0
print("\nEVALUATION SUMMARY")
print("-" * 40)
print(f"Model Evaluated: {model_id}")
print(f"Total Test Cases: {total}")
print(f"Execution Accuracy (EX): {ex_accuracy:.2f}% ({ex_count}/{total})")
print(f"Exact Set Match (ESM): {esm_accuracy:.2f}% ({esm_count}/{total})")
safe_model_name = model_id.replace("/", "_").replace(":", "_")
output_file = Path(f"sql_eval_{safe_model_name}.json")
with output_file.open("w", encoding="utf-8") as handle:
json.dump(results, handle, indent=4)
print_taxonomyReport(results)
if __name__ == "__main__":
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
models_to_test = get_models()
for model in models_to_test:
run_evaluation(model) |