1
1
import io
2
2
import json
3
+ import logging
3
4
import time
4
5
from pathlib import Path
5
6
from unittest .mock import create_autospec , patch , Mock
16
17
from databricks .sdk .service .iam import ComplexValue , User
17
18
from databricks .sdk .service .jobs import Run , RunResultState , RunState
18
19
from databricks .sdk .service .provisioning import Workspace
19
- from databricks .sdk .service .workspace import ImportFormat , ObjectInfo , ObjectType
20
+ from databricks .sdk .service .workspace import ExportFormat , ImportFormat , ObjectInfo , ObjectType
20
21
21
22
from databricks .labs .ucx .assessment .aws import AWSResources , AWSRoleAction
22
23
from databricks .labs .ucx .aws .access import AWSResourcePermissions
31
32
create_missing_principals ,
32
33
create_table_mapping ,
33
34
create_uber_principal ,
35
+ download ,
34
36
ensure_assessment_run ,
35
37
installations ,
36
38
join_collection ,
@@ -106,7 +108,7 @@ def create_workspace_client_mock(workspace_id: int) -> WorkspaceClient:
106
108
""" ,
107
109
}
108
110
109
- def download (path : str ) -> io .StringIO | io .BytesIO :
111
+ def mock_download (path : str , ** _ ) -> io .StringIO | io .BytesIO :
110
112
if path not in state :
111
113
raise NotFound (path )
112
114
if ".csv" in path or ".log" in path :
@@ -117,7 +119,7 @@ def download(path: str) -> io.StringIO | io.BytesIO:
117
119
workspace_client .get_workspace_id .return_value = workspace_id
118
120
workspace_client .config .host = 'https://localhost'
119
121
workspace_client .current_user .me .return_value = User (user_name = "foo" , groups = [ComplexValue (display = "admins" )])
120
- workspace_client .workspace .download = download
122
+ workspace_client .workspace .download . side_effect = mock_download
121
123
workspace_client .statement_execution .execute_statement .return_value = sql .StatementResponse (
122
124
status = sql .StatementStatus (state = sql .StatementState .SUCCEEDED ),
123
125
manifest = sql .ResultManifest (schema = sql .ResultSchema ()),
@@ -798,3 +800,72 @@ def test_join_collection():
798
800
w .workspace .download .return_value = io .StringIO (json .dumps ([{"workspace_id" : 123 , "workspace_name" : "some" }]))
799
801
join_collection (a , "123" )
800
802
w .workspace .download .assert_not_called ()
803
+
804
+
805
+ def test_download_raises_value_error_if_not_downloading_a_csv (ws1 ):
806
+ with pytest .raises (ValueError ) as e :
807
+ download (Path ("test.txt" ), ws1 )
808
+ assert "Command only supported for CSV files" in str (e )
809
+
810
+
811
+ @pytest .mark .parametrize ("run_as_collection" , [False , True ])
812
+ def test_download_calls_workspace_download (tmp_path , workspace_clients , acc_client , run_as_collection ):
813
+ if not run_as_collection :
814
+ workspace_clients = [workspace_clients [0 ]]
815
+
816
+ download (
817
+ tmp_path / "test.csv" ,
818
+ workspace_clients [0 ],
819
+ run_as_collection = run_as_collection ,
820
+ a = acc_client ,
821
+ )
822
+
823
+ for ws in workspace_clients :
824
+ ws .workspace .download .assert_called_with (
825
+ "/Users/foo/.ucx/test.csv" ,
826
+ format = ExportFormat .AUTO ,
827
+ )
828
+
829
+
830
+ def test_download_warns_if_file_not_found (caplog , ws1 , acc_client ):
831
+ ws1 .workspace .download .side_effect = NotFound ("test.csv" )
832
+ with caplog .at_level (logging .WARNING , logger = "databricks.labs.ucx.cli" ):
833
+ download (
834
+ Path ("test.csv" ),
835
+ ws1 ,
836
+ run_as_collection = False ,
837
+ a = acc_client ,
838
+ )
839
+ assert "File not found for https://localhost: /Users/foo/.ucx/test.csv" in caplog .messages
840
+ assert "No file(s) to download found" in caplog .messages
841
+
842
+
843
+ def test_download_deletes_empty_file (tmp_path , ws1 , acc_client ):
844
+ ws1 .workspace .download .side_effect = NotFound ("test.csv" )
845
+ mapping_path = tmp_path / "mapping.csv"
846
+ download (
847
+ mapping_path ,
848
+ ws1 ,
849
+ run_as_collection = False ,
850
+ a = acc_client ,
851
+ )
852
+ assert not mapping_path .is_file ()
853
+
854
+
855
+ def test_download_has_expected_content (tmp_path , workspace_clients , acc_client ):
856
+ expected = (
857
+ "workspace_name,catalog_name,src_schema,dst_schema,src_table,dst_table"
858
+ "\n test,test,test,test,test,test"
859
+ "\n test,test,test,test,test,test"
860
+ )
861
+ mapping_path = tmp_path / "mapping.csv"
862
+
863
+ download (
864
+ mapping_path ,
865
+ workspace_clients [0 ],
866
+ run_as_collection = True ,
867
+ a = acc_client ,
868
+ )
869
+
870
+ content = mapping_path .read_text ()
871
+ assert content == expected
0 commit comments