Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
Upcoming Release (TBD)
======================

Bug Fixes
----------

* Let table-name extraction work on multi-statement inputs.


Internal
--------

* Work on passing `ruff check` linting.
* Remove backward-compatibility hacks.
* Pin more GitHub Actions and add Dependabot support.
Expand Down
37 changes: 37 additions & 0 deletions mycli/packages/parseutils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import sqlglot
import sqlparse
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
Expand Down Expand Up @@ -166,6 +167,42 @@ def extract_tables(sql):
return list(extract_table_identifiers(stream))


def extract_tables_from_complete_statements(sql):
"""Extract the table names from a complete and valid series of SQL
statements.

Returns a list of (schema, table, alias) tuples

"""
# sqlglot chokes entirely on things like "\T" that it doesn't know about,
# but is much better at extracting table names from complete statements.
# sqlparse can extract the series of statements, though it also doesn't
# understand "\T".
roughly_parsed = sqlparse.parse(sql)
if not roughly_parsed:
return []

finely_parsed = []
for statement in roughly_parsed:
try:
finely_parsed.append(sqlglot.parse_one(str(statement), read='mysql'))
except sqlglot.errors.ParseError:
pass

tables = []
for statement in finely_parsed:
for identifier in statement.find_all(sqlglot.exp.Table):
if identifier.parent_select.sql().startswith('WITH'):
continue
tables.append((
None if identifier.db == '' else identifier.db,
identifier.name,
None if identifier.alias == '' else identifier.alias,
))

return tables


def find_prev_keyword(sql):
"""Find the last sql keyword in an SQL statement

Expand Down
4 changes: 2 additions & 2 deletions mycli/packages/tabular_output/sql_format.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Format adapter for sql."""

from mycli.packages.parseutils import extract_tables
from mycli.packages.parseutils import extract_tables_from_complete_statements

supported_formats = (
"sql-insert",
Expand All @@ -20,7 +20,7 @@ def escape_for_sql_statement(value):


def adapter(data, headers, table_format=None, **kwargs):
tables = extract_tables(formatter.query)
tables = extract_tables_from_complete_statements(formatter.query)
if len(tables) > 0:
table = tables[0]
if table[0]:
Expand Down
17 changes: 17 additions & 0 deletions test/test_parseutils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from mycli.packages.parseutils import (
extract_tables,
extract_tables_from_complete_statements,
query_starts_with,
queries_start_with,
is_destructive,
Expand Down Expand Up @@ -107,6 +108,22 @@ def test_join_as_table():
assert tables == [(None, "my_table", "m")]


def test_extract_tables_from_complete_statements():
tables = extract_tables_from_complete_statements("SELECT * FROM my_table AS m WHERE m.a > 5")
assert tables == [(None, "my_table", "m")]


def test_extract_tables_from_complete_statements_cte():
tables = extract_tables_from_complete_statements("WITH my_cte (id, num) AS ( SELECT id, COUNT(1) FROM my_table GROUP BY id ) SELECT *")
assert tables == [(None, "my_table", None)]


# this would confuse plain extract_tables() per #1122
def test_extract_tables_from_multiple_complete_statements():
tables = extract_tables_from_complete_statements(r'\T sql-insert; SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == [(None, "my_table", "m")]


def test_query_starts_with():
query = "USE test;"
assert query_starts_with(query, ("use",)) is True
Expand Down