Skip to content

Commit

Permalink
Add missing fields to JobsService (#614)
Browse files Browse the repository at this point in the history
This change renames the argument `name_filter` to `name` to match the API.
  • Loading branch information
pietern authored Mar 21, 2023
1 parent d4909e0 commit dae9d9a
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 16 deletions.
6 changes: 3 additions & 3 deletions databricks_cli/jobs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def create_job(self, json, headers=None, version=None):
version=version)

def list_jobs(self, job_type=None, expand_tasks=None, offset=None, limit=None, headers=None,
version=None, name_filter=None):
version=None, name=None):
resp = self.client.list_jobs(job_type=job_type, expand_tasks=expand_tasks, offset=offset,
limit=limit, headers=headers, version=version,
name_filter=name_filter)
name=name)
if 'jobs' not in resp:
resp['jobs'] = []
return resp
Expand All @@ -57,6 +57,6 @@ def run_now(self, job_id, jar_params, notebook_params, python_params, spark_subm
idempotency_token, headers=headers, version=version)

def _list_jobs_by_name(self, name, headers=None):
jobs = self.list_jobs(headers=headers, name_filter=name)['jobs']
jobs = self.list_jobs(headers=headers, name=name)['jobs']
result = list(filter(lambda job: job['settings']['name'] == name, jobs))
return result
8 changes: 4 additions & 4 deletions databricks_cli/jobs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def _jobs_to_table(jobs_json):
@click.option('--all', '_all', is_flag=True,
help='Lists all jobs by executing sequential calls to the API ' +
'(only available in API 2.1).')
@click.option('--name', 'name_filter', default=None, type=str,
@click.option('--name', 'name', default=None, type=str,
help='If provided, only returns jobs that match the supplied ' +
'name (only available in API 2.1).')
@api_version_option
@debug_option
@profile_option
@eat_exceptions
@provide_api_client
def list_cli(api_client, output, job_type, version, expand_tasks, offset, limit, _all, name_filter):
def list_cli(api_client, output, job_type, version, expand_tasks, offset, limit, _all, name):
"""
Lists the jobs in the Databricks Job Service.
Expand All @@ -154,7 +154,7 @@ def list_cli(api_client, output, job_type, version, expand_tasks, offset, limit,
"""
check_version(api_client, version)
api_version = version or api_client.jobs_api_version
using_features_only_in_21 = expand_tasks or offset or limit or _all or name_filter
using_features_only_in_21 = expand_tasks or offset or limit or _all or name
if api_version != '2.1' and using_features_only_in_21:
click.echo(click.style('ERROR', fg='red') + ': the options --expand-tasks, ' +
'--offset, --limit, --all, and --name are only available in API 2.1', err=True)
Expand All @@ -168,7 +168,7 @@ def list_cli(api_client, output, job_type, version, expand_tasks, offset, limit,
while has_more:
jobs_json = jobs_api.list_jobs(job_type=job_type, expand_tasks=expand_tasks,
offset=offset, limit=limit, version=version,
name_filter=name_filter)
name=name)
jobs += jobs_json['jobs'] if 'jobs' in jobs_json else []
has_more = jobs_json.get('has_more', False) and _all
if has_more:
Expand Down
63 changes: 56 additions & 7 deletions databricks_cli/sdk/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def create_job(
access_control_list=None,
pipeline_task=None,
python_wheel_task=None,
sql_task=None,
webhook_notifications=None,
continuous=None,
):
_data = {}
if name is not None:
Expand Down Expand Up @@ -147,6 +150,22 @@ def create_job(
raise TypeError(
'Expected databricks.PythonWheelTask() or dict for field python_wheel_task'
)
if sql_task is not None:
_data['sql_task'] = sql_task
if not isinstance(sql_task, dict):
raise TypeError('Expected databricks.SqlTask() or dict for field sql_task')
if webhook_notifications is not None:
_data['webhook_notifications'] = webhook_notifications
if not isinstance(webhook_notifications, dict):
raise TypeError(
'Expected databricks.WebhookNotifications() or dict for field webhook_notifications'
)
if continuous is not None:
_data['continuous'] = continuous
if not isinstance(continuous, dict):
raise TypeError(
'Expected databricks.ContinuousSettings() or dict for field continuous'
)
return self.client.perform_query(
'POST', '/jobs/create', data=_data, headers=headers, version=version
)
Expand All @@ -172,6 +191,8 @@ def submit_run(
access_control_list=None,
pipeline_task=None,
python_wheel_task=None,
sql_task=None,
webhook_notifications=None,
):
_data = {}
if run_name is not None:
Expand Down Expand Up @@ -238,6 +259,16 @@ def submit_run(
raise TypeError(
'Expected databricks.PythonWheelTask() or dict for field python_wheel_task'
)
if sql_task is not None:
_data['sql_task'] = sql_task
if not isinstance(sql_task, dict):
raise TypeError('Expected databricks.SqlTask() or dict for field sql_task')
if webhook_notifications is not None:
_data['webhook_notifications'] = webhook_notifications
if not isinstance(webhook_notifications, dict):
raise TypeError(
'Expected databricks.WebhookNotifications() or dict for field webhook_notifications'
)
return self.client.perform_query(
'POST', '/jobs/runs/submit', data=_data, headers=headers, version=version
)
Expand Down Expand Up @@ -278,28 +309,35 @@ def delete_job(self, job_id, headers=None, version=None):
'POST', '/jobs/delete', data=_data, headers=headers, version=version
)

def get_job(self, job_id, headers=None, version=None):
def get_job(self, job_id, headers=None, version=None, include_trigger_history=None):
_data = {}
if job_id is not None:
_data['job_id'] = job_id
if include_trigger_history is not None:
_data['include_trigger_history'] = include_trigger_history
return self.client.perform_query(
'GET', '/jobs/get', data=_data, headers=headers, version=version
)

def list_jobs(
self, job_type=None, expand_tasks=None, limit=None, offset=None, headers=None, version=None, name_filter=None
self,
job_type=None,
expand_tasks=None,
limit=None,
offset=None,
headers=None,
version=None,
name=None,
):
_data = {}
if job_type is not None:
_data['job_type'] = job_type
if expand_tasks is not None:
_data['expand_tasks'] = expand_tasks
if limit is not None:
_data['limit'] = limit
if offset is not None:
_data['offset'] = offset
if name_filter is not None:
_data['name'] = name_filter
if name is not None:
_data['name'] = name
return self.client.perform_query(
'GET', '/jobs/list', data=_data, headers=headers, version=version
)
Expand Down Expand Up @@ -359,6 +397,8 @@ def repair(
version=None,
dbt_commands=None,
pipeline_params=None,
rerun_all_failed_tasks=None,
rerun_dependent_tasks=None,
):
_data = {}
if run_id is not None:
Expand All @@ -385,6 +425,10 @@ def repair(
raise TypeError(
'Expected databricks.PipelineParameters() or dict for field pipeline_params'
)
if rerun_all_failed_tasks is not None:
_data['rerun_all_failed_tasks'] = rerun_all_failed_tasks
if rerun_dependent_tasks is not None:
_data['rerun_dependent_tasks'] = rerun_dependent_tasks
return self.client.perform_query(
'POST', '/jobs/runs/repair', data=_data, headers=headers, version=version
)
Expand All @@ -402,6 +446,7 @@ def list_runs(
expand_tasks=None,
start_time_from=None,
start_time_to=None,
page_token=None,
):
_data = {}
if job_id is not None:
Expand All @@ -422,6 +467,8 @@ def list_runs(
_data['start_time_from'] = start_time_from
if start_time_to is not None:
_data['start_time_to'] = start_time_to
if page_token is not None:
_data['page_token'] = page_token
return self.client.perform_query(
'GET', '/jobs/runs/list', data=_data, headers=headers, version=version
)
Expand Down Expand Up @@ -452,10 +499,12 @@ def cancel_run(self, run_id, headers=None, version=None):
'POST', '/jobs/runs/cancel', data=_data, headers=headers, version=version
)

def cancel_all_runs(self, job_id, headers=None, version=None):
def cancel_all_runs(self, job_id=None, headers=None, version=None, all_queued_runs=None):
_data = {}
if job_id is not None:
_data['job_id'] = job_id
if all_queued_runs is not None:
_data['all_queued_runs'] = all_queued_runs
return self.client.perform_query(
'POST', '/jobs/runs/cancel-all', data=_data, headers=headers, version=version
)
Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_list_jobs():
'GET', '/jobs/list', data={}, headers=None, version='3.0'
)

api.list_jobs(version='2.1', name_filter='foo')
api.list_jobs(version='2.1', name='foo')
api_client_mock.perform_query.assert_called_with(
'GET', '/jobs/list', data={'name':'foo'}, headers=None, version='2.1'
)
Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_list_name(jobs_api_mock):
runner = CliRunner()
result = runner.invoke(cli.list_cli, ['--version=2.1', '--name', 'foo'])
assert result.exit_code == 0
assert jobs_api_mock.list_jobs.call_args[1]['name_filter'] == 'foo'
assert jobs_api_mock.list_jobs.call_args[1]['name'] == 'foo'
assert jobs_api_mock.list_jobs.call_args[1]['version'] == '2.1'

@provide_conf
Expand Down

0 comments on commit dae9d9a

Please sign in to comment.