diff --git a/.gitignore b/.gitignore index be2b6924..8a30529f 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,7 @@ ENV/ generated_version.py *coverage.xml .DS_Store +junit.xml # Editor specific .idea/ diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 38cd70f7..9d86a86a 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -7,6 +7,10 @@ Snowflake Documentation is available at: Source code is also available at: +# Unreleased notes + +- Split large files into modules and update license headers + # Release Notes - v1.6.1(July 9, 2024) diff --git a/license_header.txt b/license_header.txt index eea2d7b3..8eb48d0f 100644 --- a/license_header.txt +++ b/license_header.txt @@ -1,2 +1,13 @@ +Copyright (c) 2024 Snowflake Inc. -Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..badf3dd7 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..f714b271 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .version import VERSION # parameters needed for usage tracking diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py deleted file mode 100644 index 1aaa881e..00000000 --- a/src/snowflake/sqlalchemy/base.py +++ /dev/null @@ -1,1068 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - -import itertools -import operator -import re - -from sqlalchemy import exc as sa_exc -from sqlalchemy import inspect, sql -from sqlalchemy import util as sa_util -from sqlalchemy.engine import default -from sqlalchemy.orm import context -from sqlalchemy.orm.context import _MapperEntity -from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression, functions -from sqlalchemy.sql.base import CompileState -from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.selectable import Lateral, SelectState - -from .compat import IS_VERSION_20, args_reducer, string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage -from .functions import flatten -from .util import ( - _find_left_clause_to_join_from, - _set_connection_interpolate_empty_sequences, - _Snowflake_ORMJoin, - _Snowflake_Selectable_Join, -) - -RESERVED_WORDS = frozenset( - [ - "ALL", # ANSI Reserved words - "ALTER", - "AND", - "ANY", - "AS", - "BETWEEN", - "BY", - "CHECK", - "COLUMN", - "CONNECT", - "COPY", - "CREATE", - "CURRENT", - "DELETE", - "DISTINCT", - "DROP", - "ELSE", - "EXISTS", - "FOR", - "FROM", - "GRANT", - "GROUP", - "HAVING", - "IN", - "INSERT", - "INTERSECT", - "INTO", - "IS", - "LIKE", - "NOT", - "NULL", - "OF", - "ON", - "OR", - "ORDER", - "REVOKE", - "ROW", - "ROWS", - "SAMPLE", - "SELECT", - "SET", - "START", - "TABLE", - "THEN", - "TO", - "TRIGGER", - "UNION", - "UNIQUE", - "UPDATE", - "VALUES", - "WHENEVER", - "WHERE", - "WITH", - "REGEXP", - "RLIKE", - "SOME", # Snowflake Reserved words - "MINUS", - "INCREMENT", # Oracle reserved words - ] -) - -# Snowflake DML: -# - UPDATE -# - INSERT -# - DELETE -# - MERGE -AUTOCOMMIT_REGEXP = re.compile( - r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE -) - - -""" -Overwrite methods to handle Snowflake BCR change: -https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 -- _join_determine_implicit_left_side -- _join_left_to_right -""" - - -# handle Snowflake BCR bcr-1057 -@CompileState.plugin_for("default", "select") -class SnowflakeSelectState(SelectState): - def _setup_joins(self, args, raw_columns): - for right, onclause, left, flags in args: - isouter = flags["isouter"] - full = flags["full"] - - if left is None: - ( - left, - replace_from_obj_index, - ) = self._join_determine_implicit_left_side( - raw_columns, left, right, onclause - ) - else: - (replace_from_obj_index) = self._join_place_explicit_left_side(left) - - if replace_from_obj_index is not None: - # splice into an existing element in the - # self._from_obj list - left_clause = self.from_clauses[replace_from_obj_index] - - self.from_clauses = ( - self.from_clauses[:replace_from_obj_index] - + ( - _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 - left_clause, - right, - onclause, - isouter=isouter, - full=full, - ), - ) - + self.from_clauses[replace_from_obj_index + 1 :] - ) - else: - self.from_clauses = self.from_clauses + ( - # handle Snowflake BCR bcr-1057 - _Snowflake_Selectable_Join( - left, right, onclause, isouter=isouter, full=full - ), - ) - - @sa_util.preload_module("sqlalchemy.sql.util") - def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause): - """When join conditions don't express the left side explicitly, - determine if an existing FROM or entity in this query - can serve as the left hand side. - - """ - - replace_from_obj_index = None - - from_clauses = self.from_clauses - - if from_clauses: - # handle Snowflake BCR bcr-1057 - indexes = _find_left_clause_to_join_from(from_clauses, right, onclause) - - if len(indexes) == 1: - replace_from_obj_index = indexes[0] - left = from_clauses[replace_from_obj_index] - else: - potential = {} - statement = self.statement - - for from_clause in itertools.chain( - itertools.chain.from_iterable( - [element._from_objects for element in raw_columns] - ), - itertools.chain.from_iterable( - [element._from_objects for element in statement._where_criteria] - ), - ): - - potential[from_clause] = () - - all_clauses = list(potential.keys()) - # handle Snowflake BCR bcr-1057 - indexes = _find_left_clause_to_join_from(all_clauses, right, onclause) - - if len(indexes) == 1: - left = all_clauses[indexes[0]] - - if len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't determine which FROM clause to join " - "from, there are multiple FROMS which can " - "join to this entity. Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already to " - "help resolve the ambiguity." - ) - elif not indexes: - raise sa_exc.InvalidRequestError( - "Don't know how to join to %r. " - "Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already to " - "help resolve the ambiguity." % (right,) - ) - return left, replace_from_obj_index - - -# handle Snowflake BCR bcr-1057 -@sql.base.CompileState.plugin_for("orm", "select") -class SnowflakeORMSelectCompileState(context.ORMSelectCompileState): - def _join_determine_implicit_left_side( - self, entities_collection, left, right, onclause - ): - """When join conditions don't express the left side explicitly, - determine if an existing FROM or entity in this query - can serve as the left hand side. - - """ - - # when we are here, it means join() was called without an ORM- - # specific way of telling us what the "left" side is, e.g.: - # - # join(RightEntity) - # - # or - # - # join(RightEntity, RightEntity.foo == LeftEntity.bar) - # - - r_info = inspect(right) - - replace_from_obj_index = use_entity_index = None - - if self.from_clauses: - # we have a list of FROMs already. So by definition this - # join has to connect to one of those FROMs. - - # handle Snowflake BCR bcr-1057 - indexes = _find_left_clause_to_join_from( - self.from_clauses, r_info.selectable, onclause - ) - - if len(indexes) == 1: - replace_from_obj_index = indexes[0] - left = self.from_clauses[replace_from_obj_index] - elif len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't determine which FROM clause to join " - "from, there are multiple FROMS which can " - "join to this entity. Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already " - "to help resolve the ambiguity." - ) - else: - raise sa_exc.InvalidRequestError( - "Don't know how to join to %r. " - "Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already " - "to help resolve the ambiguity." % (right,) - ) - - elif entities_collection: - # we have no explicit FROMs, so the implicit left has to - # come from our list of entities. - - potential = {} - for entity_index, ent in enumerate(entities_collection): - entity = ent.entity_zero_or_selectable - if entity is None: - continue - ent_info = inspect(entity) - if ent_info is r_info: # left and right are the same, skip - continue - - # by using a dictionary with the selectables as keys this - # de-duplicates those selectables as occurs when the query is - # against a series of columns from the same selectable - if isinstance(ent, context._MapperEntity): - potential[ent.selectable] = (entity_index, entity) - else: - potential[ent_info.selectable] = (None, entity) - - all_clauses = list(potential.keys()) - # handle Snowflake BCR bcr-1057 - indexes = _find_left_clause_to_join_from( - all_clauses, r_info.selectable, onclause - ) - - if len(indexes) == 1: - use_entity_index, left = potential[all_clauses[indexes[0]]] - elif len(indexes) > 1: - raise sa_exc.InvalidRequestError( - "Can't determine which FROM clause to join " - "from, there are multiple FROMS which can " - "join to this entity. Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already " - "to help resolve the ambiguity." - ) - else: - raise sa_exc.InvalidRequestError( - "Don't know how to join to %r. " - "Please use the .select_from() " - "method to establish an explicit left side, as well as " - "providing an explicit ON clause if not present already " - "to help resolve the ambiguity." % (right,) - ) - else: - raise sa_exc.InvalidRequestError( - "No entities to join from; please use " - "select_from() to establish the left " - "entity/selectable of this join" - ) - - return left, replace_from_obj_index, use_entity_index - - @args_reducer(positions_to_drop=(6, 7)) - def _join_left_to_right( - self, entities_collection, left, right, onclause, prop, outerjoin, full - ): - """given raw "left", "right", "onclause" parameters consumed from - a particular key within _join(), add a real ORMJoin object to - our _from_obj list (or augment an existing one) - - """ - - if left is None: - # left not given (e.g. no relationship object/name specified) - # figure out the best "left" side based on our existing froms / - # entities - assert prop is None - ( - left, - replace_from_obj_index, - use_entity_index, - ) = self._join_determine_implicit_left_side( - entities_collection, left, right, onclause - ) - else: - # left is given via a relationship/name, or as explicit left side. - # Determine where in our - # "froms" list it should be spliced/appended as well as what - # existing entity it corresponds to. - ( - replace_from_obj_index, - use_entity_index, - ) = self._join_place_explicit_left_side(entities_collection, left) - - if left is right: - raise sa_exc.InvalidRequestError( - "Can't construct a join from %s to %s, they " - "are the same entity" % (left, right) - ) - - # the right side as given often needs to be adapted. additionally - # a lot of things can be wrong with it. handle all that and - # get back the new effective "right" side - - if IS_VERSION_20: - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop - ) - else: - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, False, False - ) - - if not r_info.is_selectable: - extra_criteria = self._get_extra_criteria(r_info) - else: - extra_criteria = () - - if replace_from_obj_index is not None: - # splice into an existing element in the - # self._from_obj list - left_clause = self.from_clauses[replace_from_obj_index] - - self.from_clauses = ( - self.from_clauses[:replace_from_obj_index] - + [ - _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 - left_clause, - right, - onclause, - isouter=outerjoin, - full=full, - _extra_criteria=extra_criteria, - ) - ] - + self.from_clauses[replace_from_obj_index + 1 :] - ) - else: - # add a new element to the self._from_obj list - if use_entity_index is not None: - # make use of _MapperEntity selectable, which is usually - # entity_zero.selectable, but if with_polymorphic() were used - # might be distinct - assert isinstance(entities_collection[use_entity_index], _MapperEntity) - left_clause = entities_collection[use_entity_index].selectable - else: - left_clause = left - - self.from_clauses = self.from_clauses + [ - _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 - left_clause, - r_info, - onclause, - isouter=outerjoin, - full=full, - _extra_criteria=extra_criteria, - ) - ] - - -class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = {x.lower() for x in RESERVED_WORDS} - - def __init__(self, dialect, **kw): - quote = '"' - - super().__init__(dialect, initial_quote=quote, escape_quote=quote) - - def _quote_free_identifiers(self, *ids): - """ - Unilaterally identifier-quote any number of strings. - """ - return tuple(self.quote(i) for i in ids if i is not None) - - def quote_schema(self, schema, force=None): - """ - Split schema by a dot and merge with required quotes - """ - idents = self._split_schema_by_dot(schema) - return ".".join(self._quote_free_identifiers(*idents)) - - def format_label(self, label, name=None): - n = name or label.name - s = n.replace(self.escape_quote, "") - - if not isinstance(n, quoted_name) or n.quote is None: - return self.quote(s) - - return self.quote_identifier(s) if n.quote else s - - def _split_schema_by_dot(self, schema): - ret = [] - idx = 0 - pre_idx = 0 - in_quote = False - while idx < len(schema): - if not in_quote: - if schema[idx] == "." and pre_idx < idx: - ret.append(schema[pre_idx:idx]) - pre_idx = idx + 1 - elif schema[idx] == '"': - in_quote = True - pre_idx = idx + 1 - else: - if schema[idx] == '"' and pre_idx < idx: - ret.append(schema[pre_idx:idx]) - in_quote = False - pre_idx = idx + 1 - idx += 1 - if pre_idx < len(schema) and schema[pre_idx] == ".": - pre_idx += 1 - if pre_idx < idx: - ret.append(schema[pre_idx:idx]) - - # convert the returning strings back to quoted_name types, and assign the original 'quote' attribute on it - quoted_ret = [ - quoted_name(value, quote=getattr(schema, "quote", None)) for value in ret - ] - - return quoted_ret - - -class SnowflakeCompiler(compiler.SQLCompiler): - def visit_sequence(self, sequence, **kw): - return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval" - - def visit_now_func(self, now, **kw): - return "CURRENT_TIMESTAMP" - - def visit_merge_into(self, merge_into, **kw): - clauses = " ".join( - clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses - ) - return ( - f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}" - + (" " + clauses if clauses else "") - ) - - def visit_merge_into_clause(self, merge_into_clause, **kw): - case_predicate = ( - f" AND {str(merge_into_clause.predicate._compiler_dispatch(self, **kw))}" - if merge_into_clause.predicate is not None - else "" - ) - if merge_into_clause.command == "INSERT": - sets, sets_tos = zip(*merge_into_clause.set.items()) - sets, sets_tos = list(sets), list(sets_tos) - if kw.get("deterministic", False): - sets, sets_tos = zip( - *sorted(merge_into_clause.set.items(), key=operator.itemgetter(0)) - ) - return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format( - case_predicate, - merge_into_clause.command, - ", ".join(sets), - ", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_tos)), - ) - else: - set_list = list(merge_into_clause.set.items()) - if kw.get("deterministic", False): - set_list.sort(key=operator.itemgetter(0)) - sets = ( - ", ".join( - [ - f"{set[0]} = {set[1]._compiler_dispatch(self, **kw)}" - for set in set_list - ] - ) - if merge_into_clause.set - else "" - ) - return "WHEN MATCHED{} THEN {}{}".format( - case_predicate, - merge_into_clause.command, - " SET %s" % sets if merge_into_clause.set else "", - ) - - def visit_copy_into(self, copy_into, **kw): - if hasattr(copy_into, "formatter") and copy_into.formatter is not None: - formatter = copy_into.formatter._compiler_dispatch(self, **kw) - else: - formatter = "" - into = ( - copy_into.into - if isinstance(copy_into.into, Table) - else copy_into.into._compiler_dispatch(self, **kw) - ) - from_ = None - if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ - # this is intended to catch AWSBucket and AzureContainer - elif ( - isinstance(copy_into.from_, AWSBucket) - or isinstance(copy_into.from_, AzureContainer) - or isinstance(copy_into.from_, ExternalStage) - ): - from_ = copy_into.from_._compiler_dispatch(self, **kw) - # everything else (selects, etc.) - else: - from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" - credentials, encryption = "", "" - if isinstance(into, tuple): - into, credentials, encryption = into - elif isinstance(from_, tuple): - from_, credentials, encryption = from_ - options_list = list(copy_into.copy_options.items()) - if kw.get("deterministic", False): - options_list.sort(key=operator.itemgetter(0)) - options = ( - ( - " " - + " ".join( - [ - "{} = {}".format( - n, - ( - v._compiler_dispatch(self, **kw) - if getattr(v, "compiler_dispatch", False) - else str(v) - ), - ) - for n, v in options_list - ] - ) - ) - if copy_into.copy_options - else "" - ) - if credentials: - options += f" {credentials}" - if encryption: - options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" - - def visit_copy_formatter(self, formatter, **kw): - options_list = list(formatter.options.items()) - if kw.get("deterministic", False): - options_list.sort(key=operator.itemgetter(0)) - if "format_name" in formatter.options: - return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" - return "FILE_FORMAT=(TYPE={}{})".format( - formatter.file_format, - ( - " " - + " ".join( - [ - "{}={}".format( - name, - ( - value._compiler_dispatch(self, **kw) - if hasattr(value, "_compiler_dispatch") - else formatter.value_repr(name, value) - ), - ) - for name, value in options_list - ] - ) - if formatter.options - else "" - ), - ) - - def visit_aws_bucket(self, aws_bucket, **kw): - credentials_list = list(aws_bucket.credentials_used.items()) - if kw.get("deterministic", False): - credentials_list.sort(key=operator.itemgetter(0)) - credentials = "CREDENTIALS=({})".format( - " ".join(f"{n}='{v}'" for n, v in credentials_list) - ) - encryption_list = list(aws_bucket.encryption_used.items()) - if kw.get("deterministic", False): - encryption_list.sort(key=operator.itemgetter(0)) - encryption = "ENCRYPTION=({})".format( - " ".join( - ("{}='{}'" if isinstance(v, string_types) else "{}={}").format(n, v) - for n, v in encryption_list - ) - ) - uri = "'s3://{}{}'".format( - aws_bucket.bucket, f"/{aws_bucket.path}" if aws_bucket.path else "" - ) - return ( - uri, - credentials if aws_bucket.credentials_used else "", - encryption if aws_bucket.encryption_used else "", - ) - - def visit_azure_container(self, azure_container, **kw): - credentials_list = list(azure_container.credentials_used.items()) - if kw.get("deterministic", False): - credentials_list.sort(key=operator.itemgetter(0)) - credentials = "CREDENTIALS=({})".format( - " ".join(f"{n}='{v}'" for n, v in credentials_list) - ) - encryption_list = list(azure_container.encryption_used.items()) - if kw.get("deterministic", False): - encryption_list.sort(key=operator.itemgetter(0)) - encryption = "ENCRYPTION=({})".format( - " ".join( - f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" - for n, v in encryption_list - ) - ) - uri = "'azure://{}.blob.core.windows.net/{}{}'".format( - azure_container.account, - azure_container.container, - f"/{azure_container.path}" if azure_container.path else "", - ) - return ( - uri, - credentials if azure_container.credentials_used else "", - encryption if azure_container.encryption_used else "", - ) - - def visit_external_stage(self, external_stage, **kw): - if external_stage.file_format is None: - return ( - f"@{external_stage.namespace}{external_stage.name}{external_stage.path}" - ) - return f"@{external_stage.namespace}{external_stage.name}{external_stage.path} (file_format => {external_stage.file_format})" - - def delete_extra_from_clause( - self, delete_stmt, from_table, extra_froms, from_hints, **kw - ): - return "USING " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) - for t in extra_froms - ) - - def update_from_clause( - self, update_stmt, from_table, extra_froms, from_hints, **kw - ): - return "FROM " + ", ".join( - t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) - for t in extra_froms - ) - - def _get_regexp_args(self, binary, kw): - string = self.process(binary.left, **kw) - pattern = self.process(binary.right, **kw) - flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) - return string, pattern, flags - - def visit_regexp_match_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) - if flags is None: - return f"REGEXP_LIKE({string}, {pattern})" - else: - return f"REGEXP_LIKE({string}, {pattern}, {flags})" - - def visit_regexp_replace_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) - try: - replacement = self.process(binary.modifiers["replacement"], **kw) - except KeyError: - # in sqlalchemy 1.4.49, the internal structure of the expression is changed - # that binary.modifiers doesn't have "replacement": - # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49 - return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})" - - if flags is None: - return f"REGEXP_REPLACE({string}, {pattern}, {replacement})" - else: - return f"REGEXP_REPLACE({string}, {pattern}, {replacement}, {flags})" - - def visit_not_regexp_match_op_binary(self, binary, operator, **kw): - return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}" - - def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): - if from_linter: - from_linter.edges.update( - itertools.product(join.left._from_objects, join.right._from_objects) - ) - - if join.full: - join_type = " FULL OUTER JOIN " - elif join.isouter: - join_type = " LEFT OUTER JOIN " - else: - join_type = " JOIN " - - join_statement = ( - join.left._compiler_dispatch( - self, asfrom=True, from_linter=from_linter, **kwargs - ) - + join_type - + join.right._compiler_dispatch( - self, asfrom=True, from_linter=from_linter, **kwargs - ) - ) - - if join.onclause is None and isinstance(join.right, Lateral): - # in snowflake, onclause is not accepted for lateral due to BCR change: - # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 - # sqlalchemy only allows join with on condition. - # to adapt to snowflake syntax change, - # we make the change such that when oncaluse is None and the right part is - # Lateral, we do not append the on condition - return join_statement - - return ( - join_statement - + " ON " - # TODO: likely need asfrom=True here? - + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) - ) - - def render_literal_value(self, value, type_): - # escape backslash - return super().render_literal_value(value, type_).replace("\\", "\\\\") - - -class SnowflakeExecutionContext(default.DefaultExecutionContext): - INSERT_SQL_RE = re.compile(r"^insert\s+into", flags=re.IGNORECASE) - - def fire_sequence(self, seq, type_): - return self._execute_scalar( - f"SELECT {self.identifier_preparer.format_sequence(seq)}.nextval", - type_, - ) - - def should_autocommit_text(self, statement): - return AUTOCOMMIT_REGEXP.match(statement) - - @sa_util.memoized_property - def should_autocommit(self): - autocommit = self.execution_options.get( - "autocommit", - not self.compiled - and self.statement - and expression.PARSE_AUTOCOMMIT - or False, - ) - - if autocommit is expression.PARSE_AUTOCOMMIT: - return self.should_autocommit_text(self.unicode_statement) - else: - return autocommit and not self.isddl - - def pre_exec(self): - if self.compiled and self.identifier_preparer._double_percents: - # for compiled statements, percent is doubled for escape, we turn on _interpolate_empty_sequences - _set_connection_interpolate_empty_sequences(self._dbapi_connection, True) - - # if the statement is executemany insert, setting _interpolate_empty_sequences to True is not enough, - # because executemany pre-processes the param binding and then pass None params to execute so - # _interpolate_empty_sequences condition not getting met for the command. - # Therefore, we manually revert the escape percent in the command here - if self.executemany and self.INSERT_SQL_RE.match(self.statement): - self.statement = self.statement.replace("%%", "%") - else: - # for other cases, do no interpolate empty sequences as "%" is not double escaped - _set_connection_interpolate_empty_sequences(self._dbapi_connection, False) - - def post_exec(self): - if self.compiled and self.identifier_preparer._double_percents: - # for compiled statements, percent is doubled for escapeafter execution - # we reset _interpolate_empty_sequences to false which is turned on in pre_exec - _set_connection_interpolate_empty_sequences(self._dbapi_connection, False) - - @property - def rowcount(self): - return self.cursor.rowcount - - -class SnowflakeDDLCompiler(compiler.DDLCompiler): - def denormalize_column_name(self, name): - if name is None: - return None - elif name.lower() == name and not self.preparer._requires_quotes(name.lower()): - # no quote as case insensitive - return name - return self.preparer.quote(name) - - def get_column_specification(self, column, **kwargs): - """ - Gets Column specifications - """ - colspec = [ - self.preparer.format_column(column), - self.dialect.type_compiler.process(column.type, type_expression=column), - ] - - has_identity = ( - column.identity is not None and self.dialect.supports_identity_columns - ) - - if not column.nullable: - colspec.append("NOT NULL") - - default = self.get_column_default_string(column) - if default is not None: - colspec.append("DEFAULT " + default) - - # TODO: This makes the first INTEGER column AUTOINCREMENT. - # But the column is not really considered so unless - # postfetch_lastrowid is enabled. But it is very unlikely to happen... - if ( - column.table is not None - and column is column.table._autoincrement_column - and column.server_default is None - ): - if isinstance(column.default, Sequence): - colspec.append( - f"DEFAULT {self.dialect.identifier_preparer.format_sequence(column.default)}.nextval" - ) - else: - colspec.append("AUTOINCREMENT") - - if has_identity: - colspec.append(self.process(column.identity)) - - return " ".join(colspec) - - def post_create_table(self, table): - """ - Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. - - Users can specify the `clusterby` property per table - using the dialect specific syntax. - For example, to specify a cluster by key you apply the following: - - >>> import sqlalchemy as sa - >>> from sqlalchemy.schema import CreateTable - >>> engine = sa.create_engine('snowflake://om1') - >>> metadata = sa.MetaData() - >>> user = sa.Table( - ... 'user', - ... metadata, - ... sa.Column('id', sa.Integer, primary_key=True), - ... sa.Column('name', sa.String), - ... snowflake_clusterby=['id', 'name'] - ... ) - >>> print(CreateTable(user).compile(engine)) - - CREATE TABLE "user" ( - id INTEGER NOT NULL AUTOINCREMENT, - name VARCHAR, - PRIMARY KEY (id) - ) CLUSTER BY (id, name) - - - """ - text = "" - info = table.dialect_options["snowflake"] - cluster = info.get("clusterby") - if cluster: - text += " CLUSTER BY ({})".format( - ", ".join(self.denormalize_column_name(key) for key in cluster) - ) - return text - - def visit_create_stage(self, create_stage, **kw): - """ - This visitor will create the SQL representation for a CREATE STAGE command. - """ - return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format( - create_stage.stage.namespace, - create_stage.stage.name, - repr(create_stage.container), - or_replace="OR REPLACE " if create_stage.replace_if_exists else "", - temporary="TEMPORARY " if create_stage.temporary else "", - ) - - def visit_create_file_format(self, file_format, **kw): - """ - This visitor will create the SQL representation for a CREATE FILE FORMAT - command. - """ - return "CREATE {}FILE FORMAT {} TYPE='{}' {}".format( - "OR REPLACE " if file_format.replace_if_exists else "", - file_format.format_name, - file_format.formatter.file_format, - " ".join( - [ - f"{name} = {file_format.formatter.value_repr(name, value)}" - for name, value in file_format.formatter.options.items() - ] - ), - ) - - def visit_drop_table_comment(self, drop, **kw): - """Snowflake does not support setting table comments as NULL. - - Reflection has to account for this and convert any empty comments to NULL. - """ - table_name = self.preparer.format_table(drop.element) - return f"COMMENT ON TABLE {table_name} IS ''" - - def visit_drop_column_comment(self, drop, **kw): - """Snowflake does not support directly setting column comments as NULL. - - Instead we are forced to use the ALTER COLUMN ... UNSET COMMENT instead. - """ - return "ALTER TABLE {} ALTER COLUMN {} UNSET COMMENT".format( - self.preparer.format_table(drop.element.table), - self.preparer.format_column(drop.element), - ) - - def visit_identity_column(self, identity, **kw): - text = "IDENTITY" - if identity.start is not None or identity.increment is not None: - start = 1 if identity.start is None else identity.start - increment = 1 if identity.increment is None else identity.increment - text += f"({start},{increment})" - if identity.order is not None: - order = "ORDER" if identity.order else "NOORDER" - text += f" {order}" - return text - - def get_identity_options(self, identity_options): - text = [] - if identity_options.increment is not None: - text.append("INCREMENT BY %d" % identity_options.increment) - if identity_options.start is not None: - text.append("START WITH %d" % identity_options.start) - if identity_options.minvalue is not None: - text.append("MINVALUE %d" % identity_options.minvalue) - if identity_options.maxvalue is not None: - text.append("MAXVALUE %d" % identity_options.maxvalue) - if identity_options.nominvalue is not None: - text.append("NO MINVALUE") - if identity_options.nomaxvalue is not None: - text.append("NO MAXVALUE") - if identity_options.cache is not None: - text.append("CACHE %d" % identity_options.cache) - if identity_options.cycle is not None: - text.append("CYCLE" if identity_options.cycle else "NO CYCLE") - if identity_options.order is not None: - text.append("ORDER" if identity_options.order else "NOORDER") - return " ".join(text) - - -class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): - def visit_BYTEINT(self, type_, **kw): - return "BYTEINT" - - def visit_CHARACTER(self, type_, **kw): - return "CHARACTER" - - def visit_DEC(self, type_, **kw): - return "DEC" - - def visit_DOUBLE(self, type_, **kw): - return "DOUBLE" - - def visit_FIXED(self, type_, **kw): - return "FIXED" - - def visit_INT(self, type_, **kw): - return "INT" - - def visit_NUMBER(self, type_, **kw): - return "NUMBER" - - def visit_STRING(self, type_, **kw): - return "STRING" - - def visit_TINYINT(self, type_, **kw): - return "TINYINT" - - def visit_VARIANT(self, type_, **kw): - return "VARIANT" - - def visit_ARRAY(self, type_, **kw): - return "ARRAY" - - def visit_OBJECT(self, type_, **kw): - return "OBJECT" - - def visit_BLOB(self, type_, **kw): - return "BINARY" - - def visit_datetime(self, type_, **kw): - return "datetime" - - def visit_DATETIME(self, type_, **kw): - return "DATETIME" - - def visit_TIMESTAMP_NTZ(self, type_, **kw): - return "TIMESTAMP_NTZ" - - def visit_TIMESTAMP_TZ(self, type_, **kw): - return "TIMESTAMP_TZ" - - def visit_TIMESTAMP_LTZ(self, type_, **kw): - return "TIMESTAMP_LTZ" - - def visit_TIMESTAMP(self, type_, **kw): - return "TIMESTAMP" - - def visit_GEOGRAPHY(self, type_, **kw): - return "GEOGRAPHY" - - def visit_GEOMETRY(self, type_, **kw): - return "GEOMETRY" - - -construct_arguments = [(Table, {"clusterby": None})] - -functions.register_function("flatten", flatten) diff --git a/src/snowflake/sqlalchemy/base/__init__.py b/src/snowflake/sqlalchemy/base/__init__.py new file mode 100644 index 00000000..f990b66f --- /dev/null +++ b/src/snowflake/sqlalchemy/base/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.schema import Table +from sqlalchemy.sql import functions + +from ..functions import flatten +from .compiler.snowflake_compiler import SnowflakeCompiler +from .compiler.snowflake_ddl_compiler import SnowflakeDDLCompiler +from .compiler.snowflake_identifier_preparer import SnowflakeIdentifierPreparer +from .compiler.snowflake_type_compiler import SnowflakeTypeCompiler +from .constants import RESERVED_WORDS +from .snowflake_execution_context import SnowflakeExecutionContext +from .snowflake_orm_select_compile_state import SnowflakeORMSelectCompileState +from .snowflake_select_state import SnowflakeSelectState + +# The __all__ list is used to maintain backward compatibility with previous versions of the package. +# After splitting a large file into a module, __all__ explicitly defines the public API, ensuring +# that existing imports from the original file still work as expected. This approach prevents +# breaking changes by controlling which functions, classes, and variables are exposed when +# the module is imported. +__all__ = [ + "SnowflakeDDLCompiler", + "SnowflakeCompiler", + "SnowflakeIdentifierPreparer", + "SnowflakeTypeCompiler", + "SnowflakeExecutionContext", + "SnowflakeORMSelectCompileState", + "SnowflakeSelectState", + "RESERVED_WORDS", +] + + +construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten) diff --git a/src/snowflake/sqlalchemy/base/compiler/__init__.py b/src/snowflake/sqlalchemy/base/compiler/__init__.py new file mode 100644 index 00000000..ada0a4e1 --- /dev/null +++ b/src/snowflake/sqlalchemy/base/compiler/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/snowflake/sqlalchemy/base/compiler/snowflake_compiler.py b/src/snowflake/sqlalchemy/base/compiler/snowflake_compiler.py new file mode 100644 index 00000000..da9411e3 --- /dev/null +++ b/src/snowflake/sqlalchemy/base/compiler/snowflake_compiler.py @@ -0,0 +1,319 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import operator + +from sqlalchemy.schema import Table +from sqlalchemy.sql import compiler +from sqlalchemy.sql.selectable import Lateral + +from snowflake.sqlalchemy.compat import string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + + +class SnowflakeCompiler(compiler.SQLCompiler): + def visit_sequence(self, sequence, **kw): + return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval" + + def visit_now_func(self, now, **kw): + return "CURRENT_TIMESTAMP" + + def visit_merge_into(self, merge_into, **kw): + clauses = " ".join( + clause._compiler_dispatch(self, **kw) for clause in merge_into.clauses + ) + return ( + f"MERGE INTO {merge_into.target} USING {merge_into.source} ON {merge_into.on}" + + (" " + clauses if clauses else "") + ) + + def visit_merge_into_clause(self, merge_into_clause, **kw): + case_predicate = ( + f" AND {str(merge_into_clause.predicate._compiler_dispatch(self, **kw))}" + if merge_into_clause.predicate is not None + else "" + ) + if merge_into_clause.command == "INSERT": + sets, sets_tos = zip(*merge_into_clause.set.items()) + sets, sets_tos = list(sets), list(sets_tos) + if kw.get("deterministic", False): + sets, sets_tos = zip( + *sorted(merge_into_clause.set.items(), key=operator.itemgetter(0)) + ) + return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format( + case_predicate, + merge_into_clause.command, + ", ".join(sets), + ", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_tos)), + ) + else: + set_list = list(merge_into_clause.set.items()) + if kw.get("deterministic", False): + set_list.sort(key=operator.itemgetter(0)) + sets = ( + ", ".join( + [ + f"{set[0]} = {set[1]._compiler_dispatch(self, **kw)}" + for set in set_list + ] + ) + if merge_into_clause.set + else "" + ) + return "WHEN MATCHED{} THEN {}{}".format( + case_predicate, + merge_into_clause.command, + " SET %s" % sets if merge_into_clause.set else "", + ) + + def visit_copy_into(self, copy_into, **kw): + if hasattr(copy_into, "formatter") and copy_into.formatter is not None: + formatter = copy_into.formatter._compiler_dispatch(self, **kw) + else: + formatter = "" + into = ( + copy_into.into + if isinstance(copy_into.into, Table) + else copy_into.into._compiler_dispatch(self, **kw) + ) + from_ = None + if isinstance(copy_into.from_, Table): + from_ = copy_into.from_ + # this is intended to catch AWSBucket and AzureContainer + elif ( + isinstance(copy_into.from_, AWSBucket) + or isinstance(copy_into.from_, AzureContainer) + or isinstance(copy_into.from_, ExternalStage) + ): + from_ = copy_into.from_._compiler_dispatch(self, **kw) + # everything else (selects, etc.) + else: + from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + credentials, encryption = "", "" + if isinstance(into, tuple): + into, credentials, encryption = into + elif isinstance(from_, tuple): + from_, credentials, encryption = from_ + options_list = list(copy_into.copy_options.items()) + if kw.get("deterministic", False): + options_list.sort(key=operator.itemgetter(0)) + options = ( + ( + " " + + " ".join( + [ + "{} = {}".format( + n, + ( + v._compiler_dispatch(self, **kw) + if getattr(v, "compiler_dispatch", False) + else str(v) + ), + ) + for n, v in options_list + ] + ) + ) + if copy_into.copy_options + else "" + ) + if credentials: + options += f" {credentials}" + if encryption: + options += f" {encryption}" + return f"COPY INTO {into} FROM {from_} {formatter}{options}" + + def visit_copy_formatter(self, formatter, **kw): + options_list = list(formatter.options.items()) + if kw.get("deterministic", False): + options_list.sort(key=operator.itemgetter(0)) + if "format_name" in formatter.options: + return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})" + return "FILE_FORMAT=(TYPE={}{})".format( + formatter.file_format, + ( + " " + + " ".join( + [ + "{}={}".format( + name, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else formatter.value_repr(name, value) + ), + ) + for name, value in options_list + ] + ) + if formatter.options + else "" + ), + ) + + def visit_aws_bucket(self, aws_bucket, **kw): + credentials_list = list(aws_bucket.credentials_used.items()) + if kw.get("deterministic", False): + credentials_list.sort(key=operator.itemgetter(0)) + credentials = "CREDENTIALS=({})".format( + " ".join(f"{n}='{v}'" for n, v in credentials_list) + ) + encryption_list = list(aws_bucket.encryption_used.items()) + if kw.get("deterministic", False): + encryption_list.sort(key=operator.itemgetter(0)) + encryption = "ENCRYPTION=({})".format( + " ".join( + ("{}='{}'" if isinstance(v, string_types) else "{}={}").format(n, v) + for n, v in encryption_list + ) + ) + uri = "'s3://{}{}'".format( + aws_bucket.bucket, f"/{aws_bucket.path}" if aws_bucket.path else "" + ) + return ( + uri, + credentials if aws_bucket.credentials_used else "", + encryption if aws_bucket.encryption_used else "", + ) + + def visit_azure_container(self, azure_container, **kw): + credentials_list = list(azure_container.credentials_used.items()) + if kw.get("deterministic", False): + credentials_list.sort(key=operator.itemgetter(0)) + credentials = "CREDENTIALS=({})".format( + " ".join(f"{n}='{v}'" for n, v in credentials_list) + ) + encryption_list = list(azure_container.encryption_used.items()) + if kw.get("deterministic", False): + encryption_list.sort(key=operator.itemgetter(0)) + encryption = "ENCRYPTION=({})".format( + " ".join( + f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" + for n, v in encryption_list + ) + ) + uri = "'azure://{}.blob.core.windows.net/{}{}'".format( + azure_container.account, + azure_container.container, + f"/{azure_container.path}" if azure_container.path else "", + ) + return ( + uri, + credentials if azure_container.credentials_used else "", + encryption if azure_container.encryption_used else "", + ) + + def visit_external_stage(self, external_stage, **kw): + if external_stage.file_format is None: + return ( + f"@{external_stage.namespace}{external_stage.name}{external_stage.path}" + ) + return f"@{external_stage.namespace}{external_stage.name}{external_stage.path} (file_format => {external_stage.file_format})" + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def _get_regexp_args(self, binary, kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is not None: + flags = self.process(flags, **kw) + return string, pattern, flags + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + string, pattern, flags = self._get_regexp_args(binary, kw) + if flags is None: + return f"REGEXP_LIKE({string}, {pattern})" + else: + return f"REGEXP_LIKE({string}, {pattern}, {flags})" + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string, pattern, flags = self._get_regexp_args(binary, kw) + try: + replacement = self.process(binary.modifiers["replacement"], **kw) + except KeyError: + # in sqlalchemy 1.4.49, the internal structure of the expression is changed + # that binary.modifiers doesn't have "replacement": + # https://docs.sqlalchemy.org/en/20/changelog/changelog_14.html#change-1.4.49 + return f"REGEXP_REPLACE({string}, {pattern}{'' if flags is None else f', {flags}'})" + + if flags is None: + return f"REGEXP_REPLACE({string}, {pattern}, {replacement})" + else: + return f"REGEXP_REPLACE({string}, {pattern}, {replacement}, {flags})" + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return f"NOT {self.visit_regexp_match_op_binary(binary, operator, **kw)}" + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product(join.left._from_objects, join.right._from_objects) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + + join_statement = ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + ) + + if join.onclause is None and isinstance(join.right, Lateral): + # in snowflake, onclause is not accepted for lateral due to BCR change: + # https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 + # sqlalchemy only allows join with on condition. + # to adapt to snowflake syntax change, + # we make the change such that when oncaluse is None and the right part is + # Lateral, we do not append the on condition + return join_statement + + return ( + join_statement + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch(self, from_linter=from_linter, **kwargs) + ) + + def render_literal_value(self, value, type_): + # escape backslash + return super().render_literal_value(value, type_).replace("\\", "\\\\") diff --git a/src/snowflake/sqlalchemy/base/compiler/snowflake_ddl_compiler.py b/src/snowflake/sqlalchemy/base/compiler/snowflake_ddl_compiler.py new file mode 100644 index 00000000..6d1e3ad9 --- /dev/null +++ b/src/snowflake/sqlalchemy/base/compiler/snowflake_ddl_compiler.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import Sequence +from sqlalchemy.sql import compiler + + +class SnowflakeDDLCompiler(compiler.DDLCompiler): + def denormalize_column_name(self, name): + if name is None: + return None + elif name.lower() == name and not self.preparer._requires_quotes(name.lower()): + # no quote as case insensitive + return name + return self.preparer.quote(name) + + def get_column_specification(self, column, **kwargs): + """ + Gets Column specifications + """ + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process(column.type, type_expression=column), + ] + + has_identity = ( + column.identity is not None and self.dialect.supports_identity_columns + ) + + if not column.nullable: + colspec.append("NOT NULL") + + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) + + # TODO: This makes the first INTEGER column AUTOINCREMENT. + # But the column is not really considered so unless + # postfetch_lastrowid is enabled. But it is very unlikely to happen... + if ( + column.table is not None + and column is column.table._autoincrement_column + and column.server_default is None + ): + if isinstance(column.default, Sequence): + colspec.append( + f"DEFAULT {self.dialect.identifier_preparer.format_sequence(column.default)}.nextval" + ) + else: + colspec.append("AUTOINCREMENT") + + if has_identity: + colspec.append(self.process(column.identity)) + + return " ".join(colspec) + + def post_create_table(self, table): + """ + Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. + + Users can specify the `clusterby` property per table + using the dialect specific syntax. + For example, to specify a cluster by key you apply the following: + + >>> import sqlalchemy as sa + >>> from sqlalchemy.schema import CreateTable + >>> engine = sa.create_engine('snowflake://om1') + >>> metadata = sa.MetaData() + >>> user = sa.Table( + ... 'user', + ... metadata, + ... sa.Column('id', sa.Integer, primary_key=True), + ... sa.Column('name', sa.String), + ... snowflake_clusterby=['id', 'name'] + ... ) + >>> print(CreateTable(user).compile(engine)) + + CREATE TABLE "user" ( + id INTEGER NOT NULL AUTOINCREMENT, + name VARCHAR, + PRIMARY KEY (id) + ) CLUSTER BY (id, name) + + + """ + text = "" + info = table.dialect_options["snowflake"] + cluster = info.get("clusterby") + if cluster: + text += " CLUSTER BY ({})".format( + ", ".join(self.denormalize_column_name(key) for key in cluster) + ) + return text + + def visit_create_stage(self, create_stage, **kw): + """ + This visitor will create the SQL representation for a CREATE STAGE command. + """ + return "CREATE {or_replace}{temporary}STAGE {}{} URL={}".format( + create_stage.stage.namespace, + create_stage.stage.name, + repr(create_stage.container), + or_replace="OR REPLACE " if create_stage.replace_if_exists else "", + temporary="TEMPORARY " if create_stage.temporary else "", + ) + + def visit_create_file_format(self, file_format, **kw): + """ + This visitor will create the SQL representation for a CREATE FILE FORMAT + command. + """ + return "CREATE {}FILE FORMAT {} TYPE='{}' {}".format( + "OR REPLACE " if file_format.replace_if_exists else "", + file_format.format_name, + file_format.formatter.file_format, + " ".join( + [ + f"{name} = {file_format.formatter.value_repr(name, value)}" + for name, value in file_format.formatter.options.items() + ] + ), + ) + + def visit_drop_table_comment(self, drop, **kw): + """Snowflake does not support setting table comments as NULL. + + Reflection has to account for this and convert any empty comments to NULL. + """ + table_name = self.preparer.format_table(drop.element) + return f"COMMENT ON TABLE {table_name} IS ''" + + def visit_drop_column_comment(self, drop, **kw): + """Snowflake does not support directly setting column comments as NULL. + + Instead we are forced to use the ALTER COLUMN ... UNSET COMMENT instead. + """ + return "ALTER TABLE {} ALTER COLUMN {} UNSET COMMENT".format( + self.preparer.format_table(drop.element.table), + self.preparer.format_column(drop.element), + ) + + def visit_identity_column(self, identity, **kw): + text = "IDENTITY" + if identity.start is not None or identity.increment is not None: + start = 1 if identity.start is None else identity.start + increment = 1 if identity.increment is None else identity.increment + text += f"({start},{increment})" + if identity.order is not None: + order = "ORDER" if identity.order else "NOORDER" + text += f" {order}" + return text + + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NOORDER") + return " ".join(text) diff --git a/src/snowflake/sqlalchemy/base/compiler/snowflake_identifier_preparer.py b/src/snowflake/sqlalchemy/base/compiler/snowflake_identifier_preparer.py new file mode 100644 index 00000000..b73808fa --- /dev/null +++ b/src/snowflake/sqlalchemy/base/compiler/snowflake_identifier_preparer.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sqlalchemy.sql import compiler +from sqlalchemy.sql.elements import quoted_name + +from ..constants import RESERVED_WORDS + + +class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = {x.lower() for x in RESERVED_WORDS} + + def __init__(self, dialect, **kw): + quote = '"' + + super().__init__(dialect, initial_quote=quote, escape_quote=quote) + + def _quote_free_identifiers(self, *ids): + """ + Unilaterally identifier-quote any number of strings. + """ + return tuple(self.quote(i) for i in ids if i is not None) + + def quote_schema(self, schema, force=None): + """ + Split schema by a dot and merge with required quotes + """ + idents = self._split_schema_by_dot(schema) + return ".".join(self._quote_free_identifiers(*idents)) + + def format_label(self, label, name=None): + n = name or label.name + s = n.replace(self.escape_quote, "") + + if not isinstance(n, quoted_name) or n.quote is None: + return self.quote(s) + + return self.quote_identifier(s) if n.quote else s + + def _split_schema_by_dot(self, schema): + ret = [] + idx = 0 + pre_idx = 0 + in_quote = False + while idx < len(schema): + if not in_quote: + if schema[idx] == "." and pre_idx < idx: + ret.append(schema[pre_idx:idx]) + pre_idx = idx + 1 + elif schema[idx] == '"': + in_quote = True + pre_idx = idx + 1 + else: + if schema[idx] == '"' and pre_idx < idx: + ret.append(schema[pre_idx:idx]) + in_quote = False + pre_idx = idx + 1 + idx += 1 + if pre_idx < len(schema) and schema[pre_idx] == ".": + pre_idx += 1 + if pre_idx < idx: + ret.append(schema[pre_idx:idx]) + + # convert the returning strings back to quoted_name types, and assign the original 'quote' attribute on it + quoted_ret = [ + quoted_name(value, quote=getattr(schema, "quote", None)) for value in ret + ] + + return quoted_ret diff --git a/src/snowflake/sqlalchemy/base/compiler/snowflake_type_compiler.py b/src/snowflake/sqlalchemy/base/compiler/snowflake_type_compiler.py new file mode 100644 index 00000000..ba6a7545 --- /dev/null +++ b/src/snowflake/sqlalchemy/base/compiler/snowflake_type_compiler.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql import compiler + + +class SnowflakeTypeCompiler(compiler.GenericTypeCompiler): + def visit_BYTEINT(self, type_, **kw): + return "BYTEINT" + + def visit_CHARACTER(self, type_, **kw): + return "CHARACTER" + + def visit_DEC(self, type_, **kw): + return "DEC" + + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_FIXED(self, type_, **kw): + return "FIXED" + + def visit_INT(self, type_, **kw): + return "INT" + + def visit_NUMBER(self, type_, **kw): + return "NUMBER" + + def visit_STRING(self, type_, **kw): + return "STRING" + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_VARIANT(self, type_, **kw): + return "VARIANT" + + def visit_ARRAY(self, type_, **kw): + return "ARRAY" + + def visit_OBJECT(self, type_, **kw): + return "OBJECT" + + def visit_BLOB(self, type_, **kw): + return "BINARY" + + def visit_datetime(self, type_, **kw): + return "datetime" + + def visit_DATETIME(self, type_, **kw): + return "DATETIME" + + def visit_TIMESTAMP_NTZ(self, type_, **kw): + return "TIMESTAMP_NTZ" + + def visit_TIMESTAMP_TZ(self, type_, **kw): + return "TIMESTAMP_TZ" + + def visit_TIMESTAMP_LTZ(self, type_, **kw): + return "TIMESTAMP_LTZ" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_GEOGRAPHY(self, type_, **kw): + return "GEOGRAPHY" + + def visit_GEOMETRY(self, type_, **kw): + return "GEOMETRY" diff --git a/src/snowflake/sqlalchemy/base/constants.py b/src/snowflake/sqlalchemy/base/constants.py new file mode 100644 index 00000000..c5208d5b --- /dev/null +++ b/src/snowflake/sqlalchemy/base/constants.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +RESERVED_WORDS = frozenset( + [ + "ALL", # ANSI Reserved words + "ALTER", + "AND", + "ANY", + "AS", + "BETWEEN", + "BY", + "CHECK", + "COLUMN", + "CONNECT", + "COPY", + "CREATE", + "CURRENT", + "DELETE", + "DISTINCT", + "DROP", + "ELSE", + "EXISTS", + "FOR", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INSERT", + "INTERSECT", + "INTO", + "IS", + "LIKE", + "NOT", + "NULL", + "OF", + "ON", + "OR", + "ORDER", + "REVOKE", + "ROW", + "ROWS", + "SAMPLE", + "SELECT", + "SET", + "START", + "TABLE", + "THEN", + "TO", + "TRIGGER", + "UNION", + "UNIQUE", + "UPDATE", + "VALUES", + "WHENEVER", + "WHERE", + "WITH", + "REGEXP", + "RLIKE", + "SOME", # Snowflake Reserved words + "MINUS", + "INCREMENT", # Oracle reserved words + ] +) diff --git a/src/snowflake/sqlalchemy/base/snowflake_execution_context.py b/src/snowflake/sqlalchemy/base/snowflake_execution_context.py new file mode 100644 index 00000000..11373fcb --- /dev/null +++ b/src/snowflake/sqlalchemy/base/snowflake_execution_context.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from sqlalchemy import util as sa_util +from sqlalchemy.engine import default +from sqlalchemy.sql import expression + +from ..util import _set_connection_interpolate_empty_sequences + +# Snowflake DML: +# - UPDATE +# - INSERT +# - DELETE +# - MERGE +AUTOCOMMIT_REGEXP = re.compile( + r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE +) + + +class SnowflakeExecutionContext(default.DefaultExecutionContext): + INSERT_SQL_RE = re.compile(r"^insert\s+into", flags=re.IGNORECASE) + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + f"SELECT {self.identifier_preparer.format_sequence(seq)}.nextval", + type_, + ) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + + @sa_util.memoized_property + def should_autocommit(self): + autocommit = self.execution_options.get( + "autocommit", + not self.compiled + and self.statement + and expression.PARSE_AUTOCOMMIT + or False, + ) + + if autocommit is expression.PARSE_AUTOCOMMIT: + return self.should_autocommit_text(self.unicode_statement) + else: + return autocommit and not self.isddl + + def pre_exec(self): + if self.compiled and self.identifier_preparer._double_percents: + # for compiled statements, percent is doubled for escape, we turn on _interpolate_empty_sequences + _set_connection_interpolate_empty_sequences(self._dbapi_connection, True) + + # if the statement is executemany insert, setting _interpolate_empty_sequences to True is not enough, + # because executemany pre-processes the param binding and then pass None params to execute so + # _interpolate_empty_sequences condition not getting met for the command. + # Therefore, we manually revert the escape percent in the command here + if self.executemany and self.INSERT_SQL_RE.match(self.statement): + self.statement = self.statement.replace("%%", "%") + else: + # for other cases, do no interpolate empty sequences as "%" is not double escaped + _set_connection_interpolate_empty_sequences(self._dbapi_connection, False) + + def post_exec(self): + if self.compiled and self.identifier_preparer._double_percents: + # for compiled statements, percent is doubled for escapeafter execution + # we reset _interpolate_empty_sequences to false which is turned on in pre_exec + _set_connection_interpolate_empty_sequences(self._dbapi_connection, False) + + @property + def rowcount(self): + return self.cursor.rowcount diff --git a/src/snowflake/sqlalchemy/base/snowflake_orm_select_compile_state.py b/src/snowflake/sqlalchemy/base/snowflake_orm_select_compile_state.py new file mode 100644 index 00000000..4763cf28 --- /dev/null +++ b/src/snowflake/sqlalchemy/base/snowflake_orm_select_compile_state.py @@ -0,0 +1,237 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Overwrite methods to handle Snowflake BCR change: +https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057 +- _join_determine_implicit_left_side +- _join_left_to_right +""" + +from sqlalchemy import exc as sa_exc +from sqlalchemy import inspect, sql +from sqlalchemy.orm import context +from sqlalchemy.orm.context import _MapperEntity + +from ..compat import IS_VERSION_20, args_reducer +from ..util import _find_left_clause_to_join_from, _Snowflake_ORMJoin + + +# handle Snowflake BCR bcr-1057 +@sql.base.CompileState.plugin_for("orm", "select") +class SnowflakeORMSelectCompileState(context.ORMSelectCompileState): + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, context._MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + @args_reducer(positions_to_drop=(6, 7)) + def _join_left_to_right( + self, entities_collection, left, right, onclause, prop, outerjoin, full + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance(entities_collection[use_entity_index], _MapperEntity) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _Snowflake_ORMJoin( # handle Snowflake BCR bcr-1057 + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] diff --git a/src/snowflake/sqlalchemy/base/snowflake_select_state.py b/src/snowflake/sqlalchemy/base/snowflake_select_state.py new file mode 100644 index 00000000..872dacea --- /dev/null +++ b/src/snowflake/sqlalchemy/base/snowflake_select_state.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +from sqlalchemy import exc as sa_exc +from sqlalchemy import util as sa_util +from sqlalchemy.sql.base import CompileState +from sqlalchemy.sql.selectable import SelectState + +from ..util import _find_left_clause_to_join_from, _Snowflake_Selectable_Join + + +# handle Snowflake BCR bcr-1057 +@CompileState.plugin_for("default", "select") +class SnowflakeSelectState(SelectState): + def _setup_joins(self, args, raw_columns): + for right, onclause, left, flags in args: + isouter = flags["isouter"] + full = flags["full"] + + if left is None: + ( + left, + replace_from_obj_index, + ) = self._join_determine_implicit_left_side( + raw_columns, left, right, onclause + ) + else: + (replace_from_obj_index) = self._join_place_explicit_left_side(left) + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + ( + _Snowflake_Selectable_Join( # handle Snowflake BCR bcr-1057 + left_clause, + right, + onclause, + isouter=isouter, + full=full, + ), + ) + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + self.from_clauses = self.from_clauses + ( + # handle Snowflake BCR bcr-1057 + _Snowflake_Selectable_Join( + left, right, onclause, isouter=isouter, full=full + ), + ) + + @sa_util.preload_module("sqlalchemy.custom_commands.util") + def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + replace_from_obj_index = None + + from_clauses = self.from_clauses + + if from_clauses: + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(from_clauses, right, onclause) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = from_clauses[replace_from_obj_index] + else: + potential = {} + statement = self.statement + + for from_clause in itertools.chain( + itertools.chain.from_iterable( + [element._from_objects for element in raw_columns] + ), + itertools.chain.from_iterable( + [element._from_objects for element in statement._where_criteria] + ), + ): + + potential[from_clause] = () + + all_clauses = list(potential.keys()) + # handle Snowflake BCR bcr-1057 + indexes = _find_left_clause_to_join_from(all_clauses, right, onclause) + + if len(indexes) == 1: + left = all_clauses[indexes[0]] + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." + ) + elif not indexes: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already to " + "help resolve the ambiguity." % (right,) + ) + return left, replace_from_obj_index diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py index 9e97e574..53088334 100644 --- a/src/snowflake/sqlalchemy/compat.py +++ b/src/snowflake/sqlalchemy/compat.py @@ -1,5 +1,17 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import functools diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py deleted file mode 100644 index 15585bd5..00000000 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ /dev/null @@ -1,622 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# - -from collections.abc import Sequence -from typing import List - -from sqlalchemy import false, true -from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.sql.dml import UpdateBase -from sqlalchemy.sql.elements import ClauseElement -from sqlalchemy.sql.roles import FromClauseRole - -from .compat import string_types - -NoneType = type(None) - - -def translate_bool(bln): - if bln: - return true() - return false() - - -class MergeInto(UpdateBase): - __visit_name__ = "merge_into" - _bind = None - - def __init__(self, target, source, on): - self.target = target - self.source = source - self.on = on - self.clauses = [] - - class clause(ClauseElement): - __visit_name__ = "merge_into_clause" - - def __init__(self, command): - self.set = {} - self.predicate = None - self.command = command - - def __repr__(self): - case_predicate = ( - f" AND {str(self.predicate)}" if self.predicate is not None else "" - ) - if self.command == "INSERT": - sets, sets_tos = zip(*self.set.items()) - return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format( - case_predicate, - self.command, - ", ".join(sets), - ", ".join(map(str, sets_tos)), - ) - else: - # WHEN MATCHED clause - sets = ( - ", ".join([f"{set[0]} = {set[1]}" for set in self.set.items()]) - if self.set - else "" - ) - return "WHEN MATCHED{} THEN {}{}".format( - case_predicate, - self.command, - f" SET {str(sets)}" if self.set else "", - ) - - def values(self, **kwargs): - self.set = kwargs - return self - - def where(self, expr): - self.predicate = expr - return self - - def __repr__(self): - clauses = " ".join([repr(clause) for clause in self.clauses]) - return f"MERGE INTO {self.target} USING {self.source} ON {self.on}" + ( - f" {clauses}" if clauses else "" - ) - - def when_matched_then_update(self): - clause = self.clause("UPDATE") - self.clauses.append(clause) - return clause - - def when_matched_then_delete(self): - clause = self.clause("DELETE") - self.clauses.append(clause) - return clause - - def when_not_matched_then_insert(self): - clause = self.clause("INSERT") - self.clauses.append(clause) - return clause - - -class FilesOption: - """ - Class to represent FILES option for the snowflake COPY INTO statement - """ - - def __init__(self, file_names: List[str]): - self.file_names = file_names - - def __str__(self): - the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names] - return f"({','.join(the_files)})" - - -class CopyInto(UpdateBase): - """Copy Into Command base class, for documentation see: - https://docs.snowflake.net/manuals/sql-reference/sql/copy-into-location.html""" - - __visit_name__ = "copy_into" - _bind = None - - def __init__(self, from_, into, formatter=None): - self.from_ = from_ - self.into = into - self.formatter = formatter - self.copy_options = {} - - def __repr__(self): - """ - repr for debugging / logging purposes only. For compilation logic, see - the corresponding visitor in base.py - """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" - - def bind(self): - return None - - def force(self, force): - if not isinstance(force, bool): - raise TypeError("Parameter force should be a boolean value") - self.copy_options.update({"FORCE": translate_bool(force)}) - return self - - def single(self, single_file): - if not isinstance(single_file, bool): - raise TypeError("Parameter single_file should be a boolean value") - self.copy_options.update({"SINGLE": translate_bool(single_file)}) - return self - - def maxfilesize(self, max_size): - if not isinstance(max_size, int): - raise TypeError("Parameter max_size should be an integer value") - self.copy_options.update({"MAX_FILE_SIZE": max_size}) - return self - - def files(self, file_names): - self.copy_options.update({"FILES": FilesOption(file_names)}) - return self - - def pattern(self, pattern): - self.copy_options.update({"PATTERN": pattern}) - return self - - -class CopyFormatter(ClauseElement): - """ - Base class for Formatter specifications inside a COPY INTO statement. May also - be used to create a named format. - """ - - __visit_name__ = "copy_formatter" - - def __init__(self, format_name=None): - self.options = dict() - if format_name: - self.options["format_name"] = format_name - - def __repr__(self): - """ - repr for debugging / logging purposes only. For compilation logic, see - the corresponding visitor in base.py - """ - return f"FILE_FORMAT=({self.options})" - - @staticmethod - def value_repr(name, value): - """ - Make a SQL-suitable representation of "value". This is called from - the corresponding visitor function (base.py/visit_copy_formatter()) - - in case of a format name: return it without quotes - - in case of a string: enclose in quotes: "value" - - in case of a tuple of length 1: enclose the only element in brackets: (value) - Standard stringification of Python would append a trailing comma: (value,) - which is not correct in SQL - - otherwise: just convert to str as is: value - """ - if name == "format_name": - return value - elif isinstance(value, str): - return f"'{value}'" - elif isinstance(value, tuple) and len(value) == 1: - return f"('{value[0]}')" - else: - return str(value) - - -class CSVFormatter(CopyFormatter): - file_format = "csv" - - def compression(self, comp_type): - """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" - if isinstance(comp_type, string_types): - comp_type = comp_type.lower() - _available_options = [ - "auto", - "gzip", - "bz2", - "brotli", - "zstd", - "deflate", - "raw_deflate", - None, - ] - if comp_type not in _available_options: - raise TypeError(f"Compression type should be one of : {_available_options}") - self.options["COMPRESSION"] = comp_type - return self - - def _check_delimiter(self, delimiter, delimiter_txt): - """ - Check if a delimiter is either a string of length 1 or an integer. In case of - a string delimiter, take into account that the actual string may be longer, - but still evaluate to a single character (like "\\n" or r"\n" - """ - if isinstance(delimiter, NoneType): - return - if isinstance(delimiter, string_types): - delimiter_processed = delimiter.encode().decode("unicode_escape") - if len(delimiter_processed) == 1: - return - if isinstance(delimiter, int): - return - raise TypeError( - f"{delimiter_txt} should be a single character, that is either a string, or a number" - ) - - def record_delimiter(self, deli_type): - """Character that separates records in an unloaded file.""" - self._check_delimiter(deli_type, "Record delimiter") - if isinstance(deli_type, int): - self.options["RECORD_DELIMITER"] = hex(deli_type) - else: - self.options["RECORD_DELIMITER"] = deli_type - return self - - def field_delimiter(self, deli_type): - """Character that separates fields in an unloaded file.""" - self._check_delimiter(deli_type, "Field delimiter") - if isinstance(deli_type, int): - self.options["FIELD_DELIMITER"] = hex(deli_type) - else: - self.options["FIELD_DELIMITER"] = deli_type - return self - - def file_extension(self, ext): - """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service. - """ - if not isinstance(ext, (NoneType, string_types)): - raise TypeError("File extension should be a string") - self.options["FILE_EXTENSION"] = ext - return self - - def date_format(self, dt_frmt): - """String that defines the format of date values in the unloaded data files.""" - if not isinstance(dt_frmt, string_types): - raise TypeError("Date format should be a string") - self.options["DATE_FORMAT"] = dt_frmt - return self - - def time_format(self, tm_frmt): - """String that defines the format of time values in the unloaded data files.""" - if not isinstance(tm_frmt, string_types): - raise TypeError("Time format should be a string") - self.options["TIME_FORMAT"] = tm_frmt - return self - - def timestamp_format(self, tmstmp_frmt): - """String that defines the format of timestamp values in the unloaded data files.""" - if not isinstance(tmstmp_frmt, string_types): - raise TypeError("Timestamp format should be a string") - self.options["TIMESTAMP_FORMAT"] = tmstmp_frmt - return self - - def binary_format(self, bin_fmt): - """Character used as the escape character for any field values. The option can be used when unloading data - from binary columns in a table.""" - if isinstance(bin_fmt, string_types): - bin_fmt = bin_fmt.lower() - _available_options = ["hex", "base64", "utf8"] - if bin_fmt not in _available_options: - raise TypeError(f"Binary format should be one of : {_available_options}") - self.options["BINARY_FORMAT"] = bin_fmt - return self - - def escape(self, esc): - """Character used as the escape character for any field values.""" - self._check_delimiter(esc, "Escape") - if isinstance(esc, int): - self.options["ESCAPE"] = hex(esc) - else: - self.options["ESCAPE"] = esc - return self - - def escape_unenclosed_field(self, esc): - """Single character string used as the escape character for unenclosed field values only.""" - self._check_delimiter(esc, "Escape unenclosed field") - if isinstance(esc, int): - self.options["ESCAPE_UNENCLOSED_FIELD"] = hex(esc) - else: - self.options["ESCAPE_UNENCLOSED_FIELD"] = esc - return self - - def field_optionally_enclosed_by(self, enc): - """Character used to enclose strings. Either None, ', or \".""" - _available_options = [None, "'", '"'] - if enc not in _available_options: - raise TypeError(f"Enclosing string should be one of : {_available_options}") - self.options["FIELD_OPTIONALLY_ENCLOSED_BY"] = enc - return self - - def null_if(self, null_value): - """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace - NULL values with the first string""" - if not isinstance(null_value, Sequence): - raise TypeError("Parameter null_value should be an iterable") - self.options["NULL_IF"] = tuple(null_value) - return self - - def skip_header(self, skip_header): - """ - Number of header rows to be skipped at the beginning of the file - """ - if not isinstance(skip_header, int): - raise TypeError("skip_header should be an int") - self.options["SKIP_HEADER"] = skip_header - return self - - def trim_space(self, trim_space): - """ - Remove leading or trailing white spaces - """ - if not isinstance(trim_space, bool): - raise TypeError("trim_space should be a bool") - self.options["TRIM_SPACE"] = trim_space - return self - - def error_on_column_count_mismatch(self, error_on_col_count_mismatch): - """ - Generate a parsing error if the number of delimited columns (i.e. fields) in - an input data file does not match the number of columns in the corresponding table. - """ - if not isinstance(error_on_col_count_mismatch, bool): - raise TypeError("skip_header should be a bool") - self.options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = error_on_col_count_mismatch - return self - - -class JSONFormatter(CopyFormatter): - """Format specific functions""" - - file_format = "json" - - def compression(self, comp_type): - """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" - if isinstance(comp_type, string_types): - comp_type = comp_type.lower() - _available_options = [ - "auto", - "gzip", - "bz2", - "brotli", - "zstd", - "deflate", - "raw_deflate", - None, - ] - if comp_type not in _available_options: - raise TypeError(f"Compression type should be one of : {_available_options}") - self.options["COMPRESSION"] = comp_type - return self - - def file_extension(self, ext): - """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is - responsible for specifying a valid file extension that can be read by the desired software or service. - """ - if not isinstance(ext, (NoneType, string_types)): - raise TypeError("File extension should be a string") - self.options["FILE_EXTENSION"] = ext - return self - - -class PARQUETFormatter(CopyFormatter): - """Format specific functions""" - - file_format = "parquet" - - def snappy_compression(self, comp): - """Enable, or disable snappy compression""" - if not isinstance(comp, bool): - raise TypeError("Comp should be a Boolean value") - self.options["SNAPPY_COMPRESSION"] = translate_bool(comp) - return self - - def compression(self, comp): - """ - Set compression type - """ - if not isinstance(comp, str): - raise TypeError("Comp should be a str value") - self.options["COMPRESSION"] = comp - return self - - def binary_as_text(self, value): - """Enable, or disable binary as text""" - if not isinstance(value, bool): - raise TypeError("binary_as_text should be a Boolean value") - self.options["BINARY_AS_TEXT"] = translate_bool(value) - return self - - -class ExternalStage(ClauseElement, FromClauseRole): - """External Stage descriptor""" - - __visit_name__ = "external_stage" - _hide_froms = () - - @staticmethod - def prepare_namespace(namespace): - return f"{namespace}." if not namespace.endswith(".") else namespace - - @staticmethod - def prepare_path(path): - return f"/{path}" if not path.startswith("/") else path - - def __init__(self, name, path=None, namespace=None, file_format=None): - self.name = name - self.path = self.prepare_path(path) if path else "" - self.namespace = self.prepare_namespace(namespace) if namespace else "" - self.file_format = file_format - - def __repr__(self): - return f"@{self.namespace}{self.name}{self.path} ({self.file_format})" - - @classmethod - def from_parent_stage(cls, parent_stage, path, file_format=None): - """ - Extend an existing parent stage (with or without path) with an - additional sub-path - """ - return cls( - parent_stage.name, - f"{parent_stage.path}/{path}", - parent_stage.namespace, - file_format, - ) - - -class CreateFileFormat(DDLElement): - """ - Encapsulates a CREATE FILE FORMAT statement; using a format description (as in - a COPY INTO statement) and a format name. - """ - - __visit_name__ = "create_file_format" - - def __init__(self, format_name, formatter, replace_if_exists=False): - super().__init__() - self.format_name = format_name - self.formatter = formatter - self.replace_if_exists = replace_if_exists - - -class CreateStage(DDLElement): - """ - Encapsulates a CREATE STAGE statement, using a container (physical base for the - stage) and the actual ExternalStage object. - """ - - __visit_name__ = "create_stage" - - def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): - super().__init__() - self.container = container - self.temporary = temporary - self.stage = stage - self.replace_if_exists = replace_if_exists - - -class AWSBucket(ClauseElement): - """AWS S3 bucket descriptor""" - - __visit_name__ = "aws_bucket" - - def __init__(self, bucket, path=None): - self.bucket = bucket - self.path = path - self.encryption_used = {} - self.credentials_used = {} - - @classmethod - def from_uri(cls, uri): - if uri[0:5] != "s3://": - raise ValueError(f"Invalid AWS bucket URI: {uri}") - b = uri[5:].split("/", 1) - if len(b) == 1: - bucket, path = b[0], None - else: - bucket, path = b - return cls(bucket, path) - - def __repr__(self): - credentials = "CREDENTIALS=({})".format( - " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items()) - ) - encryption = "ENCRYPTION=({})".format( - " ".join( - f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" - for n, v in self.encryption_used.items() - ) - ) - uri = "'s3://{}{}'".format(self.bucket, f"/{self.path}" if self.path else "") - return "{}{}{}".format( - uri, - f" {credentials}" if self.credentials_used else "", - f" {encryption}" if self.encryption_used else "", - ) - - def credentials( - self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None - ): - if aws_role is None and (aws_key_id is None and aws_secret_key is None): - raise ValueError( - "Either 'aws_role', or aws_key_id and aws_secret_key has to be supplied" - ) - if aws_role: - self.credentials_used = {"AWS_ROLE": aws_role} - else: - self.credentials_used = { - "AWS_SECRET_KEY": aws_secret_key, - "AWS_KEY_ID": aws_key_id, - } - if aws_token: - self.credentials_used["AWS_TOKEN"] = aws_token - return self - - def encryption_aws_cse(self, master_key): - self.encryption_used = {"TYPE": "AWS_CSE", "MASTER_KEY": master_key} - return self - - def encryption_aws_sse_s3(self): - self.encryption_used = {"TYPE": "AWS_SSE_S3"} - return self - - def encryption_aws_sse_kms(self, kms_key_id=None): - self.encryption_used = {"TYPE": "AWS_SSE_KMS"} - if kms_key_id: - self.encryption_used["KMS_KEY_ID"] = kms_key_id - return self - - -class AzureContainer(ClauseElement): - """Microsoft Azure Container descriptor""" - - __visit_name__ = "azure_container" - - def __init__(self, account, container, path=None): - self.account = account - self.container = container - self.path = path - self.encryption_used = {} - self.credentials_used = {} - - @classmethod - def from_uri(cls, uri): - if uri[0:8] != "azure://": - raise ValueError(f"Invalid Azure Container URI: {uri}") - account, uri = uri[8:].split(".", 1) - if uri[0:22] != "blob.core.windows.net/": - raise ValueError(f"Invalid Azure Container URI: {uri}") - b = uri[22:].split("/", 1) - if len(b) == 1: - container, path = b[0], None - else: - container, path = b - return cls(account, container, path) - - def __repr__(self): - credentials = "CREDENTIALS=({})".format( - " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items()) - ) - encryption = "ENCRYPTION=({})".format( - " ".join( - f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" - for n, v in self.encryption_used.items() - ) - ) - uri = "'azure://{}.blob.core.windows.net/{}{}'".format( - self.account, self.container, f"/{self.path}" if self.path else "" - ) - return "{}{}{}".format( - uri, - f" {credentials}" if self.credentials_used else "", - f" {encryption}" if self.encryption_used else "", - ) - - def credentials(self, azure_sas_token): - self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token} - return self - - def encryption_azure_cse(self, master_key): - self.encryption_used = {"TYPE": "AZURE_CSE", "MASTER_KEY": master_key} - return self - - -CopyIntoStorage = CopyInto diff --git a/src/snowflake/sqlalchemy/custom_commands/__init__.py b/src/snowflake/sqlalchemy/custom_commands/__init__.py new file mode 100644 index 00000000..8aeaa202 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .aws_bucket import AWSBucket +from .azure_container import AzureContainer +from .copy_formatter import CopyFormatter +from .copy_into import CopyInto, CopyIntoStorage, FilesOption +from .create_file_format import CreateFileFormat +from .create_stage import CreateStage +from .csv_formatter import CSVFormatter +from .external_stage import ExternalStage +from .json_formatter import JSONFormatter +from .merge_into import MergeInto +from .parquet_formater import PARQUETFormatter +from .utils import translate_bool + +__all__ = [ + "AWSBucket", + "AzureContainer", + "CopyFormatter", + "FilesOption", + "CopyInto", + "CopyIntoStorage", + "CreateFileFormat", + "CreateStage", + "CSVFormatter", + "ExternalStage", + "JSONFormatter", + "MergeInto", + "PARQUETFormatter", + "translate_bool", +] diff --git a/src/snowflake/sqlalchemy/custom_commands/aws_bucket.py b/src/snowflake/sqlalchemy/custom_commands/aws_bucket.py new file mode 100644 index 00000000..c9e616e1 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/aws_bucket.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.elements import ClauseElement + +from snowflake.sqlalchemy.compat import string_types + + +class AWSBucket(ClauseElement): + """AWS S3 bucket descriptor""" + + __visit_name__ = "aws_bucket" + + def __init__(self, bucket, path=None): + self.bucket = bucket + self.path = path + self.encryption_used = {} + self.credentials_used = {} + + @classmethod + def from_uri(cls, uri): + if uri[0:5] != "s3://": + raise ValueError(f"Invalid AWS bucket URI: {uri}") + b = uri[5:].split("/", 1) + if len(b) == 1: + bucket, path = b[0], None + else: + bucket, path = b + return cls(bucket, path) + + def __repr__(self): + credentials = "CREDENTIALS=({})".format( + " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items()) + ) + encryption = "ENCRYPTION=({})".format( + " ".join( + f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" + for n, v in self.encryption_used.items() + ) + ) + uri = "'s3://{}{}'".format(self.bucket, f"/{self.path}" if self.path else "") + return "{}{}{}".format( + uri, + f" {credentials}" if self.credentials_used else "", + f" {encryption}" if self.encryption_used else "", + ) + + def credentials( + self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None + ): + if aws_role is None and (aws_key_id is None and aws_secret_key is None): + raise ValueError( + "Either 'aws_role', or aws_key_id and aws_secret_key has to be supplied" + ) + if aws_role: + self.credentials_used = {"AWS_ROLE": aws_role} + else: + self.credentials_used = { + "AWS_SECRET_KEY": aws_secret_key, + "AWS_KEY_ID": aws_key_id, + } + if aws_token: + self.credentials_used["AWS_TOKEN"] = aws_token + return self + + def encryption_aws_cse(self, master_key): + self.encryption_used = {"TYPE": "AWS_CSE", "MASTER_KEY": master_key} + return self + + def encryption_aws_sse_s3(self): + self.encryption_used = {"TYPE": "AWS_SSE_S3"} + return self + + def encryption_aws_sse_kms(self, kms_key_id=None): + self.encryption_used = {"TYPE": "AWS_SSE_KMS"} + if kms_key_id: + self.encryption_used["KMS_KEY_ID"] = kms_key_id + return self diff --git a/src/snowflake/sqlalchemy/custom_commands/azure_container.py b/src/snowflake/sqlalchemy/custom_commands/azure_container.py new file mode 100644 index 00000000..18f08940 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/azure_container.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.elements import ClauseElement + +from snowflake.sqlalchemy.compat import string_types + + +class AzureContainer(ClauseElement): + """Microsoft Azure Container descriptor""" + + __visit_name__ = "azure_container" + + def __init__(self, account, container, path=None): + self.account = account + self.container = container + self.path = path + self.encryption_used = {} + self.credentials_used = {} + + @classmethod + def from_uri(cls, uri): + if uri[0:8] != "azure://": + raise ValueError(f"Invalid Azure Container URI: {uri}") + account, uri = uri[8:].split(".", 1) + if uri[0:22] != "blob.core.windows.net/": + raise ValueError(f"Invalid Azure Container URI: {uri}") + b = uri[22:].split("/", 1) + if len(b) == 1: + container, path = b[0], None + else: + container, path = b + return cls(account, container, path) + + def __repr__(self): + credentials = "CREDENTIALS=({})".format( + " ".join(f"{n}='{v}'" for n, v in self.credentials_used.items()) + ) + encryption = "ENCRYPTION=({})".format( + " ".join( + f"{n}='{v}'" if isinstance(v, string_types) else f"{n}={v}" + for n, v in self.encryption_used.items() + ) + ) + uri = "'azure://{}.blob.core.windows.net/{}{}'".format( + self.account, self.container, f"/{self.path}" if self.path else "" + ) + return "{}{}{}".format( + uri, + f" {credentials}" if self.credentials_used else "", + f" {encryption}" if self.encryption_used else "", + ) + + def credentials(self, azure_sas_token): + self.credentials_used = {"AZURE_SAS_TOKEN": azure_sas_token} + return self + + def encryption_azure_cse(self, master_key): + self.encryption_used = {"TYPE": "AZURE_CSE", "MASTER_KEY": master_key} + return self diff --git a/src/snowflake/sqlalchemy/custom_commands/copy_formatter.py b/src/snowflake/sqlalchemy/custom_commands/copy_formatter.py new file mode 100644 index 00000000..1836bb31 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/copy_formatter.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from sqlalchemy.sql.elements import ClauseElement + + +class CopyFormatter(ClauseElement): + """ + Base class for Formatter specifications inside a COPY INTO statement. May also + be used to create a named format. + """ + + __visit_name__ = "copy_formatter" + + def __init__(self, format_name=None): + self.options = dict() + if format_name: + self.options["format_name"] = format_name + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the corresponding visitor in base.py + """ + return f"FILE_FORMAT=({self.options})" + + @staticmethod + def value_repr(name, value): + """ + Make a SQL-suitable representation of "value". This is called from + the corresponding visitor function (base.py/visit_copy_formatter()) + - in case of a format name: return it without quotes + - in case of a string: enclose in quotes: "value" + - in case of a tuple of length 1: enclose the only element in brackets: (value) + Standard stringification of Python would append a trailing comma: (value,) + which is not correct in SQL + - otherwise: just convert to str as is: value + """ + if name == "format_name": + return value + elif isinstance(value, str): + return f"'{value}'" + elif isinstance(value, tuple) and len(value) == 1: + return f"('{value[0]}')" + else: + return str(value) diff --git a/src/snowflake/sqlalchemy/custom_commands/copy_into.py b/src/snowflake/sqlalchemy/custom_commands/copy_into.py new file mode 100644 index 00000000..0869011f --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/copy_into.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from sqlalchemy.sql.dml import UpdateBase + +from snowflake.sqlalchemy.custom_commands.utils import translate_bool + + +class FilesOption: + """ + Class to represent FILES option for the snowflake COPY INTO statement + """ + + def __init__(self, file_names: List[str]): + self.file_names = file_names + + def __str__(self): + the_files = ["'" + f.replace("'", "\\'") + "'" for f in self.file_names] + return f"({','.join(the_files)})" + + +class CopyInto(UpdateBase): + """Copy Into Command base class, for documentation see: + https://docs.snowflake.net/manuals/sql-reference/sql/copy-into-location.html""" + + __visit_name__ = "copy_into" + _bind = None + + def __init__(self, from_, into, formatter=None): + self.from_ = from_ + self.into = into + self.formatter = formatter + self.copy_options = {} + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the corresponding visitor in base.py + """ + return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + + def bind(self): + return None + + def force(self, force): + if not isinstance(force, bool): + raise TypeError("Parameter force should be a boolean value") + self.copy_options.update({"FORCE": translate_bool(force)}) + return self + + def single(self, single_file): + if not isinstance(single_file, bool): + raise TypeError("Parameter single_file should be a boolean value") + self.copy_options.update({"SINGLE": translate_bool(single_file)}) + return self + + def maxfilesize(self, max_size): + if not isinstance(max_size, int): + raise TypeError("Parameter max_size should be an integer value") + self.copy_options.update({"MAX_FILE_SIZE": max_size}) + return self + + def files(self, file_names): + self.copy_options.update({"FILES": FilesOption(file_names)}) + return self + + def pattern(self, pattern): + self.copy_options.update({"PATTERN": pattern}) + return self + + +CopyIntoStorage = CopyInto diff --git a/src/snowflake/sqlalchemy/custom_commands/create_file_format.py b/src/snowflake/sqlalchemy/custom_commands/create_file_format.py new file mode 100644 index 00000000..ab993e86 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/create_file_format.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.ddl import DDLElement + + +class CreateFileFormat(DDLElement): + """ + Encapsulates a CREATE FILE FORMAT statement; using a format description (as in + a COPY INTO statement) and a format name. + """ + + __visit_name__ = "create_file_format" + + def __init__(self, format_name, formatter, replace_if_exists=False): + super().__init__() + self.format_name = format_name + self.formatter = formatter + self.replace_if_exists = replace_if_exists diff --git a/src/snowflake/sqlalchemy/custom_commands/create_stage.py b/src/snowflake/sqlalchemy/custom_commands/create_stage.py new file mode 100644 index 00000000..97cb36b9 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/create_stage.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.ddl import DDLElement + + +class CreateStage(DDLElement): + """ + Encapsulates a CREATE STAGE statement, using a container (physical base for the + stage) and the actual ExternalStage object. + """ + + __visit_name__ = "create_stage" + + def __init__(self, container, stage, replace_if_exists=False, *, temporary=False): + super().__init__() + self.container = container + self.temporary = temporary + self.stage = stage + self.replace_if_exists = replace_if_exists diff --git a/src/snowflake/sqlalchemy/custom_commands/csv_formatter.py b/src/snowflake/sqlalchemy/custom_commands/csv_formatter.py new file mode 100644 index 00000000..670b26a7 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/csv_formatter.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +from sqlalchemy.util import NoneType + +from snowflake.sqlalchemy.compat import string_types +from snowflake.sqlalchemy.custom_commands.copy_formatter import CopyFormatter + + +class CSVFormatter(CopyFormatter): + file_format = "csv" + + def compression(self, comp_type): + """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" + if isinstance(comp_type, string_types): + comp_type = comp_type.lower() + _available_options = [ + "auto", + "gzip", + "bz2", + "brotli", + "zstd", + "deflate", + "raw_deflate", + None, + ] + if comp_type not in _available_options: + raise TypeError(f"Compression type should be one of : {_available_options}") + self.options["COMPRESSION"] = comp_type + return self + + def _check_delimiter(self, delimiter, delimiter_txt): + """ + Check if a delimiter is either a string of length 1 or an integer. In case of + a string delimiter, take into account that the actual string may be longer, + but still evaluate to a single character (like "\\n" or r"\n" + """ + if isinstance(delimiter, NoneType): + return + if isinstance(delimiter, string_types): + delimiter_processed = delimiter.encode().decode("unicode_escape") + if len(delimiter_processed) == 1: + return + if isinstance(delimiter, int): + return + raise TypeError( + f"{delimiter_txt} should be a single character, that is either a string, or a number" + ) + + def record_delimiter(self, deli_type): + """Character that separates records in an unloaded file.""" + self._check_delimiter(deli_type, "Record delimiter") + if isinstance(deli_type, int): + self.options["RECORD_DELIMITER"] = hex(deli_type) + else: + self.options["RECORD_DELIMITER"] = deli_type + return self + + def field_delimiter(self, deli_type): + """Character that separates fields in an unloaded file.""" + self._check_delimiter(deli_type, "Field delimiter") + if isinstance(deli_type, int): + self.options["FIELD_DELIMITER"] = hex(deli_type) + else: + self.options["FIELD_DELIMITER"] = deli_type + return self + + def file_extension(self, ext): + """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is + responsible for specifying a valid file extension that can be read by the desired software or service. + """ + if not isinstance(ext, (NoneType, string_types)): + raise TypeError("File extension should be a string") + self.options["FILE_EXTENSION"] = ext + return self + + def date_format(self, dt_frmt): + """String that defines the format of date values in the unloaded data files.""" + if not isinstance(dt_frmt, string_types): + raise TypeError("Date format should be a string") + self.options["DATE_FORMAT"] = dt_frmt + return self + + def time_format(self, tm_frmt): + """String that defines the format of time values in the unloaded data files.""" + if not isinstance(tm_frmt, string_types): + raise TypeError("Time format should be a string") + self.options["TIME_FORMAT"] = tm_frmt + return self + + def timestamp_format(self, tmstmp_frmt): + """String that defines the format of timestamp values in the unloaded data files.""" + if not isinstance(tmstmp_frmt, string_types): + raise TypeError("Timestamp format should be a string") + self.options["TIMESTAMP_FORMAT"] = tmstmp_frmt + return self + + def binary_format(self, bin_fmt): + """Character used as the escape character for any field values. The option can be used when unloading data + from binary columns in a table.""" + if isinstance(bin_fmt, string_types): + bin_fmt = bin_fmt.lower() + _available_options = ["hex", "base64", "utf8"] + if bin_fmt not in _available_options: + raise TypeError(f"Binary format should be one of : {_available_options}") + self.options["BINARY_FORMAT"] = bin_fmt + return self + + def escape(self, esc): + """Character used as the escape character for any field values.""" + self._check_delimiter(esc, "Escape") + if isinstance(esc, int): + self.options["ESCAPE"] = hex(esc) + else: + self.options["ESCAPE"] = esc + return self + + def escape_unenclosed_field(self, esc): + """Single character string used as the escape character for unenclosed field values only.""" + self._check_delimiter(esc, "Escape unenclosed field") + if isinstance(esc, int): + self.options["ESCAPE_UNENCLOSED_FIELD"] = hex(esc) + else: + self.options["ESCAPE_UNENCLOSED_FIELD"] = esc + return self + + def field_optionally_enclosed_by(self, enc): + """Character used to enclose strings. Either None, ', or \".""" + _available_options = [None, "'", '"'] + if enc not in _available_options: + raise TypeError(f"Enclosing string should be one of : {_available_options}") + self.options["FIELD_OPTIONALLY_ENCLOSED_BY"] = enc + return self + + def null_if(self, null_value): + """Copying into a table these strings will be replaced by a NULL, while copying out of Snowflake will replace + NULL values with the first string""" + if not isinstance(null_value, Sequence): + raise TypeError("Parameter null_value should be an iterable") + self.options["NULL_IF"] = tuple(null_value) + return self + + def skip_header(self, skip_header): + """ + Number of header rows to be skipped at the beginning of the file + """ + if not isinstance(skip_header, int): + raise TypeError("skip_header should be an int") + self.options["SKIP_HEADER"] = skip_header + return self + + def trim_space(self, trim_space): + """ + Remove leading or trailing white spaces + """ + if not isinstance(trim_space, bool): + raise TypeError("trim_space should be a bool") + self.options["TRIM_SPACE"] = trim_space + return self + + def error_on_column_count_mismatch(self, error_on_col_count_mismatch): + """ + Generate a parsing error if the number of delimited columns (i.e. fields) in + an input data file does not match the number of columns in the corresponding table. + """ + if not isinstance(error_on_col_count_mismatch, bool): + raise TypeError("skip_header should be a bool") + self.options["ERROR_ON_COLUMN_COUNT_MISMATCH"] = error_on_col_count_mismatch + return self diff --git a/src/snowflake/sqlalchemy/custom_commands/external_stage.py b/src/snowflake/sqlalchemy/custom_commands/external_stage.py new file mode 100644 index 00000000..ec71e64b --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/external_stage.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.elements import ClauseElement +from sqlalchemy.sql.roles import FromClauseRole + + +class ExternalStage(ClauseElement, FromClauseRole): + """External Stage descriptor""" + + __visit_name__ = "external_stage" + _hide_froms = () + + @staticmethod + def prepare_namespace(namespace): + return f"{namespace}." if not namespace.endswith(".") else namespace + + @staticmethod + def prepare_path(path): + return f"/{path}" if not path.startswith("/") else path + + def __init__(self, name, path=None, namespace=None, file_format=None): + self.name = name + self.path = self.prepare_path(path) if path else "" + self.namespace = self.prepare_namespace(namespace) if namespace else "" + self.file_format = file_format + + def __repr__(self): + return f"@{self.namespace}{self.name}{self.path} ({self.file_format})" + + @classmethod + def from_parent_stage(cls, parent_stage, path, file_format=None): + """ + Extend an existing parent stage (with or without path) with an + additional sub-path + """ + return cls( + parent_stage.name, + f"{parent_stage.path}/{path}", + parent_stage.namespace, + file_format, + ) diff --git a/src/snowflake/sqlalchemy/custom_commands/json_formatter.py b/src/snowflake/sqlalchemy/custom_commands/json_formatter.py new file mode 100644 index 00000000..6dd92376 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/json_formatter.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.util import NoneType + +from ..compat import string_types +from .copy_formatter import CopyFormatter + + +class JSONFormatter(CopyFormatter): + """Format specific functions""" + + file_format = "json" + + def compression(self, comp_type): + """String (constant) that specifies to compresses the unloaded data files using the specified compression algorithm.""" + if isinstance(comp_type, string_types): + comp_type = comp_type.lower() + _available_options = [ + "auto", + "gzip", + "bz2", + "brotli", + "zstd", + "deflate", + "raw_deflate", + None, + ] + if comp_type not in _available_options: + raise TypeError(f"Compression type should be one of : {_available_options}") + self.options["COMPRESSION"] = comp_type + return self + + def file_extension(self, ext): + """String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is + responsible for specifying a valid file extension that can be read by the desired software or service. + """ + if not isinstance(ext, (NoneType, string_types)): + raise TypeError("File extension should be a string") + self.options["FILE_EXTENSION"] = ext + return self diff --git a/src/snowflake/sqlalchemy/custom_commands/merge_into.py b/src/snowflake/sqlalchemy/custom_commands/merge_into.py new file mode 100644 index 00000000..13551fd7 --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/merge_into.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.sql.dml import UpdateBase +from sqlalchemy.sql.elements import ClauseElement + + +class MergeInto(UpdateBase): + __visit_name__ = "merge_into" + _bind = None + + def __init__(self, target, source, on): + self.target = target + self.source = source + self.on = on + self.clauses = [] + + class clause(ClauseElement): + __visit_name__ = "merge_into_clause" + + def __init__(self, command): + self.set = {} + self.predicate = None + self.command = command + + def __repr__(self): + case_predicate = ( + f" AND {str(self.predicate)}" if self.predicate is not None else "" + ) + if self.command == "INSERT": + sets, sets_tos = zip(*self.set.items()) + return "WHEN NOT MATCHED{} THEN {} ({}) VALUES ({})".format( + case_predicate, + self.command, + ", ".join(sets), + ", ".join(map(str, sets_tos)), + ) + else: + # WHEN MATCHED clause + sets = ( + ", ".join([f"{set[0]} = {set[1]}" for set in self.set.items()]) + if self.set + else "" + ) + return "WHEN MATCHED{} THEN {}{}".format( + case_predicate, + self.command, + f" SET {str(sets)}" if self.set else "", + ) + + def values(self, **kwargs): + self.set = kwargs + return self + + def where(self, expr): + self.predicate = expr + return self + + def __repr__(self): + clauses = " ".join([repr(clause) for clause in self.clauses]) + return f"MERGE INTO {self.target} USING {self.source} ON {self.on}" + ( + f" {clauses}" if clauses else "" + ) + + def when_matched_then_update(self): + clause = self.clause("UPDATE") + self.clauses.append(clause) + return clause + + def when_matched_then_delete(self): + clause = self.clause("DELETE") + self.clauses.append(clause) + return clause + + def when_not_matched_then_insert(self): + clause = self.clause("INSERT") + self.clauses.append(clause) + return clause diff --git a/src/snowflake/sqlalchemy/custom_commands/parquet_formater.py b/src/snowflake/sqlalchemy/custom_commands/parquet_formater.py new file mode 100644 index 00000000..26c72a3b --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/parquet_formater.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from snowflake.sqlalchemy.custom_commands.copy_formatter import CopyFormatter +from snowflake.sqlalchemy.custom_commands.utils import translate_bool + + +class PARQUETFormatter(CopyFormatter): + """Format specific functions""" + + file_format = "parquet" + + def snappy_compression(self, comp): + """Enable, or disable snappy compression""" + if not isinstance(comp, bool): + raise TypeError("Comp should be a Boolean value") + self.options["SNAPPY_COMPRESSION"] = translate_bool(comp) + return self + + def compression(self, comp): + """ + Set compression type + """ + if not isinstance(comp, str): + raise TypeError("Comp should be a str value") + self.options["COMPRESSION"] = comp + return self + + def binary_as_text(self, value): + """Enable, or disable binary as text""" + if not isinstance(value, bool): + raise TypeError("binary_as_text should be a Boolean value") + self.options["BINARY_AS_TEXT"] = translate_bool(value) + return self diff --git a/src/snowflake/sqlalchemy/custom_commands/utils.py b/src/snowflake/sqlalchemy/custom_commands/utils.py new file mode 100644 index 00000000..e70bb62e --- /dev/null +++ b/src/snowflake/sqlalchemy/custom_commands/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import false, true + + +def translate_bool(bln): + if bln: + return true() + return false() diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 802d1ce1..e402d7bd 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sqlalchemy.types as sqltypes import sqlalchemy.util as util diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py index c08aa734..4580fcc3 100644 --- a/src/snowflake/sqlalchemy/functions.py +++ b/src/snowflake/sqlalchemy/functions.py @@ -1,5 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings diff --git a/src/snowflake/sqlalchemy/provision.py b/src/snowflake/sqlalchemy/provision.py index 2c8368fd..c2d30f58 100644 --- a/src/snowflake/sqlalchemy/provision.py +++ b/src/snowflake/sqlalchemy/provision.py @@ -1,6 +1,17 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from sqlalchemy.testing.provision import set_default_schema_on_connection diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index f2844804..c5cd8c7e 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy.testing import exclusions from sqlalchemy.testing.requirements import SuiteRequirements diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 04305a00..d2ed8890 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import operator from collections import defaultdict diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index a1aefff9..4fafbacb 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import re from itertools import chain @@ -133,7 +143,7 @@ def parse_url_integer(value: str) -> int: # handle Snowflake BCR bcr-1057 -# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState +# the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.base.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that # can not handle the BCR change, we implement it in a way that lateral join does not need onclause def _find_left_clause_to_join_from(clauses, join_to, onclause): diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index d90f706b..346c7cc9 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # Update this for the versions # Don't change the forth version number from None VERSION = "1.6.1" diff --git a/tests/__init__.py b/tests/__init__.py index ef416f64..ada0a4e1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,13 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/conftest.py b/tests/conftest.py index d4dab3d1..e0d9be3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations import os diff --git a/tests/sqlalchemy_test_suite/__init__.py b/tests/sqlalchemy_test_suite/__init__.py index ef416f64..ada0a4e1 100644 --- a/tests/sqlalchemy_test_suite/__init__.py +++ b/tests/sqlalchemy_test_suite/__init__.py @@ -1,3 +1,13 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index f0464c7d..08e4f270 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sqlalchemy.testing.config from sqlalchemy import util diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index 643d1559..a74003f2 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from sqlalchemy import Integer, testing from sqlalchemy.schema import Column, Sequence, Table diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py index 1f79c4e9..55c0b3af 100644 --- a/tests/sqlalchemy_test_suite/test_suite_20.py +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from sqlalchemy import Integer, testing from sqlalchemy.schema import Column, Sequence, Table diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 40207b41..c96d98c8 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..cf02dab3 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table diff --git a/tests/test_core.py b/tests/test_core.py index 179133c8..4b98b244 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,17 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import decimal import json import os diff --git a/tests/test_create.py b/tests/test_create.py index 0b8b48fa..f31ff4b3 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from snowflake.sqlalchemy import ( AzureContainer, diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py index 2a1e1cb5..8b1acdc1 100644 --- a/tests/test_custom_functions.py +++ b/tests/test_custom_functions.py @@ -1,5 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from sqlalchemy import func diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index a997ffe8..85fa3e89 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from snowflake.sqlalchemy import custom_types diff --git a/tests/test_geography.py b/tests/test_geography.py index 7168d2a3..ae95c8bb 100644 --- a/tests/test_geography.py +++ b/tests/test_geography.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from json import loads diff --git a/tests/test_geometry.py b/tests/test_geometry.py index 742b518e..b8235823 100644 --- a/tests/test_geometry.py +++ b/tests/test_geometry.py @@ -1,6 +1,17 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from json import loads from sqlalchemy import Column, Integer, MetaData, Table diff --git a/tests/test_multivalues_insert.py b/tests/test_multivalues_insert.py index 5a91d3da..0091bc5a 100644 --- a/tests/test_multivalues_insert.py +++ b/tests/test_multivalues_insert.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy import Integer, Sequence, String from sqlalchemy.schema import Column, MetaData, Table diff --git a/tests/test_orm.py b/tests/test_orm.py index f53cd708..5ad010c8 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import enum import logging diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 63cd6d0e..ae77060c 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import operator import random diff --git a/tests/test_qmark.py b/tests/test_qmark.py index 3761181a..58f100a5 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys diff --git a/tests/test_quote.py b/tests/test_quote.py index ca6f36dd..d2501e00 100644 --- a/tests/test_quote.py +++ b/tests/test_quote.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table, inspect diff --git a/tests/test_semi_structured_datatypes.py b/tests/test_semi_structured_datatypes.py index c6c15bd4..3288f054 100644 --- a/tests/test_semi_structured_datatypes.py +++ b/tests/test_semi_structured_datatypes.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import textwrap diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 32fc390e..8b0a292f 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy import ( Column, diff --git a/tests/test_timestamp.py b/tests/test_timestamp.py index 72bc9173..f199da2f 100644 --- a/tests/test_timestamp.py +++ b/tests/test_timestamp.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from datetime import datetime diff --git a/tests/test_unit_core.py b/tests/test_unit_core.py index 27c7cf4a..3821ac98 100644 --- a/tests/test_unit_core.py +++ b/tests/test_unit_core.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from sqlalchemy.engine.url import URL diff --git a/tests/test_unit_cte.py b/tests/test_unit_cte.py index d3421653..9d114a27 100644 --- a/tests/test_unit_cte.py +++ b/tests/test_unit_cte.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. def test_cte(): diff --git a/tests/test_unit_types.py b/tests/test_unit_types.py index fc6d3c23..36e4ff48 100644 --- a/tests/test_unit_types.py +++ b/tests/test_unit_types.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import snowflake.sqlalchemy from snowflake.sqlalchemy.snowdialect import SnowflakeDialect diff --git a/tests/test_unit_url.py b/tests/test_unit_url.py index d2ca4e47..f3d801b0 100644 --- a/tests/test_unit_url.py +++ b/tests/test_unit_url.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import urllib.parse from snowflake.sqlalchemy import URL diff --git a/tests/util.py b/tests/util.py index db0b0c9c..87fe9400 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,6 +1,16 @@ +# Copyright (c) 2024 Snowflake Inc. # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations