Skip to content

Commit

Permalink
SNOW-1212541-ordered-sequence: add support for creating sequences ord…
Browse files Browse the repository at this point in the history
…er (#473)
  • Loading branch information
sfc-gh-mraba authored Mar 11, 2024
1 parent 8ac9f33 commit f384432
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 141 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
exclude: '^(.*egg.info.*|.*/parameters.py).*$'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
exclude: .github/repo_meta.yaml
- id: debug-statements
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v2.37.3
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py37-plus]
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 24.2.0
hooks:
- id: black
args:
- --safe
language_version: python3
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.3.0
rev: v1.5.5
hooks:
- id: insert-license
name: insert-py-license
Expand All @@ -39,7 +39,7 @@ repos:
- --license-filepath
- license_header.txt
- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
Expand Down
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Source code is also available at:

# Release Notes

- 1.5.2

- Add support for sequence ordering in tests

- v1.5.1(November 03, 2023)

- Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057.
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[tool.ruff]
line-length = 88

[tool.black]
line-length = 88
65 changes: 47 additions & 18 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
@CompileState.plugin_for("default", "select")
class SnowflakeSelectState(SelectState):
def _setup_joins(self, args, raw_columns):
for (right, onclause, left, flags) in args:
for right, onclause, left, flags in args:
isouter = flags["isouter"]
full = flags["full"]

Expand Down Expand Up @@ -579,9 +579,11 @@ def visit_copy_into(self, copy_into, **kw):
[
"{} = {}".format(
n,
v._compiler_dispatch(self, **kw)
if getattr(v, "compiler_dispatch", False)
else str(v),
(
v._compiler_dispatch(self, **kw)
if getattr(v, "compiler_dispatch", False)
else str(v)
),
)
for n, v in options_list
]
Expand All @@ -604,20 +606,24 @@ def visit_copy_formatter(self, formatter, **kw):
return f"FILE_FORMAT=(format_name = {formatter.options['format_name']})"
return "FILE_FORMAT=(TYPE={}{})".format(
formatter.file_format,
" "
+ " ".join(
[
"{}={}".format(
name,
value._compiler_dispatch(self, **kw)
if hasattr(value, "_compiler_dispatch")
else formatter.value_repr(name, value),
)
for name, value in options_list
]
)
if formatter.options
else "",
(
" "
+ " ".join(
[
"{}={}".format(
name,
(
value._compiler_dispatch(self, **kw)
if hasattr(value, "_compiler_dispatch")
else formatter.value_repr(name, value)
),
)
for name, value in options_list
]
)
if formatter.options
else ""
),
)

def visit_aws_bucket(self, aws_bucket, **kw):
Expand Down Expand Up @@ -967,6 +973,29 @@ def visit_identity_column(self, identity, **kw):
text += f"({start},{increment})"
return text

def get_identity_options(self, identity_options):
text = []
if identity_options.increment is not None:
text.append(f"INCREMENT BY {identity_options.increment:d}")
if identity_options.start is not None:
text.append(f"START WITH {identity_options.start:d}")
if identity_options.minvalue is not None:
text.append(f"MINVALUE {identity_options.minvalue:d}")
if identity_options.maxvalue is not None:
text.append(f"MAXVALUE {identity_options.maxvalue:d}")
if identity_options.nominvalue is not None:
text.append("NO MINVALUE")
if identity_options.nomaxvalue is not None:
text.append("NO MAXVALUE")
if identity_options.cache is not None:
text.append(f"CACHE {identity_options.cache:d}")
if identity_options.cycle is not None:
text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
if identity_options.order is not None:
text.append("ORDER" if identity_options.order else "NOORDER")

return " ".join(text)


class SnowflakeTypeCompiler(compiler.GenericTypeCompiler):
def visit_BYTEINT(self, type_, **kw):
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ def field_delimiter(self, deli_type):

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
responsible for specifying a valid file extension that can be read by the desired software or service."""
responsible for specifying a valid file extension that can be read by the desired software or service.
"""
if not isinstance(ext, (NoneType, string_types)):
raise TypeError("File extension should be a string")
self.options["FILE_EXTENSION"] = ext
Expand Down Expand Up @@ -386,7 +387,8 @@ def compression(self, comp_type):

