From 8a21836300bfd05d87429d4b6e69ac098013d21b Mon Sep 17 00:00:00 2001 From: T-Santos Date: Mon, 26 Aug 2024 13:15:22 -0400 Subject: [PATCH 1/6] Update test_truncate_table.py Add whitespace From 66707a60002afaf0fb93c5286740b0456a0d10e2 Mon Sep 17 00:00:00 2001 From: "Santos, Tyler (Boston)" Date: Wed, 23 Oct 2024 21:48:40 +0000 Subject: [PATCH 2/6] MSSQL Unqualified Schema Table Parsing --- sql_metadata/parser.py | 24 +++++++++++++++--------- test/test_mssql_server.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index a8cac3e6..68675cf1 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -219,7 +219,7 @@ def columns(self) -> List[str]: self._handle_column_save(token=token, columns=columns) elif token.is_column_name_inside_insert_clause: - column = str(token.value).strip("`") + column = str(token.value) self._add_to_columns_subsection( keyword=token.last_keyword_normalized, column=column ) @@ -350,6 +350,7 @@ def tables(self) -> List[str]: with_names = self.with_names for token in self._not_parsed_tokens: + #import ipdb; ipdb.set_trace() if token.is_potential_table_name: if ( token.is_alias_of_table_or_alias_of_subquery @@ -369,10 +370,8 @@ def tables(self) -> List[str]: and self.query_type == "INSERT" ): continue - - table_name = str(token.value.strip("`")) token.token_type = TokenType.TABLE - tables.append(table_name) + tables.append(str(token.value)) self._tables = tables - with_names return self._tables @@ -1026,17 +1025,24 @@ def _combine_qualified_names(self, index: int, token: SQLToken) -> None: is_complex = True while is_complex: value, is_complex = self._combine_tokens(index=index, value=value) - index = index - 2 + index = index - 1 token.value = value def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]: """ Checks if complex identifier is longer and follows back until it's finished """ - if index > 1 and str(self.non_empty_tokens[index - 1]) == ".": - prev_value = self.non_empty_tokens[index - 2].value.strip("`").strip('"') - value = f"{prev_value}.{value}" - return value, True + keep_going = False + #import ipdb; ipdb.set_trace() + if index > 1: + prev_value = self.non_empty_tokens[index - 1] + prev_value = str(prev_value).strip('`') + if prev_value == ".": + keep_going = True + if str(self.non_empty_tokens[index - 2]) == ".": + keep_going = True + value = f"{prev_value.strip('`')}{value}" + return value, keep_going return value, False def _get_sqlparse_tokens(self, parsed) -> None: diff --git a/test/test_mssql_server.py b/test/test_mssql_server.py index 6c5e24cf..303a06bc 100644 --- a/test/test_mssql_server.py +++ b/test/test_mssql_server.py @@ -1,5 +1,17 @@ +import pytest + from sql_metadata.parser import Parser +@pytest.mark.parametrize( + "query, expected", + [ + pytest.param("SELECT * FROM mydb..test_table", ["mydb..test_table"], id='Default schema, db qualified'), + #pytest.param("SELECT * FROM ..test_table", ["..test_table"], id='Default schema, db unqualified'), + ] +) +def test_simple_queries_tables(query, expected): + assert Parser(query).tables == expected + def test_sql_server_cte(): """ From 5612b06ce05ef86285375d66fba1260c16446cb6 Mon Sep 17 00:00:00 2001 From: "Santos, Tyler (Boston)" Date: Wed, 23 Oct 2024 22:23:04 +0000 Subject: [PATCH 3/6] tests working --- sql_metadata/parser.py | 12 +++++++----- test/test_mssql_server.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index 68675cf1..cf9993fc 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -144,7 +144,8 @@ def tokens(self) -> List[SQLToken]: # noqa: C901 combine_flag = False for index, tok in enumerate(self.non_empty_tokens): # combine dot separated identifiers - if self._is_token_part_of_complex_identifier(token=tok, index=index): + #import ipdb; ipdb.set_trace() + if not tok.is_keyword and self._is_token_part_of_complex_identifier(token=tok, index=index): combine_flag = True continue token = SQLToken( @@ -1033,15 +1034,16 @@ def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]: Checks if complex identifier is longer and follows back until it's finished """ keep_going = False - #import ipdb; ipdb.set_trace() if index > 1: prev_value = self.non_empty_tokens[index - 1] - prev_value = str(prev_value).strip('`') - if prev_value == ".": + if prev_value.is_keyword: + return value, False + if str(prev_value) == ".": keep_going = True if str(self.non_empty_tokens[index - 2]) == ".": keep_going = True - value = f"{prev_value.strip('`')}{value}" + prev_value = str(prev_value).strip('`') + value = f"{prev_value}{value}" return value, keep_going return value, False diff --git a/test/test_mssql_server.py b/test/test_mssql_server.py index 303a06bc..1500243f 100644 --- a/test/test_mssql_server.py +++ b/test/test_mssql_server.py @@ -6,7 +6,7 @@ "query, expected", [ pytest.param("SELECT * FROM mydb..test_table", ["mydb..test_table"], id='Default schema, db qualified'), - #pytest.param("SELECT * FROM ..test_table", ["..test_table"], id='Default schema, db unqualified'), + pytest.param("SELECT * FROM ..test_table", ["..test_table"], id='Default schema, db unqualified'), ] ) def test_simple_queries_tables(query, expected): From 760e9d48427faa2361996926dc9b5a9cf1786e15 Mon Sep 17 00:00:00 2001 From: "Santos, Tyler (Boston)" Date: Wed, 23 Oct 2024 22:28:18 +0000 Subject: [PATCH 4/6] simplify code --- sql_metadata/parser.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index cf9993fc..e683cecd 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -144,8 +144,7 @@ def tokens(self) -> List[SQLToken]: # noqa: C901 combine_flag = False for index, tok in enumerate(self.non_empty_tokens): # combine dot separated identifiers - #import ipdb; ipdb.set_trace() - if not tok.is_keyword and self._is_token_part_of_complex_identifier(token=tok, index=index): + if self._is_token_part_of_complex_identifier(token=tok, index=index): combine_flag = True continue token = SQLToken( @@ -351,7 +350,6 @@ def tables(self) -> List[str]: with_names = self.with_names for token in self._not_parsed_tokens: - #import ipdb; ipdb.set_trace() if token.is_potential_table_name: if ( token.is_alias_of_table_or_alias_of_subquery @@ -1013,6 +1011,8 @@ def _is_token_part_of_complex_identifier( Checks if token is a part of complex identifier like .. or
. """ + if token.is_keyword: + return False return str(token) == "." or ( index + 1 < self.tokens_length and str(self.non_empty_tokens[index + 1]) == "." @@ -1033,18 +1033,13 @@ def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]: """ Checks if complex identifier is longer and follows back until it's finished """ - keep_going = False if index > 1: prev_value = self.non_empty_tokens[index - 1] - if prev_value.is_keyword: + if not self._is_token_part_of_complex_identifier(prev_value, index-1): return value, False - if str(prev_value) == ".": - keep_going = True - if str(self.non_empty_tokens[index - 2]) == ".": - keep_going = True prev_value = str(prev_value).strip('`') value = f"{prev_value}{value}" - return value, keep_going + return value, True return value, False def _get_sqlparse_tokens(self, parsed) -> None: From e60efcc65f572bf7d50bc4ccaf6be7360174f501 Mon Sep 17 00:00:00 2001 From: "Santos, Tyler (Boston)" Date: Thu, 24 Oct 2024 12:23:23 +0000 Subject: [PATCH 5/6] more tests --- test/test_mssql_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_mssql_server.py b/test/test_mssql_server.py index 1500243f..7ba0d1f3 100644 --- a/test/test_mssql_server.py +++ b/test/test_mssql_server.py @@ -7,6 +7,8 @@ [ pytest.param("SELECT * FROM mydb..test_table", ["mydb..test_table"], id='Default schema, db qualified'), pytest.param("SELECT * FROM ..test_table", ["..test_table"], id='Default schema, db unqualified'), + pytest.param("SELECT * FROM [mydb].[dbo].[test_table]", ["[mydb].[dbo].[test_table]"], id='With object identifier delimiters'), + pytest.param("SELECT * FROM [my_server].[mydb].[dbo].[test_table]", ["[my_server].[mydb].[dbo].[test_table]"], id='With linked-server and object identifier delimiters'), ] ) def test_simple_queries_tables(query, expected): From 3aa9480f06ad7ae40e4fe2c366c00f94e06bd0ea Mon Sep 17 00:00:00 2001 From: "Santos, Tyler (Boston)" Date: Thu, 24 Oct 2024 12:27:42 +0000 Subject: [PATCH 6/6] formatting --- sql_metadata/parser.py | 4 ++-- test/test_mssql_server.py | 31 ++++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index e683cecd..7605da22 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -1035,9 +1035,9 @@ def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]: """ if index > 1: prev_value = self.non_empty_tokens[index - 1] - if not self._is_token_part_of_complex_identifier(prev_value, index-1): + if not self._is_token_part_of_complex_identifier(prev_value, index - 1): return value, False - prev_value = str(prev_value).strip('`') + prev_value = str(prev_value).strip("`") value = f"{prev_value}{value}" return value, True return value, False diff --git a/test/test_mssql_server.py b/test/test_mssql_server.py index 7ba0d1f3..3510b632 100644 --- a/test/test_mssql_server.py +++ b/test/test_mssql_server.py @@ -2,14 +2,31 @@ from sql_metadata.parser import Parser + @pytest.mark.parametrize( - "query, expected", - [ - pytest.param("SELECT * FROM mydb..test_table", ["mydb..test_table"], id='Default schema, db qualified'), - pytest.param("SELECT * FROM ..test_table", ["..test_table"], id='Default schema, db unqualified'), - pytest.param("SELECT * FROM [mydb].[dbo].[test_table]", ["[mydb].[dbo].[test_table]"], id='With object identifier delimiters'), - pytest.param("SELECT * FROM [my_server].[mydb].[dbo].[test_table]", ["[my_server].[mydb].[dbo].[test_table]"], id='With linked-server and object identifier delimiters'), - ] + "query, expected", + [ + pytest.param( + "SELECT * FROM mydb..test_table", + ["mydb..test_table"], + id="Default schema, db qualified", + ), + pytest.param( + "SELECT * FROM ..test_table", + ["..test_table"], + id="Default schema, db unqualified", + ), + pytest.param( + "SELECT * FROM [mydb].[dbo].[test_table]", + ["[mydb].[dbo].[test_table]"], + id="With object identifier delimiters", + ), + pytest.param( + "SELECT * FROM [my_server].[mydb].[dbo].[test_table]", + ["[my_server].[mydb].[dbo].[test_table]"], + id="With linked-server and object identifier delimiters", + ), + ], ) def test_simple_queries_tables(query, expected): assert Parser(query).tables == expected