Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --recursive option to fs ls command #513

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions databricks_cli/dbfs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,18 @@ class DbfsApi(object):
def __init__(self, api_client):
self.client = DbfsService(api_client)

def list_files(self, dbfs_path, headers=None):
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
def _recursive_list(self, *args, **kwargs):
paths = self.client.list_files(*args, **kwargs)
files = [p for p in paths if not p.is_dir]
for p in paths:
files = files + self._recursive_list(p) if p.is_dir else files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to read this a couple times to understand what you're doing. I think it should be simplified to iterate over paths just once and then have an if p.is_dir in there with both dealing with a file or a directory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate comment: this also need to pass along the headers from the **kwargs to be consistent with the list_files API today. Otherwise you use it on the first call but not later calls.

return files

def list_files(self, dbfs_path, headers=None, is_recursive=False):
if is_recursive:
list_response = self._recursive_list(dbfs_path, headers)
else:
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
if 'files' in list_response:
return [FileInfo.from_json(f) for f in list_response['files']]
else:
Expand Down
9 changes: 7 additions & 2 deletions databricks_cli/dbfs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@
@click.option('-l', is_flag=True, default=False,
help="""Displays full information including size, file type
and modification time since Epoch in milliseconds.""")
@click.option('--recursive', '-r', is_flag=True, default=False,
help='Displays all subdirectories and files.')
@click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType())
@debug_option
@profile_option
@eat_exceptions
@provide_api_client
def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA
"""
List files in DBFS.
"""
Expand All @@ -53,7 +55,10 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
dbfs_path = dbfs_path[0]
else:
error_and_quit('ls can take a maximum of one path.')
files = DbfsApi(api_client).list_files(dbfs_path)

files = DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive)
absolute = absolute or recursive

table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
tablefmt='plain')
click.echo(table)
Expand Down
44 changes: 34 additions & 10 deletions tests/dbfs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,26 @@

TEST_DBFS_PATH = DbfsPath('dbfs:/test')
DUMMY_TIME = 1613158406000
TEST_FILE_JSON = {
TEST_FILE_JSON1 = {
'path': '/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
TEST_FILE_JSON2 = {
'path': '/dir/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_DIR_JSON = {
'path': '/dir',
'is_dir': True,
'file_size': 0,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME)


def get_resource_does_not_exist_exception():
Expand Down Expand Up @@ -74,7 +87,7 @@ def test_to_row_long_form_not_absolute(self):
assert TEST_DBFS_PATH.basename == row[2]

def test_from_json(self):
file_info = api.FileInfo.from_json(TEST_FILE_JSON)
file_info = api.FileInfo.from_json(TEST_FILE_JSON0)
assert file_info.dbfs_path == TEST_DBFS_PATH
assert not file_info.is_dir
assert file_info.file_size == 1
Expand All @@ -89,15 +102,26 @@ def dbfs_api():


class TestDbfsApi(object):
def test_list_files_recursive(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON0, TEST_DIR_JSON, TEST_FILE_JSON1]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files("dbfs:/")

assert len(files) == 2
assert TEST_FILE_INFO0 == files[0]
assert TEST_FILE_INFO1 == files[1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be a test for a real recursive call.

The variable TEST_FILE_JSON2 is unused. When you make a test that uses it and does a real recursive call, I expect it to fail for the reason I commented about above; the mismatch between FileInfo and DbfsPath.

def test_list_files_exists(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON]
'files': [TEST_FILE_JSON0]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files(TEST_DBFS_PATH)
files = dbfs_api.list_files(TEST_DBFS_PATH, is_recursive=True)

assert len(files) == 1
assert TEST_FILE_INFO == files[0]
assert TEST_FILE_INFO0 == files[0]

def test_list_files_does_not_exist(self, dbfs_api):
json = {}
Expand All @@ -107,7 +131,7 @@ def test_list_files_does_not_exist(self, dbfs_api):
assert len(files) == 0

def test_file_exists_true(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
assert dbfs_api.file_exists(TEST_DBFS_PATH)

def test_file_exists_false(self, dbfs_api):
Expand All @@ -116,8 +140,8 @@ def test_file_exists_false(self, dbfs_api):
assert not dbfs_api.file_exists(TEST_DBFS_PATH)

def test_get_status(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO0

def test_get_status_fail(self, dbfs_api):
exception = get_resource_does_not_exist_exception()
Expand Down Expand Up @@ -164,7 +188,7 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir):

def test_get_file(self, dbfs_api, tmpdir):
api_mock = dbfs_api.client
api_mock.get_status.return_value = TEST_FILE_JSON
api_mock.get_status.return_value = TEST_FILE_JSON0
api_mock.read.return_value = {
'bytes_read': 1,
'data': b64encode(b'x'),
Expand Down