From f9217fc9e527d942f7d389658968630722b06ec1 Mon Sep 17 00:00:00 2001 From: Chunhui Shi Date: Wed, 17 Mar 2021 10:05:13 -0700 Subject: [PATCH] SNOW-299875 Mock multipart threshold for tests instead of passing an additional threshold to exec() (#658) --- src/snowflake/connector/cursor.py | 6 +----- test/integ/test_large_put.py | 31 +++++++++++++++++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index b63404b34..ca0929f37 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -506,7 +506,6 @@ def execute( _is_put_get: Optional[bool] = None, _raise_put_get_error: bool = True, _force_put_overwrite: bool = False, - _multipart_threshold: Optional[int] = None, file_stream: Optional[IO[bytes]] = None, ): """Executes a command/query. @@ -533,7 +532,6 @@ def execute( _raise_put_get_error: Whether to raise PUT and GET errors. _force_put_overwrite: If the SQL query is a PUT, then this flag can force overwriting of an already existing file on stage. - _multipart_threshold: use internally to decide multipart threshold to overwrite threshold from server file_stream: File-like object to be uploaded with PUT Returns: @@ -654,9 +652,7 @@ def execute( force_put_overwrite=_force_put_overwrite or data.get("overwrite", False), source_from_stream=file_stream, - multipart_threshold=data.get("threshold") - if _multipart_threshold is None - else _multipart_threshold, + multipart_threshold=data.get("threshold"), ) sf_file_transfer_agent.execute() data = sf_file_transfer_agent.result() diff --git a/test/integ/test_large_put.py b/test/integ/test_large_put.py index 97823d68e..044af3925 100644 --- a/test/integ/test_large_put.py +++ b/test/integ/test_large_put.py @@ -7,6 +7,9 @@ import os import pytest +from mock import patch + +from snowflake.connector.file_transfer_agent import SnowflakeFileTransferAgent from ..generate_test_files import generate_k_lines_of_n_files @@ -50,14 +53,26 @@ def test_put_copy_large_files(tmpdir, conn_cnx, db_parameters): password=db_parameters["password"], ) as cnx: files = files.replace("\\", "\\\\") - cnx.cursor().execute( - "put 'file://{file}' @%{name}".format( - file=files, - name=db_parameters["name"], - ), - # add _multipart_threshold so the PUT will not use the threshold(200MB) from server. - _multipart_threshold=1000000, - ) + + def mocked_file_agent(*args, **kwargs): + newkwargs = kwargs.copy() + newkwargs.update(multipart_threshold=10000) + agent = SnowflakeFileTransferAgent(*args, **newkwargs) + mocked_file_agent.agent = agent + return agent + + with patch( + "snowflake.connector.cursor.SnowflakeFileTransferAgent", + side_effect=mocked_file_agent, + ): + cnx.cursor().execute( + "put 'file://{file}' @%{name}".format( + file=files, + name=db_parameters["name"], + ), + ) + assert mocked_file_agent.agent._multipart_threshold == 10000 + c = cnx.cursor() try: c.execute("copy into {}".format(db_parameters["name"]))