diff --git a/changelog.md b/changelog.md index 96652279..552df360 100644 --- a/changelog.md +++ b/changelog.md @@ -1,8 +1,15 @@ Upcoming Release (TBD) ====================== +Bug Fixes +---------- + +* Let table-name extraction work on multi-statement inputs. + + Internal -------- + * Work on passing `ruff check` linting. * Remove backward-compatibility hacks. * Pin more GitHub Actions and add Dependabot support. diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 9acbcd5c..5eac267e 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -1,4 +1,5 @@ import re +import sqlglot import sqlparse from sqlparse.sql import IdentifierList, Identifier, Function from sqlparse.tokens import Keyword, DML, Punctuation @@ -166,6 +167,42 @@ def extract_tables(sql): return list(extract_table_identifiers(stream)) +def extract_tables_from_complete_statements(sql): + """Extract the table names from a complete and valid series of SQL + statements. + + Returns a list of (schema, table, alias) tuples + + """ + # sqlglot chokes entirely on things like "\T" that it doesn't know about, + # but is much better at extracting table names from complete statements. + # sqlparse can extract the series of statements, though it also doesn't + # understand "\T". + roughly_parsed = sqlparse.parse(sql) + if not roughly_parsed: + return [] + + finely_parsed = [] + for statement in roughly_parsed: + try: + finely_parsed.append(sqlglot.parse_one(str(statement), read='mysql')) + except sqlglot.errors.ParseError: + pass + + tables = [] + for statement in finely_parsed: + for identifier in statement.find_all(sqlglot.exp.Table): + if identifier.parent_select.sql().startswith('WITH'): + continue + tables.append(( + None if identifier.db == '' else identifier.db, + identifier.name, + None if identifier.alias == '' else identifier.alias, + )) + + return tables + + def find_prev_keyword(sql): """Find the last sql keyword in an SQL statement diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index 828a4b38..008e4d43 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -1,6 +1,6 @@ """Format adapter for sql.""" -from mycli.packages.parseutils import extract_tables +from mycli.packages.parseutils import extract_tables_from_complete_statements supported_formats = ( "sql-insert", @@ -20,7 +20,7 @@ def escape_for_sql_statement(value): def adapter(data, headers, table_format=None, **kwargs): - tables = extract_tables(formatter.query) + tables = extract_tables_from_complete_statements(formatter.query) if len(tables) > 0: table = tables[0] if table[0]: diff --git a/test/test_parseutils.py b/test/test_parseutils.py index abc4a9c8..7f1aa4c5 100644 --- a/test/test_parseutils.py +++ b/test/test_parseutils.py @@ -1,6 +1,7 @@ import pytest from mycli.packages.parseutils import ( extract_tables, + extract_tables_from_complete_statements, query_starts_with, queries_start_with, is_destructive, @@ -107,6 +108,22 @@ def test_join_as_table(): assert tables == [(None, "my_table", "m")] +def test_extract_tables_from_complete_statements(): + tables = extract_tables_from_complete_statements("SELECT * FROM my_table AS m WHERE m.a > 5") + assert tables == [(None, "my_table", "m")] + + +def test_extract_tables_from_complete_statements_cte(): + tables = extract_tables_from_complete_statements("WITH my_cte (id, num) AS ( SELECT id, COUNT(1) FROM my_table GROUP BY id ) SELECT *") + assert tables == [(None, "my_table", None)] + + +# this would confuse plain extract_tables() per #1122 +def test_extract_tables_from_multiple_complete_statements(): + tables = extract_tables_from_complete_statements(r'\T sql-insert; SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, "my_table", "m")] + + def test_query_starts_with(): query = "USE test;" assert query_starts_with(query, ("use",)) is True