| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import sqlite3 |
| | from nltk import word_tokenize |
| |
|
| | CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') |
| | JOIN_KEYWORDS = ('join', 'on', 'as') |
| |
|
| | WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') |
| | UNIT_OPS = ('none', '-', '+', "*", '/') |
| | AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') |
| | TABLE_TYPE = { |
| | 'sql': "sql", |
| | 'table_unit': "table_unit", |
| | } |
| |
|
| | COND_OPS = ('and', 'or') |
| | SQL_OPS = ('intersect', 'union', 'except') |
| | ORDER_OPS = ('desc', 'asc') |
| |
|
| |
|
| |
|
| | class Schema: |
| | """ |
| | Simple schema which maps table&column to a unique identifier |
| | """ |
| | def __init__(self, schema): |
| | self._schema = schema |
| | self._idMap = self._map(self._schema) |
| |
|
| | @property |
| | def schema(self): |
| | return self._schema |
| |
|
| | @property |
| | def idMap(self): |
| | return self._idMap |
| |
|
| | def _map(self, schema): |
| | idMap = {'*': "__all__"} |
| | id = 1 |
| | for key, vals in schema.items(): |
| | for val in vals: |
| | idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" |
| | id += 1 |
| |
|
| | for key in schema: |
| | idMap[key.lower()] = "__" + key.lower() + "__" |
| | id += 1 |
| |
|
| | return idMap |
| |
|
| |
|
| | def get_schema(db): |
| | """ |
| | Get database's schema, which is a dict with table name as key |
| | and list of column names as value |
| | :param db: database path |
| | :return: schema dict |
| | """ |
| |
|
| | schema = {} |
| | conn = sqlite3.connect(db) |
| | cursor = conn.cursor() |
| |
|
| | |
| | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
| | tables = [str(table[0].lower()) for table in cursor.fetchall()] |
| |
|
| | |
| | for table in tables: |
| | cursor.execute("PRAGMA table_info({})".format(table)) |
| | schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] |
| |
|
| | return schema |
| |
|
| |
|
| | def get_schema_from_json(fpath): |
| | with open(fpath, encoding='utf8') as f: |
| | data = json.load(f) |
| |
|
| | schema = {} |
| | for entry in data: |
| | table = str(entry['table'].lower()) |
| | cols = [str(col['column_name'].lower()) for col in entry['col_data']] |
| | schema[table] = cols |
| |
|
| | return schema |
| |
|
| |
|
| | def tokenize(string): |
| | string = str(string) |
| | string = string.replace("\'", "\"") |
| | quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] |
| | assert len(quote_idxs) % 2 == 0, "Unexpected quote" |
| |
|
| | |
| | vals = {} |
| | for i in range(len(quote_idxs)-1, -1, -2): |
| | qidx1 = quote_idxs[i-1] |
| | qidx2 = quote_idxs[i] |
| | val = string[qidx1: qidx2+1] |
| | key = "__val_{}_{}__".format(qidx1, qidx2) |
| | string = string[:qidx1] + key + string[qidx2+1:] |
| | vals[key] = val |
| |
|
| | toks = [word.lower() for word in word_tokenize(string)] |
| | |
| | for i in range(len(toks)): |
| | if toks[i] in vals: |
| | toks[i] = vals[toks[i]] |
| |
|
| | |
| | eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] |
| | eq_idxs.reverse() |
| | prefix = ('!', '>', '<') |
| | for eq_idx in eq_idxs: |
| | pre_tok = toks[eq_idx-1] |
| | if pre_tok in prefix: |
| | toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] |
| |
|
| | return toks |
| |
|
| |
|
| | def scan_alias(toks): |
| | """Scan the index of 'as' and build the map for all alias""" |
| | as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] |
| | alias = {} |
| | for idx in as_idxs: |
| | alias[toks[idx+1]] = toks[idx-1] |
| | return alias |
| |
|
| |
|
| | def get_tables_with_alias(schema, toks): |
| | tables = scan_alias(toks) |
| | for key in schema: |
| | assert key not in tables, "Alias {} has the same name in table".format(key) |
| | tables[key] = key |
| | return tables |
| |
|
| |
|
| | def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | """ |
| | :returns next idx, column id |
| | """ |
| | tok = toks[start_idx] |
| | if tok == "*": |
| | return start_idx + 1, schema.idMap[tok] |
| |
|
| | if '.' in tok: |
| | alias, col = tok.split('.') |
| | key = tables_with_alias[alias] + "." + col |
| | return start_idx+1, schema.idMap[key] |
| |
|
| | assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" |
| |
|
| | for alias in default_tables: |
| | table = tables_with_alias[alias] |
| | if tok in schema.schema[table]: |
| | key = table + "." + tok |
| | return start_idx+1, schema.idMap[key] |
| |
|
| | assert False, "Error col: {}".format(tok) |
| |
|
| |
|
| | def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | """ |
| | :returns next idx, (agg_op id, col_id) |
| | """ |
| | idx = start_idx |
| | len_ = len(toks) |
| | isBlock = False |
| | isDistinct = False |
| | if toks[idx] == '(': |
| | isBlock = True |
| | idx += 1 |
| |
|
| | if toks[idx] in AGG_OPS: |
| | agg_id = AGG_OPS.index(toks[idx]) |
| | idx += 1 |
| | assert idx < len_ and toks[idx] == '(' |
| | idx += 1 |
| | if toks[idx] == "distinct": |
| | idx += 1 |
| | isDistinct = True |
| | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) |
| | assert idx < len_ and toks[idx] == ')' |
| | idx += 1 |
| | return idx, (agg_id, col_id, isDistinct) |
| |
|
| | if toks[idx] == "distinct": |
| | idx += 1 |
| | isDistinct = True |
| | agg_id = AGG_OPS.index("none") |
| | idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) |
| |
|
| | if isBlock: |
| | assert toks[idx] == ')' |
| | idx += 1 |
| |
|
| | return idx, (agg_id, col_id, isDistinct) |
| |
|
| |
|
| | def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | idx = start_idx |
| | len_ = len(toks) |
| | isBlock = False |
| | if toks[idx] == '(': |
| | isBlock = True |
| | idx += 1 |
| |
|
| | col_unit1 = None |
| | col_unit2 = None |
| | unit_op = UNIT_OPS.index('none') |
| |
|
| | idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) |
| | if idx < len_ and toks[idx] in UNIT_OPS: |
| | unit_op = UNIT_OPS.index(toks[idx]) |
| | idx += 1 |
| | idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) |
| |
|
| | if isBlock: |
| | assert toks[idx] == ')' |
| | idx += 1 |
| |
|
| | return idx, (unit_op, col_unit1, col_unit2) |
| |
|
| |
|
| | def parse_table_unit(toks, start_idx, tables_with_alias, schema): |
| | """ |
| | :returns next idx, table id, table name |
| | """ |
| | idx = start_idx |
| | len_ = len(toks) |
| | key = tables_with_alias[toks[idx]] |
| |
|
| | if idx + 1 < len_ and toks[idx+1] == "as": |
| | idx += 3 |
| | else: |
| | idx += 1 |
| |
|
| | return idx, schema.idMap[key], key |
| |
|
| |
|
| | def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | idx = start_idx |
| | len_ = len(toks) |
| |
|
| | isBlock = False |
| | if toks[idx] == '(': |
| | isBlock = True |
| | idx += 1 |
| |
|
| | if toks[idx] == 'select': |
| | idx, val = parse_sql(toks, idx, tables_with_alias, schema) |
| | elif "\"" in toks[idx]: |
| | val = toks[idx] |
| | idx += 1 |
| | else: |
| | try: |
| | val = float(toks[idx]) |
| | idx += 1 |
| | except: |
| | end_idx = idx |
| | while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ |
| | and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: |
| | end_idx += 1 |
| |
|
| | idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) |
| | idx = end_idx |
| |
|
| | if isBlock: |
| | assert toks[idx] == ')' |
| | idx += 1 |
| |
|
| | return idx, val |
| |
|
| |
|
| | def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | idx = start_idx |
| | len_ = len(toks) |
| | conds = [] |
| |
|
| | while idx < len_: |
| | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) |
| | not_op = False |
| | if toks[idx] == 'not': |
| | not_op = True |
| | idx += 1 |
| |
|
| | assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) |
| | op_id = WHERE_OPS.index(toks[idx]) |
| | idx += 1 |
| | val1 = val2 = None |
| | if op_id == WHERE_OPS.index('between'): |
| | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) |
| | assert toks[idx] == 'and' |
| | idx += 1 |
| | idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) |
| | else: |
| | idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) |
| | val2 = None |
| |
|
| | conds.append((not_op, op_id, val_unit, val1, val2)) |
| |
|
| | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): |
| | break |
| |
|
| | if idx < len_ and toks[idx] in COND_OPS: |
| | conds.append(toks[idx]) |
| | idx += 1 |
| |
|
| | return idx, conds |
| |
|
| |
|
| | def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): |
| | idx = start_idx |
| | len_ = len(toks) |
| |
|
| | assert toks[idx] == 'select', "'select' not found" |
| | idx += 1 |
| | isDistinct = False |
| | if idx < len_ and toks[idx] == 'distinct': |
| | idx += 1 |
| | isDistinct = True |
| | val_units = [] |
| |
|
| | while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: |
| | agg_id = AGG_OPS.index("none") |
| | if toks[idx] in AGG_OPS: |
| | agg_id = AGG_OPS.index(toks[idx]) |
| | idx += 1 |
| | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) |
| | val_units.append((agg_id, val_unit)) |
| | if idx < len_ and toks[idx] == ',': |
| | idx += 1 |
| |
|
| | return idx, (isDistinct, val_units) |
| |
|
| |
|
| | def parse_from(toks, start_idx, tables_with_alias, schema): |
| | """ |
| | Assume in the from clause, all table units are combined with join |
| | """ |
| | assert 'from' in toks[start_idx:], "'from' not found" |
| |
|
| | len_ = len(toks) |
| | idx = toks.index('from', start_idx) + 1 |
| | default_tables = [] |
| | table_units = [] |
| | conds = [] |
| |
|
| | while idx < len_: |
| | isBlock = False |
| | if toks[idx] == '(': |
| | isBlock = True |
| | idx += 1 |
| |
|
| | if toks[idx] == 'select': |
| | idx, sql = parse_sql(toks, idx, tables_with_alias, schema) |
| | table_units.append((TABLE_TYPE['sql'], sql)) |
| | else: |
| | if idx < len_ and toks[idx] == 'join': |
| | idx += 1 |
| | idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) |
| | table_units.append((TABLE_TYPE['table_unit'],table_unit)) |
| | default_tables.append(table_name) |
| | if idx < len_ and toks[idx] == "on": |
| | idx += 1 |
| | idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) |
| | if len(conds) > 0: |
| | conds.append('and') |
| | conds.extend(this_conds) |
| |
|
| | if isBlock: |
| | assert toks[idx] == ')' |
| | idx += 1 |
| | if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): |
| | break |
| |
|
| | return idx, table_units, conds, default_tables |
| |
|
| |
|
| | def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): |
| | idx = start_idx |
| | len_ = len(toks) |
| |
|
| | if idx >= len_ or toks[idx] != 'where': |
| | return idx, [] |
| |
|
| | idx += 1 |
| | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) |
| | return idx, conds |
| |
|
| |
|
| | def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): |
| | idx = start_idx |
| | len_ = len(toks) |
| | col_units = [] |
| |
|
| | if idx >= len_ or toks[idx] != 'group': |
| | return idx, col_units |
| |
|
| | idx += 1 |
| | assert toks[idx] == 'by' |
| | idx += 1 |
| |
|
| | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): |
| | idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) |
| | col_units.append(col_unit) |
| | if idx < len_ and toks[idx] == ',': |
| | idx += 1 |
| | else: |
| | break |
| |
|
| | return idx, col_units |
| |
|
| |
|
| | def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): |
| | idx = start_idx |
| | len_ = len(toks) |
| | val_units = [] |
| | order_type = 'asc' |
| |
|
| | if idx >= len_ or toks[idx] != 'order': |
| | return idx, val_units |
| |
|
| | idx += 1 |
| | assert toks[idx] == 'by' |
| | idx += 1 |
| |
|
| | while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): |
| | idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) |
| | val_units.append(val_unit) |
| | if idx < len_ and toks[idx] in ORDER_OPS: |
| | order_type = toks[idx] |
| | idx += 1 |
| | if idx < len_ and toks[idx] == ',': |
| | idx += 1 |
| | else: |
| | break |
| |
|
| | return idx, (order_type, val_units) |
| |
|
| |
|
| | def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): |
| | idx = start_idx |
| | len_ = len(toks) |
| |
|
| | if idx >= len_ or toks[idx] != 'having': |
| | return idx, [] |
| |
|
| | idx += 1 |
| | idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) |
| | return idx, conds |
| |
|
| |
|
| | def parse_limit(toks, start_idx): |
| | idx = start_idx |
| | len_ = len(toks) |
| |
|
| | if idx < len_ and toks[idx] == 'limit': |
| | idx += 2 |
| | return idx, int(toks[idx-1]) |
| |
|
| | return idx, None |
| |
|
| |
|
| | def parse_sql(toks, start_idx, tables_with_alias, schema): |
| | isBlock = False |
| | len_ = len(toks) |
| | idx = start_idx |
| |
|
| | sql = {} |
| | if toks[idx] == '(': |
| | isBlock = True |
| | idx += 1 |
| |
|
| | |
| | from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) |
| | sql['from'] = {'table_units': table_units, 'conds': conds} |
| | |
| | _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) |
| | idx = from_end_idx |
| | sql['select'] = select_col_units |
| | |
| | idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) |
| | sql['where'] = where_conds |
| | |
| | idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) |
| | sql['groupBy'] = group_col_units |
| | |
| | idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) |
| | sql['having'] = having_conds |
| | |
| | idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) |
| | sql['orderBy'] = order_col_units |
| | |
| | idx, limit_val = parse_limit(toks, idx) |
| | sql['limit'] = limit_val |
| |
|
| | idx = skip_semicolon(toks, idx) |
| | if isBlock: |
| | assert toks[idx] == ')' |
| | idx += 1 |
| | idx = skip_semicolon(toks, idx) |
| |
|
| | |
| | for op in SQL_OPS: |
| | sql[op] = None |
| | if idx < len_ and toks[idx] in SQL_OPS: |
| | sql_op = toks[idx] |
| | idx += 1 |
| | idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) |
| | sql[sql_op] = IUE_sql |
| | return idx, sql |
| |
|
| |
|
| | def load_data(fpath): |
| | with open(fpath, encoding='utf8') as f: |
| | data = json.load(f) |
| | return data |
| |
|
| |
|
| | def get_sql(schema, query): |
| | toks = tokenize(query) |
| | tables_with_alias = get_tables_with_alias(schema.schema, toks) |
| | _, sql = parse_sql(toks, 0, tables_with_alias, schema) |
| |
|
| | return sql |
| |
|
| |
|
| | def skip_semicolon(toks, start_idx): |
| | idx = start_idx |
| | while idx < len(toks) and toks[idx] == ";": |
| | idx += 1 |
| | return idx |
| |
|