File size: 7,025 Bytes
edef444
 
 
dfa643b
edef444
 
 
c96208b
edef444
dfa643b
edef444
 
 
c96208b
edef444
dfa643b
 
c96208b
 
 
dfa643b
c96208b
 
 
edef444
 
 
 
c96208b
 
 
 
 
 
edef444
 
c96208b
 
edef444
c96208b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edef444
dfa643b
 
 
 
 
edef444
 
 
c96208b
 
 
edef444
 
 
dfa643b
c96208b
 
dfa643b
 
 
 
 
 
 
 
 
c96208b
 
 
dfa643b
 
 
 
 
 
 
edef444
dfa643b
edef444
 
c96208b
 
 
 
edef444
 
 
 
 
 
 
 
dfa643b
 
edef444
 
 
 
dfa643b
 
 
 
c96208b
edef444
 
c96208b
edef444
 
dfa643b
 
 
 
 
 
 
edef444
 
 
 
 
 
dfa643b
 
c96208b
 
 
 
dfa643b
c96208b
 
 
 
 
 
 
 
 
 
 
dfa643b
 
edef444
dfa643b
edef444
 
 
dfa643b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c96208b
dfa643b
 
 
edef444
c96208b
 
 
dfa643b
 
c96208b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Path: src/scripts/evaluation_mode.py
# Evaluation script for Hugging Face SQL generation.
import json
import sqlglot
from pathlib import Path
import pandas as pd
from src.database.db_manager import get_db_connection
from src.nl2sql.hf_engine import get_models
from src.nl2sql.sql_agent import nl2sql_agent
from src.scripts.taxonomy_report import print_taxonomyReport

TEST_CASES_PATH = Path("src/scripts/test_cases.json")

def _normalize_dataframe(dataframe: pd.DataFrame) -> list:
    # Normalize dataframe to ensure accurate comparison
    """
    Standardize dataframes for Execution Accuracy (EX).
    - Converts dataframe to a list of tuples to ignore column names.
    - Rounds floating points to 4 decimal places to avoid precision mismatch.
    - Sorts the final list to ensure order-agnostic comparison.
    """
    if dataframe is None or dataframe.empty:
        return []
    
    normalized = dataframe.copy()

    for column in normalized.columns:
        normalized[column] = normalized[column].map(
            lambda x: round(float(x), 4)
            if pd.api.types.is_numeric_dtype(type(x)) and isinstance(x, float)
            else x
            #lambda value: round(float(value), 6)
            #if isinstance(value, (float, int))
            #else value
        )

    # Convert to list of tuples for order-agnostic comparison
    data_tuples = [tuple(row) for row in normalized.to_numpy()]

    # Sort to ensure order agnoticism
    try:
        data_tuples.sort(key=lambda x: str(x))
    except Exception as e:
        pass
        
    return data_tuples

# Semantic safety net
def extract_tables(sql: str) -> set:
    """
    Extract a set of all table names used in a SQL query.
    Used to catch false positives where EX passes but the model queried the wrong tables.
    """
    if not sql:
        return set()
    try:
        parsed = sqlglot.parse_one(sql, read=None)

        # Find all table expressions & extract names, ignore aliases
        return set(table.name.lower() for table in parsed.find_all(sqlglot.exp.Table) if table.name)
    except Exception as e:
        return set()

