diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index a8cac3e6..7605da22 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -219,7 +219,7 @@ def columns(self) -> List[str]: self._handle_column_save(token=token, columns=columns) elif token.is_column_name_inside_insert_clause: - column = str(token.value).strip("`") + column = str(token.value) self._add_to_columns_subsection( keyword=token.last_keyword_normalized, column=column ) @@ -369,10 +369,8 @@ def tables(self) -> List[str]: and self.query_type == "INSERT" ): continue - - table_name = str(token.value.strip("`")) token.token_type = TokenType.TABLE - tables.append(table_name) + tables.append(str(token.value)) self._tables = tables - with_names return self._tables @@ -1013,6 +1011,8 @@ def _is_token_part_of_complex_identifier( Checks if token is a part of complex identifier like .. or
. """ + if token.is_keyword: + return False return str(token) == "." or ( index + 1 < self.tokens_length and str(self.non_empty_tokens[index + 1]) == "." @@ -1026,16 +1026,19 @@ def _combine_qualified_names(self, index: int, token: SQLToken) -> None: is_complex = True while is_complex: value, is_complex = self._combine_tokens(index=index, value=value) - index = index - 2 + index = index - 1 token.value = value def _combine_tokens(self, index: int, value: str) -> Tuple[str, bool]: """ Checks if complex identifier is longer and follows back until it's finished """ - if index > 1 and str(self.non_empty_tokens[index - 1]) == ".": - prev_value = self.non_empty_tokens[index - 2].value.strip("`").strip('"') - value = f"{prev_value}.{value}" + if index > 1: + prev_value = self.non_empty_tokens[index - 1] + if not self._is_token_part_of_complex_identifier(prev_value, index - 1): + return value, False + prev_value = str(prev_value).strip("`") + value = f"{prev_value}{value}" return value, True return value, False diff --git a/test/test_mssql_server.py b/test/test_mssql_server.py index 6c5e24cf..3510b632 100644 --- a/test/test_mssql_server.py +++ b/test/test_mssql_server.py @@ -1,6 +1,37 @@ +import pytest + from sql_metadata.parser import Parser +@pytest.mark.parametrize( + "query, expected", + [ + pytest.param( + "SELECT * FROM mydb..test_table", + ["mydb..test_table"], + id="Default schema, db qualified", + ), + pytest.param( + "SELECT * FROM ..test_table", + ["..test_table"], + id="Default schema, db unqualified", + ), + pytest.param( + "SELECT * FROM [mydb].[dbo].[test_table]", + ["[mydb].[dbo].[test_table]"], + id="With object identifier delimiters", + ), + pytest.param( + "SELECT * FROM [my_server].[mydb].[dbo].[test_table]", + ["[my_server].[mydb].[dbo].[test_table]"], + id="With linked-server and object identifier delimiters", + ), + ], +) +def test_simple_queries_tables(query, expected): + assert Parser(query).tables == expected + + def test_sql_server_cte(): """ Tests support for SQL Server's common table expression (CTE).