Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes: snowflake lowercase issue #19486

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
"""
Module to define overriden dialect methods
"""

import operator
from functools import reduce
from typing import Dict, Optional

import sqlalchemy.types as sqltypes
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
from sqlalchemy import exc as sa_exc
from sqlalchemy import util as sa_util
from sqlalchemy.engine import reflection
Expand Down Expand Up @@ -52,6 +54,7 @@
get_table_comment_wrapper,
)

dialect = SnowflakeDialect()
Query = str
QueryMap = Dict[str, Query]

Expand Down Expand Up @@ -83,6 +86,20 @@
}


def _denormalize_quote_join(*idents):
ip = dialect.identifier_preparer
split_idents = reduce(
operator.add,
[ip._split_schema_by_dot(ids) for ids in idents if ids is not None],
)
quoted_identifiers = ip._quote_free_identifiers(*split_idents)
normalized_identifiers = (
item if item.startswith('"') and item.endswith('"') else f'"{item}"'
for item in quoted_identifiers
)
return ".".join(normalized_identifiers)


def _quoted_name(entity_name: Optional[str]) -> Optional[str]:
if entity_name:
return fqn.quote_name(entity_name)
Expand Down Expand Up @@ -256,17 +273,16 @@ def get_schema_columns(self, connection, schema, **kw):
None, as it is cacheable and is an unexpected return type for this function"""
ans = {}
current_database, _ = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
current_database, fqn.quote_name(schema)
)
full_schema_name = _denormalize_quote_join(current_database, fqn.quote_name(schema))
try:
schema_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, **kw
)
# removing " " from schema name because schema name is in the WHERE clause of a query
table_schema = self.denormalize_name(fqn.unquote_name(schema))
table_schema = table_schema.lower() if schema.islower() else table_schema
result = connection.execute(
text(SNOWFLAKE_GET_SCHEMA_COLUMNS),
{"table_schema": self.denormalize_name(fqn.unquote_name(schema))}
# removing " " from schema name because schema name is in the WHERE clause of a query
text(SNOWFLAKE_GET_SCHEMA_COLUMNS), {"table_schema": table_schema}
)

except sa_exc.ProgrammingError as p_err:
Expand Down Expand Up @@ -362,9 +378,10 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)

return self._get_schema_primary_keys(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, {"constrained_columns": [], "name": None})
Expand All @@ -378,7 +395,7 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)

Expand Down Expand Up @@ -452,9 +469,10 @@ def get_unique_constraints(self, connection, table_name, schema, **kw):
schema = schema or self.default_schema_name
schema = _quoted_name(entity_name=schema)
current_database, current_schema = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(
full_schema_name = _denormalize_quote_join(
current_database, schema if schema else current_schema
)

return self._get_schema_unique_constraints(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, [])
Expand Down
Loading