diff --git a/databricks_cli/dbfs/api.py b/databricks_cli/dbfs/api.py index 0be47f57..773aa7c0 100644 --- a/databricks_cli/dbfs/api.py +++ b/databricks_cli/dbfs/api.py @@ -91,8 +91,18 @@ class DbfsApi(object): def __init__(self, api_client): self.client = DbfsService(api_client) - def list_files(self, dbfs_path, headers=None): - list_response = self.client.list(dbfs_path.absolute_path, headers=headers) + def _recursive_list(self, *args, **kwargs): + paths = self.client.list_files(*args, **kwargs) + files = [p for p in paths if not p.is_dir] + for p in paths: + files = files + self._recursive_list(p) if p.is_dir else files + return files + + def list_files(self, dbfs_path, headers=None, is_recursive=False): + if is_recursive: + list_response = self._recursive_list(dbfs_path, headers) + else: + list_response = self.client.list(dbfs_path.absolute_path, headers=headers) if 'files' in list_response: return [FileInfo.from_json(f) for f in list_response['files']] else: diff --git a/databricks_cli/dbfs/cli.py b/databricks_cli/dbfs/cli.py index 83ccf315..dc99f276 100644 --- a/databricks_cli/dbfs/cli.py +++ b/databricks_cli/dbfs/cli.py @@ -38,12 +38,14 @@ @click.option('-l', is_flag=True, default=False, help="""Displays full information including size, file type and modification time since Epoch in milliseconds.""") +@click.option('--recursive', '-r', is_flag=True, default=False, + help='Displays all subdirectories and files.') @click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType()) @debug_option @profile_option @eat_exceptions @provide_api_client -def ls_cli(api_client, l, absolute, dbfs_path): # NOQA +def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA """ List files in DBFS. """ @@ -53,7 +55,10 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA dbfs_path = dbfs_path[0] else: error_and_quit('ls can take a maximum of one path.') - files = DbfsApi(api_client).list_files(dbfs_path) + + files = DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive) + absolute = absolute or recursive + table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files], tablefmt='plain') click.echo(table) diff --git a/tests/dbfs/test_api.py b/tests/dbfs/test_api.py index ded0c976..da5bb5e8 100644 --- a/tests/dbfs/test_api.py +++ b/tests/dbfs/test_api.py @@ -36,13 +36,26 @@ TEST_DBFS_PATH = DbfsPath('dbfs:/test') DUMMY_TIME = 1613158406000 -TEST_FILE_JSON = { +TEST_FILE_JSON1 = { 'path': '/test', 'is_dir': False, 'file_size': 1, 'modification_time': DUMMY_TIME } -TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME) +TEST_FILE_JSON2 = { + 'path': '/dir/test', + 'is_dir': False, + 'file_size': 1, + 'modification_time': DUMMY_TIME +} +TEST_DIR_JSON = { + 'path': '/dir', + 'is_dir': True, + 'file_size': 0, + 'modification_time': DUMMY_TIME +} +TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME) +TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME) def get_resource_does_not_exist_exception(): @@ -74,7 +87,7 @@ def test_to_row_long_form_not_absolute(self): assert TEST_DBFS_PATH.basename == row[2] def test_from_json(self): - file_info = api.FileInfo.from_json(TEST_FILE_JSON) + file_info = api.FileInfo.from_json(TEST_FILE_JSON0) assert file_info.dbfs_path == TEST_DBFS_PATH assert not file_info.is_dir assert file_info.file_size == 1 @@ -89,15 +102,26 @@ def dbfs_api(): class TestDbfsApi(object): + def test_list_files_recursive(self, dbfs_api): + json = { + 'files': [TEST_FILE_JSON0, TEST_DIR_JSON, TEST_FILE_JSON1] + } + dbfs_api.client.list.return_value = json + files = dbfs_api.list_files("dbfs:/") + + assert len(files) == 2 + assert TEST_FILE_INFO0 == files[0] + assert TEST_FILE_INFO1 == files[1] + def test_list_files_exists(self, dbfs_api): json = { - 'files': [TEST_FILE_JSON] + 'files': [TEST_FILE_JSON0] } dbfs_api.client.list.return_value = json - files = dbfs_api.list_files(TEST_DBFS_PATH) + files = dbfs_api.list_files(TEST_DBFS_PATH, is_recursive=True) assert len(files) == 1 - assert TEST_FILE_INFO == files[0] + assert TEST_FILE_INFO0 == files[0] def test_list_files_does_not_exist(self, dbfs_api): json = {} @@ -107,7 +131,7 @@ def test_list_files_does_not_exist(self, dbfs_api): assert len(files) == 0 def test_file_exists_true(self, dbfs_api): - dbfs_api.client.get_status.return_value = TEST_FILE_JSON + dbfs_api.client.get_status.return_value = TEST_FILE_JSON0 assert dbfs_api.file_exists(TEST_DBFS_PATH) def test_file_exists_false(self, dbfs_api): @@ -116,8 +140,8 @@ def test_file_exists_false(self, dbfs_api): assert not dbfs_api.file_exists(TEST_DBFS_PATH) def test_get_status(self, dbfs_api): - dbfs_api.client.get_status.return_value = TEST_FILE_JSON - assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO + dbfs_api.client.get_status.return_value = TEST_FILE_JSON0 + assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO0 def test_get_status_fail(self, dbfs_api): exception = get_resource_does_not_exist_exception() @@ -164,7 +188,7 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir): def test_get_file(self, dbfs_api, tmpdir): api_mock = dbfs_api.client - api_mock.get_status.return_value = TEST_FILE_JSON + api_mock.get_status.return_value = TEST_FILE_JSON0 api_mock.read.return_value = { 'bytes_read': 1, 'data': b64encode(b'x'),