From 3330fafbb231a23731b8ad5a52ed9aacb17930c1 Mon Sep 17 00:00:00 2001 From: Justin Mai Date: Wed, 13 Jul 2022 12:19:59 -0700 Subject: [PATCH 1/4] !513 refactor recursive dbfs ls and add test --- databricks_cli/dbfs/api.py | 14 ++++++++++-- databricks_cli/dbfs/cli.py | 9 ++++++-- tests/dbfs/test_api.py | 47 +++++++++++++++++++++++++++++--------- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/databricks_cli/dbfs/api.py b/databricks_cli/dbfs/api.py index 0be47f57..1794f181 100644 --- a/databricks_cli/dbfs/api.py +++ b/databricks_cli/dbfs/api.py @@ -91,8 +91,18 @@ class DbfsApi(object): def __init__(self, api_client): self.client = DbfsService(api_client) - def list_files(self, dbfs_path, headers=None): - list_response = self.client.list(dbfs_path.absolute_path, headers=headers) + def _recursive_list(self, **kwargs): + paths = self.client.list_files(**kwargs) + files = [p for p in paths if not p.is_dir] + for p in paths: + files = files + self._recursive_list(p) if p.is_dir else files + return files + + def list_files(self, dbfs_path, headers=None, is_recursive=False): + if is_recursive: + list_response = self._recursive_list(dbfs_path, headers) + else: + list_response = self.client.list(dbfs_path.absolute_path, headers=headers) if 'files' in list_response: return [FileInfo.from_json(f) for f in list_response['files']] else: diff --git a/databricks_cli/dbfs/cli.py b/databricks_cli/dbfs/cli.py index 83ccf315..b6e50e03 100644 --- a/databricks_cli/dbfs/cli.py +++ b/databricks_cli/dbfs/cli.py @@ -38,12 +38,14 @@ @click.option('-l', is_flag=True, default=False, help="""Displays full information including size, file type and modification time since Epoch in milliseconds.""") +@click.option('--recursive', '-r', is_flag=True, default=False, + help='Displays all subdirectories and files.') @click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType()) @debug_option @profile_option @eat_exceptions @provide_api_client -def ls_cli(api_client, l, absolute, dbfs_path): # NOQA +def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA """ List files in DBFS. """ @@ -53,7 +55,10 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA dbfs_path = dbfs_path[0] else: error_and_quit('ls can take a maximum of one path.') - files = DbfsApi(api_client).list_files(dbfs_path) + + DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive) + absolute = absolute or recursive + table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files], tablefmt='plain') click.echo(table) diff --git a/tests/dbfs/test_api.py b/tests/dbfs/test_api.py index ded0c976..76fadf83 100644 --- a/tests/dbfs/test_api.py +++ b/tests/dbfs/test_api.py @@ -36,13 +36,26 @@ TEST_DBFS_PATH = DbfsPath('dbfs:/test') DUMMY_TIME = 1613158406000 -TEST_FILE_JSON = { +TEST_FILE_JSON1 = { 'path': '/test', 'is_dir': False, 'file_size': 1, 'modification_time': DUMMY_TIME } -TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME) +TEST_FILE_JSON2 = { + 'path': '/dir/test', + 'is_dir': False, + 'file_size': 1, + 'modification_time': DUMMY_TIME +} +TEST_DIR_JSON = { + 'path': '/dir', + 'is_dir': True, + 'file_size': 0, + 'modification_time': DUMMY_TIME +} +TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME) +TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME) def get_resource_does_not_exist_exception(): @@ -74,7 +87,7 @@ def test_to_row_long_form_not_absolute(self): assert TEST_DBFS_PATH.basename == row[2] def test_from_json(self): - file_info = api.FileInfo.from_json(TEST_FILE_JSON) + file_info = api.FileInfo.from_json(TEST_FILE_JSON0) assert file_info.dbfs_path == TEST_DBFS_PATH assert not file_info.is_dir assert file_info.file_size == 1 @@ -89,15 +102,26 @@ def dbfs_api(): class TestDbfsApi(object): + def test_list_files_recursive(self, dbfs_api): + json = { + 'files': [TEST_FILE_JSON0, TEST_DIR_JSON, TEST_FILE_JSON1] + } + dbfs_api.client.list.return_value = json + files = dbfs_api.list_files("dbfs:/") + + assert len(files) == 2 + assert TEST_FILE_INFO0 == files[0] + assert TEST_FILE_INFO1 == files[1] + def test_list_files_exists(self, dbfs_api): json = { - 'files': [TEST_FILE_JSON] + 'files': [TEST_FILE_JSON0] } dbfs_api.client.list.return_value = json - files = dbfs_api.list_files(TEST_DBFS_PATH) + files = dbfs_api.list_files(TEST_DBFS_PATH, is_recursive=True) assert len(files) == 1 - assert TEST_FILE_INFO == files[0] + assert TEST_FILE_INFO0 == files[0] def test_list_files_does_not_exist(self, dbfs_api): json = {} @@ -107,7 +131,7 @@ def test_list_files_does_not_exist(self, dbfs_api): assert len(files) == 0 def test_file_exists_true(self, dbfs_api): - dbfs_api.client.get_status.return_value = TEST_FILE_JSON + dbfs_api.client.get_status.return_value = TEST_FILE_JSON0 assert dbfs_api.file_exists(TEST_DBFS_PATH) def test_file_exists_false(self, dbfs_api): @@ -116,8 +140,8 @@ def test_file_exists_false(self, dbfs_api): assert not dbfs_api.file_exists(TEST_DBFS_PATH) def test_get_status(self, dbfs_api): - dbfs_api.client.get_status.return_value = TEST_FILE_JSON - assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO + dbfs_api.client.get_status.return_value = TEST_FILE_JSON0 + assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO0 def test_get_status_fail(self, dbfs_api): exception = get_resource_does_not_exist_exception() @@ -151,7 +175,8 @@ def test_put_large_file(self, dbfs_api, tmpdir): dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True) assert api_mock.add_block.call_count == 1 assert test_handle == api_mock.add_block.call_args[0][0] - assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1] + assert b64encode(b'test').decode( + ) == api_mock.add_block.call_args[0][1] assert api_mock.close.call_count == 1 assert test_handle == api_mock.close.call_args[0][0] @@ -164,7 +189,7 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir): def test_get_file(self, dbfs_api, tmpdir): api_mock = dbfs_api.client - api_mock.get_status.return_value = TEST_FILE_JSON + api_mock.get_status.return_value = TEST_FILE_JSON0 api_mock.read.return_value = { 'bytes_read': 1, 'data': b64encode(b'x'), From 9389901886a45c5224003bf150284ef069710e15 Mon Sep 17 00:00:00 2001 From: Justin Mai Date: Wed, 13 Jul 2022 12:30:25 -0700 Subject: [PATCH 2/4] !513 fix args for list_files() --- databricks_cli/dbfs/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/databricks_cli/dbfs/api.py b/databricks_cli/dbfs/api.py index 1794f181..773aa7c0 100644 --- a/databricks_cli/dbfs/api.py +++ b/databricks_cli/dbfs/api.py @@ -91,8 +91,8 @@ class DbfsApi(object): def __init__(self, api_client): self.client = DbfsService(api_client) - def _recursive_list(self, **kwargs): - paths = self.client.list_files(**kwargs) + def _recursive_list(self, *args, **kwargs): + paths = self.client.list_files(*args, **kwargs) files = [p for p in paths if not p.is_dir] for p in paths: files = files + self._recursive_list(p) if p.is_dir else files From b0fb9cfa9360a87eae5d2c91e6cc076a9a184205 Mon Sep 17 00:00:00 2001 From: JM Date: Tue, 19 Jul 2022 10:56:58 -0700 Subject: [PATCH 3/4] Update cli.py fix Dbfs files assignment --- databricks_cli/dbfs/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks_cli/dbfs/cli.py b/databricks_cli/dbfs/cli.py index b6e50e03..dc99f276 100644 --- a/databricks_cli/dbfs/cli.py +++ b/databricks_cli/dbfs/cli.py @@ -56,7 +56,7 @@ def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA else: error_and_quit('ls can take a maximum of one path.') - DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive) + files = DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive) absolute = absolute or recursive table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files], From 7060ef042d0afa831769b7c00f41c94bc6a7ef81 Mon Sep 17 00:00:00 2001 From: JM Date: Wed, 3 Aug 2022 09:37:35 -0700 Subject: [PATCH 4/4] revert erroneous changes --- tests/dbfs/test_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/dbfs/test_api.py b/tests/dbfs/test_api.py index 76fadf83..da5bb5e8 100644 --- a/tests/dbfs/test_api.py +++ b/tests/dbfs/test_api.py @@ -175,8 +175,7 @@ def test_put_large_file(self, dbfs_api, tmpdir): dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True) assert api_mock.add_block.call_count == 1 assert test_handle == api_mock.add_block.call_args[0][0] - assert b64encode(b'test').decode( - ) == api_mock.add_block.call_args[0][1] + assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1] assert api_mock.close.call_count == 1 assert test_handle == api_mock.close.call_args[0][0]