Skip to content

Commit

Permalink
Added oauth CLI flow
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesd-db committed Jun 27, 2023
1 parent 8fe7aad commit 472e043
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 19 deletions.
164 changes: 147 additions & 17 deletions databricks_cli/unity_catalog/connection_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
# limitations under the License.

import functools
import http.server
import os
import socketserver
import ssl
import webbrowser
import oauthlib.oauth2
from oauthlib.common import generate_token


import click

Expand Down Expand Up @@ -51,8 +59,6 @@ def common_create_args(f):
help='Host of new connection')
@click.option('--port', default=None,
help='Port of new connection')
@click.option('--user', default=None,
help='Username for authorization of new connection')
@functools.wraps(f)
def wrapper(*args, **kwargs):
f(*args, **kwargs)
Expand All @@ -70,10 +76,64 @@ def wrapper(*args, **kwargs):
f(*args, **kwargs)
return wrapper


redirect_uri = 'https://localhost:8771'
return_query_res = ""

class AccessCodeRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()

script_path = os.path.abspath(__file__)
html_file_path = os.path.join(os.path.dirname(script_path), 'response.html')

with open(html_file_path, 'rb') as file:
self.wfile.write(file.read())

global return_query_res
return_query_res = self.path

#For quiet log messages
def log_message(self, format, *args):
# Override log_message to suppress output
pass

def run_oauth_response_server():
server_address = ('', 8771)
httpd = socketserver.TCPServer(server_address, AccessCodeRequestHandler)

ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(certfile="localhost.pem", keyfile="localhost-key.pem")
httpd.socket = ssl_context.wrap_socket(httpd.socket, server_side=True)

try:
httpd.handle_request()
except KeyboardInterrupt:
pass

def get_auth_code(host, client_id, scope):
oauth = oauthlib.oauth2.WebApplicationClient(client_id)
state = generate_token()
verifier = oauth.create_code_verifier(96)
challenge = oauth.create_code_challenge(verifier, 'S256')

authorization_url = oauth.prepare_request_uri(
'https://' + host + '/oauth/authorize', redirect_uri = redirect_uri, scope = scope, state = state, code_challenge = challenge, code_challenge_method = 'S256')
webbrowser.open_new(authorization_url)
res = run_oauth_response_server()
parsed_result = oauth.parse_request_uri_response(redirect_uri + return_query_res, state = state)
parsed_result['code_verifier'] = verifier
return parsed_result