def file_extension(self, ext):
"""String that specifies the extension for files unloaded to a stage. Accepts any extension. The user is
responsible for specifying a valid file extension that can be read by the desired software or service."""
responsible for specifying a valid file extension that can be read by the desired software or service.
"""
if not isinstance(ext, (NoneType, string_types)):
raise TypeError("File extension should be a string")
self.options["FILE_EXTENSION"] = ext
Expand Down
40 changes: 24 additions & 16 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,13 @@ def _get_schema_columns(self, connection, schema, **kw):
"autoincrement": is_identity == "YES",
"comment": comment,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
(
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False
),
}
)
if is_identity == "YES":
Expand Down Expand Up @@ -688,11 +690,13 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
"autoincrement": is_identity == "YES",
"comment": comment if comment != "" else None,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
(
column_name
in schema_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False
),
}
)

Expand Down Expand Up @@ -876,18 +880,22 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
result = self._get_view_comment(connection, table_name, schema)

return {
"text": result._mapping["comment"]
if result and result._mapping["comment"]
else None
"text": (
result._mapping["comment"]
if result and result._mapping["comment"]
else None
)
}

def connect(self, *cargs, **cparams):
return (
super().connect(
*cargs,
**_update_connection_application_name(**cparams)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else cparams,
**(
_update_connection_application_name(**cparams)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else cparams
),
)
if _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME
else super().connect(*cargs, **cparams)
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/sqlalchemy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,14 @@ def __init__(
else:
adapt_from = left_selectable

(pj, sj, source, dest, secondary, target_adapter,) = prop._create_joins(
(
pj,
sj,
source,
dest,
secondary,
target_adapter,
) = prop._create_joins(
source_selectable=adapt_from,
dest_selectable=adapt_to,
source_polymorphic=True,
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/sqlalchemy/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
#
# Update this for the versions
# Don't change the forth version number from None
VERSION = (1, 5, 1, None)
VERSION = (1, 5, 2, None)
79 changes: 45 additions & 34 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Table,
UniqueConstraint,
dialects,
insert,
inspect,
text,
)
Expand Down Expand Up @@ -1424,47 +1425,51 @@ def test_special_schema_character(db_parameters, on_public_ci):


def test_autoincrement(engine_testaccount):
"""Snowflake does not guarantee generating sequence numbers without gaps.
The generated numbers are not necessarily contiguous.
https://docs.snowflake.com/en/user-guide/querying-sequences
"""
metadata = MetaData()
users = Table(
"users",
metadata,
Column("uid", Integer, Sequence("id_seq"), primary_key=True),
Column("uid", Integer, Sequence("id_seq", order=True), primary_key=True),
Column("name", String(39)),
)

try:
users.create(engine_testaccount)

with engine_testaccount.connect() as connection:
with connection.begin():
connection.execute(users.insert(), [{"name": "sf1"}])
assert connection.execute(select(users)).fetchall() == [(1, "sf1")]
connection.execute(users.insert(), [{"name": "sf2"}, {"name": "sf3"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
]
connection.execute(users.insert(), {"name": "sf4"})
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
]

seq = Sequence("id_seq")
nextid = connection.execute(seq)
connection.execute(users.insert(), [{"uid": nextid, "name": "sf5"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
(5, "sf5"),
]
metadata.create_all(engine_testaccount)

with engine_testaccount.begin() as connection:
connection.execute(insert(users), [{"name": "sf1"}])
assert connection.execute(select(users)).fetchall() == [(1, "sf1")]
connection.execute(insert(users), [{"name": "sf2"}, {"name": "sf3"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
]
connection.execute(insert(users), {"name": "sf4"})
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
]

seq = Sequence("id_seq")
nextid = connection.execute(seq)
connection.execute(insert(users), [{"uid": nextid, "name": "sf5"}])
assert connection.execute(select(users)).fetchall() == [
(1, "sf1"),
(2, "sf2"),
(3, "sf3"),
(4, "sf4"),
(5, "sf5"),
]
finally:
users.drop(engine_testaccount)
metadata.drop_all(engine_testaccount)


@pytest.mark.skip(
Expand Down Expand Up @@ -1869,10 +1874,16 @@ def test_snowflake_sqlalchemy_as_valid_client_type():
)
snowflake.connector.connection.DEFAULT_CONFIGURATION[
"internal_application_name"
] = ("PythonConnector", (type(None), str))
] = (
"PythonConnector",
(type(None), str),
)
snowflake.connector.connection.DEFAULT_CONFIGURATION[
"internal_application_version"
] = ("3.0.0", (type(None), str))
] = (
"3.0.0",
(type(None), str),
)
engine = create_engine(
URL(
user=CONNECTION_PARAMETERS["user"],
Expand Down
Loading

0 comments on commit f384432

Please sign in to comment.