Skip to content

Commit

Permalink
SNOW-299875 Mock multipart threshold for tests instead of passing an …
Browse files Browse the repository at this point in the history
…additional threshold to exec() (#658)
  • Loading branch information
sfc-gh-cshi authored Mar 17, 2021
1 parent 63215d8 commit f9217fc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
6 changes: 1 addition & 5 deletions src/snowflake/connector/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,6 @@ def execute(
_is_put_get: Optional[bool] = None,
_raise_put_get_error: bool = True,
_force_put_overwrite: bool = False,
_multipart_threshold: Optional[int] = None,
file_stream: Optional[IO[bytes]] = None,
):
"""Executes a command/query.
Expand All @@ -533,7 +532,6 @@ def execute(
_raise_put_get_error: Whether to raise PUT and GET errors.
_force_put_overwrite: If the SQL query is a PUT, then this flag can force overwriting of an already
existing file on stage.
_multipart_threshold: use internally to decide multipart threshold to overwrite threshold from server
file_stream: File-like object to be uploaded with PUT
Returns:
Expand Down Expand Up @@ -654,9 +652,7 @@ def execute(
force_put_overwrite=_force_put_overwrite
or data.get("overwrite", False),
source_from_stream=file_stream,
multipart_threshold=data.get("threshold")
if _multipart_threshold is None
else _multipart_threshold,
multipart_threshold=data.get("threshold"),
)
sf_file_transfer_agent.execute()
data = sf_file_transfer_agent.result()
Expand Down
31 changes: 23 additions & 8 deletions test/integ/test_large_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import os

import pytest
from mock import patch

from snowflake.connector.file_transfer_agent import SnowflakeFileTransferAgent

from ..generate_test_files import generate_k_lines_of_n_files

Expand Down Expand Up @@ -50,14 +53,26 @@ def test_put_copy_large_files(tmpdir, conn_cnx, db_parameters):
password=db_parameters["password"],
) as cnx:
files = files.replace("\\", "\\\\")
cnx.cursor().execute(
"put 'file://{file}' @%{name}".format(
file=files,
name=db_parameters["name"],
),
# add _multipart_threshold so the PUT will not use the threshold(200MB) from server.
_multipart_threshold=1000000,
)

def mocked_file_agent(*args, **kwargs):
newkwargs = kwargs.copy()
newkwargs.update(multipart_threshold=10000)
agent = SnowflakeFileTransferAgent(*args, **newkwargs)
mocked_file_agent.agent = agent
return agent

with patch(
"snowflake.connector.cursor.SnowflakeFileTransferAgent",
side_effect=mocked_file_agent,
):
cnx.cursor().execute(
"put 'file://{file}' @%{name}".format(
file=files,
name=db_parameters["name"],
),
)
assert mocked_file_agent.agent._multipart_threshold == 10000

c = cnx.cursor()
try:
c.execute("copy into {}".format(db_parameters["name"]))
Expand Down

0 comments on commit f9217fc

Please sign in to comment.