@click.command(context_settings=CONTEXT_SETTINGS,
short_help='Create mysql connection with CLI flags.')
@common_create_args
@create_update_common_options
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -104,6 +164,8 @@ def create_mysql_cli(api_client, name, host, port, user,
short_help='Create postgresql connection with CLI flags.')
@common_create_args
@create_update_common_options
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -135,6 +197,8 @@ def create_postgresql_cli(api_client, name, host, port, user,
@create_update_common_options
@click.option('--sfwarehouse', default=None,
help='Snowflake warehouse name of new connection')
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -167,6 +231,8 @@ def create_snowflake_cli(api_client, name, host, port, user, sfwarehouse,
short_help='Create redshift connection with CLI flags.')
@common_create_args
@create_update_common_options
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -199,6 +265,8 @@ def create_redshift_cli(api_client, name, host, port, user,
@create_update_common_options
@click.option('--trustservercert', is_flag=True, default=None,
help='Trust the server provided certificate')
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -233,6 +301,8 @@ def create_sqldw_cli(api_client, name, host, port, user, trustservercert,
@create_update_common_options
@click.option('--trustservercert', is_flag=True, default=None,
help='Trust the server provided certificate')
@click.option('--user', default=None,
help='Username for authorization of new connection')
@click.option(
"--password", prompt=True, hide_input=True,
confirmation_prompt=True
Expand Down Expand Up @@ -295,31 +365,46 @@ def create_databricks_cli(api_client, name, host, httppath, token,
con_json = UnityCatalogApi(api_client).create_connection(data)
click.echo(mc_pretty_format(con_json))

#OAuth CLIs below

@click.command(context_settings=CONTEXT_SETTINGS,
short_help='Create online catalog connection with CLI flags.')
short_help='Create snowflake oauth connection with CLI flags.')
@common_create_args
@create_update_common_options
@click.option('--sfwarehouse', default=None,
help='Snowflake warehouse name of new connection')
@click.option('--scope', default=None,
help='Scope of new OAuth connection. Should be a single \
quoted string with separate options separated by spaces')
@click.option('--client-id', default=None,
help='Client ID for new connection')
@click.option(
"--client-secret", prompt=True, hide_input=True)
@debug_option
@eat_exceptions
@profile_option
@provide_api_client
def create_online_catalog_cli(api_client, name, host, port, user,
read_only, comment):
def create_snowflake_oauth_cli(api_client, name, host, port, client_id, sfwarehouse,
read_only, comment, scope, client_secret):
"""
Create new online catalog connection.
Create new snowflake U2M oauth connection.
"""
if (name is None) or (host is None) or (port is None) or (user is None):
raise ValueError('Must provide all required connection parameters')
if (name is None) or (host is None) or (port is None) or (client_id is None) or \
(client_secret is None) or (sfwarehouse is None):
raise ValueError('Must provide all required oauth connection parameters')

code_dict = get_auth_code(host, client_id, scope)
data = {
'name': name,
'connection_type': 'ONLINE_CATALOG',
'options': {'host': host, 'port': port, 'user': user},
'connection_type': 'SNOWFLAKE',
'options': {'host': host, 'port': port, 'client_id': client_id, 'client_secret': client_secret, 'state': code_dict['state'],
'code': code_dict['code'], 'sfWarehouse': sfwarehouse, 'redirect_uri': redirect_uri, 'code_verifier': code_dict['code_verifier']},
'read_only': read_only,
'comment': comment,
'comment': comment
}
con_json = UnityCatalogApi(api_client).create_connection(data)
click.echo(mc_pretty_format(con_json))
click.echo(mc_pretty_format(data))
#con_json = UnityCatalogApi(api_client).create_connection(data)
#click.echo(mc_pretty_format(con_json))


@click.command(context_settings=CONTEXT_SETTINGS,
Expand All @@ -331,7 +416,7 @@ def create_online_catalog_cli(api_client, name, host, port, user,
@provide_api_client
def create_json(api_client, json_file, json):
'''
Create new connection with an inline JSON or JSON file input.
Create new connection with an inline JSON or JSON file path.
'''
if (json is None) and (json_file is None):
raise ValueError('Must either provide inline JSON or JSON file.')
Expand All @@ -348,7 +433,7 @@ def create_json(api_client, json_file, json):
@profile_option
@eat_exceptions
@provide_api_client
def list_connections_cli(api_client, ):
def list_connections_cli(api_client):
"""
List connections.
"""
Expand Down Expand Up @@ -387,6 +472,48 @@ def delete_connection_cli(api_client, name):
UnityCatalogApi(api_client).delete_connection(name)


@click.command(context_settings=CONTEXT_SETTINGS,
short_help='Update a connection.')
@click.option('--name', required=True,
help='Name of the connection to update.')
@click.option('--new-name', default=None, help='New name of the connection.')
@create_update_common_options
@click.option('--owner', default=None,
help='Owner of the connection.')

@click.option('--json-file', default=None, type=click.Path(),
help=json_file_help(method='PATCH', path='/connections/{name}'))
@click.option('--json', default=None, type=JsonClickType(),
help=json_string_help(method='PATCH', path='/connections/{name}'))
@debug_option
@profile_option
@eat_exceptions
@provide_api_client
def update_connection_cli(api_client, name, new_name, read_only,
comment, owner, json_file, json):
"""
Update an connection.
The public specification for the JSON request is in development.
"""
if ((new_name is not None) or
(read_only is not None) or (comment is not None)):
if (json_file is not None) or (json is not None):
raise ValueError('Cannot specify JSON if any other update flags are specified')
data = {
'name': new_name,
'read_only': read_only,
'comment': comment,
'owner': owner
}
loc_json = UnityCatalogApi(api_client).update_connnection(
name, data)
click.echo(mc_pretty_format(loc_json))
else:
json_cli_base(json_file, json,
lambda json: UnityCatalogApi(api_client).update_connnection(name, json),
encode_utf8=True)

@click.group()
def create_group(): # pragma: no cover
pass
Expand All @@ -402,14 +529,16 @@ def register_connection_commands(cmd_group):
cmd_group.add_command(hide(create_mysql_cli), name='create-mysql-connection')
cmd_group.add_command(hide(create_postgresql_cli), name='create-postgresql-connection')
cmd_group.add_command(hide(create_snowflake_cli), name='create-snowflake-connection')
cmd_group.add_command(hide(create_snowflake_oauth_cli), name='create-snowflake-oauth-connection')
cmd_group.add_command(hide(create_redshift_cli), name='create-redshift-connection')
cmd_group.add_command(hide(create_sqldw_cli), name='create-sqldw-connection')
cmd_group.add_command(hide(create_sqlserver_cli), name='create-sqlserver-connection')
cmd_group.add_command(hide(create_databricks_cli), name='create-databricks-connection')
cmd_group.add_command(hide(create_online_catalog_cli), name='create-online-catalog-connection')
cmd_group.add_command(hide(list_connections_cli), name='list-connections')
cmd_group.add_command(hide(get_connection_cli), name='get-connection')
cmd_group.add_command(hide(delete_connection_cli), name='delete-connection')
cmd_group.add_command(hide(update_connection_cli), name='update-connection')




Expand All @@ -418,15 +547,16 @@ def register_connection_commands(cmd_group):
create_group.add_command(create_mysql_cli, name='mysql')
create_group.add_command(create_postgresql_cli, name='postgresql')
create_group.add_command(create_snowflake_cli, name='snowflake')
create_group.add_command(create_snowflake_oauth_cli, name='snowflake-oauth')
create_group.add_command(create_redshift_cli, name='redshift')
create_group.add_command(create_sqldw_cli, name='sqldw')
create_group.add_command(create_sqlserver_cli, name='sqlserver')
create_group.add_command(create_databricks_cli, name='databricks')
create_group.add_command(create_online_catalog_cli, name='online-catalog')

connections_group.add_command(create_group, name='create')
connections_group.add_command(create_json, name='create-json')
connections_group.add_command(list_connections_cli, name='list')
connections_group.add_command(get_connection_cli, name='get')
connections_group.add_command(delete_connection_cli, name='delete')
connections_group.add_command(update_connection_cli, name='update')
cmd_group.add_command(connections_group, name='connection')
26 changes: 26 additions & 0 deletions databricks_cli/unity_catalog/response.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<!DOCTYPE html>
<html>
<head>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/sweetalert2@11.0.18/dist/sweetalert2.min.css">
<script src="https://cdn.jsdelivr.net/npm/sweetalert2@11.0.18/dist/sweetalert2.all.min.js"></script>
<style>
body {
background-color: #1B3139;
}
.swal2-title, .swal2-content {
font-family: "Arial", sans-serif;
}
</style>
<script>
document.addEventListener("DOMContentLoaded", () => {
Swal.fire({
icon: "success",
title: "<span style='font-family: Arial, sans-serif;'>Authorization code received</span>",
html: "<span style='font-family: Arial, sans-serif;'>Close this tab and return to the CLI</span>",
showConfirmButton: false
});
});
</script>
</head>
<body></body>
</html>
2 changes: 0 additions & 2 deletions tests/unity_catalog/test_con_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
COMMENT = 'some_comment'

TESTHOST = "test_postgresql.fakedb.com"
TESTHOST2 = "postgresql.fakedb2.lan"
TESTPORT = "1234"
TESTPORT2 = "5678"
TEST_OPTIONS = {
"host": TESTHOST,
"port": TESTPORT,
Expand Down

0 comments on commit 472e043

Please sign in to comment.