# EX: Compare generated SQL results with expected results
def calculate_ex(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
    """
    Execution Accuracy (EX): Compare generated SQL results with expected results.
    """
    if df_generated is None or df_gold is None:
        return False

    #if normalized_generated.shape != normalized_gold.shape:
    #        return False
    
    try:
        normalized_generated = _normalize_dataframe(df_generated)
        normalized_gold = _normalize_dataframe(df_gold)
        
        return normalized_generated == normalized_gold

    except Exception as error:
        print(f"EX Evaluation Error: {error}")
        return False

def calculate_esm(generated_sql: str, gold_sql: str) -> bool:
    """
    Exact Set Match (ESM): Compare AST structure using sqlglot.
    - Ignores formatting, capitalization, and minor syntactic sugar.
    """
    if not generated_sql or not gold_sql:
        return False
    
    try:
        # Parse both SQL queries into expressions
        generated_exp = sqlglot.parse_one(generated_sql, read=None)
        gold_exp = sqlglot.parse_one(gold_sql, read=None)

        # Compare the expressions for structural equivalence
        return generated_exp == gold_exp
    except Exception as error:
        print(f"ESM Evaluation Error: {error}")
        return False

def run_evaluation(model_id: str):
    print(f"\nRunning SQL evaluation for model: {model_id}")
    print("\n" + "-" *50)

    if not TEST_CASES_PATH.exists():
        print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
        return
    
    with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
        test_cases = json.load(handle)

    results = []
    ex_count = 0
    esm_count = 0

    print(f"Running evaluation on {len(test_cases)} test cases...\n")

    for case in test_cases:
        id = case.get("id")
        question = case.get("question")
        gold_sql = case.get("gold_sql")
        taxonomy = case.get("taxonomy", "Unknown")
        print(f"Testing ID {id}: {question[:40]}...")

        # Implement agent to handle RAG retrieval and SQL generation
        agent_response = nl2sql_agent(user_question=question, model_id=model_id)
        generated_sql = agent_response.get("query", "")

        # ESM Evaluation
        esm_result = calculate_esm(generated_sql, gold_sql)
        if esm_result:
            esm_count += 1

        # EX Evaluation
        ex_result = False
        connection = get_db_connection()
        if connection is None:
            raise RuntimeError("Unable to connect to the SQLite database.")

        try:
            df_generated = pd.read_sql_query(generated_sql, connection)
            df_gold = pd.read_sql_query(gold_sql, connection)

            # Trap the False Positive (empty set): weak test case
            if df_gold.empty:
                print(f"[!]WARNING: Gold SQL for ID {id} returned an emoty response.")

            ex_result = calculate_ex(df_generated, df_gold)

            # Semantic safety net check
            if ex_result:
                gen_tables = extract_tables(generated_sql)
                gold_tables = extract_tables(gold_sql)

                if gen_tables != gold_tables:
                    print(f"[X] FALSE POSITIVE (ID{id}): Data matched, tables not")
                    print(f"\nGenerated SQL tables: {gen_tables} | Gold SQL tables: {gold_tables}")
                    ex_result = False
        
            if ex_result:
                ex_count += 1
        except Exception as error:
            print(f"Error executing SQL for ID {id}: {error}")
        finally:
            connection.close()

        results.append({
            "id": id,
            "question": question,
            "taxonomy": taxonomy,
            "ex_pass": ex_result,
            "esm_pass": esm_result,
            "generated_sql": generated_sql,
            "gold_sql": gold_sql
        })
    
    # Summary Statistics
    total = len(test_cases)
    ex_accuracy = (ex_count / total) * 100 if total > 0 else 0
    esm_accuracy = (esm_count / total) * 100 if total > 0 else 0

    print("\nEVALUATION SUMMARY")
    print("-" * 40)
    print(f"Model Evaluated: {model_id}")
    print(f"Total Test Cases: {total}")
    print(f"Execution Accuracy (EX): {ex_accuracy:.2f}% ({ex_count}/{total})")
    print(f"Exact Set Match (ESM): {esm_accuracy:.2f}% ({esm_count}/{total})")

    safe_model_name = model_id.replace("/", "_").replace(":", "_")
    output_file = Path(f"sql_eval_{safe_model_name}.json")
    with output_file.open("w", encoding="utf-8") as handle:
        json.dump(results, handle, indent=4)

    print_taxonomyReport(results)

if __name__ == "__main__":
    from dotenv import load_dotenv, find_dotenv
    load_dotenv(find_dotenv())

    models_to_test = get_models()
    for model in models_to_test:
        run_evaluation(model)