From ee989f5cf28e8d49a28f6319a30bef7f8b5a1c10 Mon Sep 17 00:00:00 2001 From: sesky4 Date: Thu, 12 Sep 2024 23:48:17 +0800 Subject: [PATCH] feat: support SSO (#86) * feat: support plugins * fix: add test plugin * fix: ServiceRegister -> register_service * doc: plugin * feat: auth * Update add.py * fix: login * feat: oauth * fix: oauth * Update oauth.py * fix: auth login * feat: logout * Update configure.py * fix: py3 compatibility * Update login.py * Update login.py * Update login.py --- tccli/command.py | 8 +- tccli/configure.py | 14 +++- tccli/loaders.py | 51 +++++++++-- tccli/oauth.py | 115 +++++++++++++++++++++++++ tccli/plugin.py | 30 +++++++ tccli/plugins/__init__.py | 0 tccli/plugins/auth/__init__.py | 59 +++++++++++++ tccli/plugins/auth/browser_flow.py | 88 +++++++++++++++++++ tccli/plugins/auth/login.py | 130 +++++++++++++++++++++++++++++ tccli/plugins/auth/logout.py | 21 +++++ tccli/plugins/auth/texts.py | 30 +++++++ tccli/plugins/test/__init__.py | 87 +++++++++++++++++++ tccli/plugins/test/add.py | 31 +++++++ 13 files changed, 652 insertions(+), 12 deletions(-) create mode 100644 tccli/oauth.py create mode 100644 tccli/plugin.py create mode 100644 tccli/plugins/__init__.py create mode 100644 tccli/plugins/auth/__init__.py create mode 100644 tccli/plugins/auth/browser_flow.py create mode 100644 tccli/plugins/auth/login.py create mode 100644 tccli/plugins/auth/logout.py create mode 100644 tccli/plugins/auth/texts.py create mode 100644 tccli/plugins/test/__init__.py create mode 100644 tccli/plugins/test/add.py diff --git a/tccli/command.py b/tccli/command.py index c313400a23..e364891b2a 100644 --- a/tccli/command.py +++ b/tccli/command.py @@ -6,6 +6,8 @@ import tccli.services as Services import tccli.options_define as Options_define from collections import OrderedDict + +from tccli import oauth from tccli.utils import Utils from tccli.argument import CLIArgument, CustomArgument, ListArgument, BooleanArgument from tccli.exceptions import UnknownArgumentError @@ -176,12 +178,15 @@ def _build_command_map(self): service_model = self._get_service_model() for action in service_model["actions"]: action_model = service_model["actions"][action] + action_caller = action_model.get("action_caller", None) + if not action_caller: + action_caller = Services.action_caller(self._service_name)()[action] command_map[action] = ActionCommand( service_name=self._service_name, version=self._version, action_name=action, action_model=action_model, - action_caller=Services.action_caller(self._service_name)()[action], + action_caller=action_caller, ) return command_map @@ -286,6 +291,7 @@ def __call__(self, args, parsed_globals): action_parameters = self.cli_unfold_argument.build_action_parameters(parsed_args) else: action_parameters = self._build_action_parameters(parsed_args, self.argument_map) + oauth.maybe_refresh_credential(parsed_globals.profile if parsed_globals.profile else "default") return self._action_caller(action_parameters, vars(parsed_globals)) def create_help_command(self): diff --git a/tccli/configure.py b/tccli/configure.py index d072995ca6..66552c0267 100644 --- a/tccli/configure.py +++ b/tccli/configure.py @@ -416,16 +416,24 @@ def _run_main(self, parsed_args, parsed_globals): def init_configures(self): config = {} - if not self._profile_existed("default.configure")[0]: + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--profile", type=str) + args, _ = parser.parse_known_args() + profile = args.profile or "default" + profile_file = "%s.configure" % profile + + if not self._profile_existed(profile_file)[0]: config = { "region": "ap-guangzhou", "output": "json", "arrayCount": 10, "warning": "off" } - self._init_configure("default.configure", config) + self._init_configure(profile_file, config) + for profile_name in os.listdir(self.cli_path): - if profile_name == "default.configure": + if profile_name == profile_file: continue if profile_name.endswith(".configure"): self._init_configure(profile_name, {}) diff --git a/tccli/loaders.py b/tccli/loaders.py index b4a02425bc..ef39b5e2f6 100644 --- a/tccli/loaders.py +++ b/tccli/loaders.py @@ -9,6 +9,7 @@ from tccli import __version__ from tccli.services import SERVICE_VERSIONS from collections import OrderedDict +import tccli.plugin as plugin BASE_TYPE = ["int64", "uint64", "string", "float", "bool", "date", "datetime", "datetime_iso", "binary"] CLI_BASE_TYPE = ["Integer", "String", "Float", "Timestamp", "Boolean", "Binary"] @@ -175,7 +176,15 @@ def _version_transform(self, version): return version[1:5] + "-" + version[5:7] + "-" + version[7:9] def get_available_services(self): - return SERVICE_VERSIONS + services = copy.deepcopy(SERVICE_VERSIONS) + for name, vers in plugin.import_plugins().items(): + if name not in services: + services[name] = [] + for ver, spec in vers.items(): + api_ver = spec["metadata"]["apiVersion"] + if api_ver not in services[name]: + services[name].append(api_ver) + return services def get_service_default_version(self, service): args = sys.argv[1:] @@ -194,15 +203,41 @@ def get_service_model(self, service, version): services_path = self.get_services_path() version = "v" + version.replace('-', '') apis_path = os.path.join(services_path, service, version, "api.json") - if not os.path.exists(apis_path): + model = { + "metadata": {}, + "actions": {}, + "objects": {}, + } + if os.path.exists(apis_path): + if six.PY2: + with open(apis_path, 'r') as f: + model = json.load(f) + else: + with open(apis_path, 'r', encoding='utf-8') as f: + model = json.load(f) + + # merge plugins + for plugin_name, vers in plugin.import_plugins().items(): + + if plugin_name != service: + continue + + for ver, spec in vers.items(): + + # 2017-03-12 -> v20170312 + compact_ver = 'v' + ver.replace('-', '') + + if compact_ver != version: + continue + + model["metadata"].update(spec["metadata"]) + model["actions"].update(spec["actions"]) + model["objects"].update(spec["objects"]) + + if not model: raise Exception("Not find service:%s version:%s model" % (service, version)) - if six.PY2: - with open(apis_path, 'r') as f: - return json.load(f) - else: - with open(apis_path, 'r', encoding='utf-8') as f: - return json.load(f) + return model def get_service_description(self, service, version): service_model = self.get_service_model(service, version) diff --git a/tccli/oauth.py b/tccli/oauth.py new file mode 100644 index 0000000000..34481b8cfd --- /dev/null +++ b/tccli/oauth.py @@ -0,0 +1,115 @@ +import json +import os +import time + +import requests +import uuid + +_API_ENDPOINT = "https://cli.cloud.tencent.com" +_CRED_REFRESH_SAFE_DUR = 60 * 5 +_ACCESS_REFRESH_SAFE_DUR = 60 * 5 + + +def maybe_refresh_credential(profile): + cred_path = cred_path_of_profile(profile) + try: + with open(cred_path, "r") as cred_file: + cred = json.load(cred_file) + except IOError: + # file not found, don't check + return + + if cred.get("type") != "oauth": + return + + try: + now = time.time() + + expires_at = cred["expiresAt"] + if expires_at - now > _CRED_REFRESH_SAFE_DUR: + return + + token_info = cred["oauth"] + site = token_info["site"] + access_expires = token_info["expiresAt"] + if access_expires - now < _ACCESS_REFRESH_SAFE_DUR: + refresh_token = token_info["refreshToken"] + open_id = token_info["openId"] + new_token = refresh_user_token(refresh_token, open_id, site) + token_info.update(new_token) + + access_token = token_info["accessToken"] + new_cred = get_temp_cred(access_token, site) + save_credential(token_info, new_cred, profile) + + except KeyError as e: + print("failed to refresh credential, your credential file(%s) is corrupted, %s" % (cred_path, e)) + + except Exception as e: + print("failed to refresh credential, %s" % e) + + +def refresh_user_token(ref_token, open_id, site): + api_endpoint = _API_ENDPOINT + "/refresh_user_token" + body = { + "TraceId": str(uuid.uuid4()), + "RefreshToken": ref_token, + "OpenId": open_id, + "Site": site, + } + http_response = requests.post(api_endpoint, json=body, verify=False) + resp = http_response.json() + + if "Error" in resp: + raise ValueError("refresh_user_token: %s" % json.dumps(resp)) + + return { + "accessToken": resp["AccessToken"], + "expiresAt": resp["ExpiresAt"], + } + + +def get_temp_cred(access_token, site): + api_endpoint = _API_ENDPOINT + "/get_temp_cred" + body = { + "TraceId": str(uuid.uuid4()), + "AccessToken": access_token, + "Site": site, + } + http_response = requests.post(api_endpoint, json=body, verify=False) + resp = http_response.json() + + if "Error" in resp: + raise ValueError("get_temp_key: %s" % json.dumps(resp)) + + return { + "secretId": resp["SecretId"], + "secretKey": resp["SecretKey"], + "token": resp["Token"], + "expiresAt": resp["ExpiresAt"], + } + + +def cred_path_of_profile(profile): + return os.path.join(os.path.expanduser("~"), ".tccli", profile + ".credential") + + +def save_credential(token, new_cred, profile): + cred_path = cred_path_of_profile(profile) + + cred = { + "type": "oauth", + "secretId": new_cred["secretId"], + "secretKey": new_cred["secretKey"], + "token": new_cred["token"], + "expiresAt": new_cred["expiresAt"], + "oauth": { + "openId": token["openId"], + "accessToken": token["accessToken"], + "expiresAt": token["expiresAt"], + "refreshToken": token["refreshToken"], + "site": token["site"], + }, + } + with open(cred_path, "w") as cred_file: + json.dump(cred, cred_file, indent=4) diff --git a/tccli/plugin.py b/tccli/plugin.py new file mode 100644 index 0000000000..81a10b28a7 --- /dev/null +++ b/tccli/plugin.py @@ -0,0 +1,30 @@ +import importlib +import logging +import pkgutil + +import tccli.plugins as plugins + +_plugins = {} +_imported = False + + +def import_plugins(): + global _imported + + if not _imported: + _reimport_plugins() + _imported = True + + return _plugins + + +def _reimport_plugins(): + for _, name, _ in pkgutil.iter_modules(plugins.__path__, plugins.__name__ + "."): + mod = importlib.import_module(name) + register_service = getattr(mod, "register_service", None) + if not register_service: + logging.warning("invalid module %s" % name) + continue + register_service(_plugins) + + return _plugins diff --git a/tccli/plugins/__init__.py b/tccli/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tccli/plugins/auth/__init__.py b/tccli/plugins/auth/__init__.py new file mode 100644 index 0000000000..8a8e4feaee --- /dev/null +++ b/tccli/plugins/auth/__init__.py @@ -0,0 +1,59 @@ +# encoding: utf-8 +from tccli.plugins.auth.login import login_command_entrypoint +from tccli.plugins.auth.logout import logout_command_entrypoint + +service_name = "auth" +service_version = "2024-08-20" + +_spec = { + "metadata": { + "serviceShortName": service_name, + "apiVersion": service_version, + "description": "auth related commands", + }, + "actions": { + "login": { + "name": "登陆", + "document": "login through sso", + "input": "loginRequest", + "output": "loginResponse", + "action_caller": login_command_entrypoint, + }, + "logout": { + "name": "登出", + "document": "remove local credential file", + "input": "logoutRequest", + "output": "logoutResponse", + "action_caller": logout_command_entrypoint, + }, + }, + "objects": { + "loginRequest": { + "members": [ + { + "name": "browser", + "member": "string", + "type": "string", + "required": False, + "document": "use browser=no to indicate no browser login mode", + }, + ], + }, + "loginResponse": { + "members": [], + }, + "logoutRequest": { + "members": [], + }, + "logoutResponse": { + "members": [], + }, + }, + "version": "1.0", +} + + +def register_service(specs): + specs[service_name] = { + service_version: _spec, + } diff --git a/tccli/plugins/auth/browser_flow.py b/tccli/plugins/auth/browser_flow.py new file mode 100644 index 0000000000..1fabc991ca --- /dev/null +++ b/tccli/plugins/auth/browser_flow.py @@ -0,0 +1,88 @@ +import time +import traceback +import socket +from threading import Thread +from tccli import oauth + +try: + from urlparse import urlparse, parse_qs + from Queue import Queue + from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler + from SocketServer import ThreadingTCPServer, TCPServer +except ImportError: + from urllib.parse import urlparse, parse_qs + from queue import Queue + from http.server import HTTPServer, BaseHTTPRequestHandler + from socketserver import ThreadingTCPServer, TCPServer + + +# chrome keeps previous connection alive, so use threading to avoid blocking +class ThreadingHTTPServer(ThreadingTCPServer): + allow_reuse_address = 1 + + def server_bind(self): + """Override server_bind to store the server name.""" + TCPServer.server_bind(self) + host, port = self.socket.getsockname()[:2] + self.server_name = socket.getfqdn(host) + self.server_port = port + + +class HTTPHandler(BaseHTTPRequestHandler): + result_queue = Queue(1) + + def do_GET(self): + try: + parsed_url = urlparse(self.path) + query_vals = parse_qs(parsed_url.query) + + open_id = query_vals.get("open_id")[0] + access_token = query_vals.get("access_token")[0] + refresh_token = query_vals.get("refresh_token")[0] + expires_at = int(query_vals.get("expires_at")[0]) + state = query_vals.get("state")[0] + redirect_url = query_vals.get("redirect_url")[0] + site = query_vals.get("site")[0] + token = { + "openId": open_id, + "accessToken": access_token, + "refreshToken": refresh_token, + "expiresAt": expires_at, + "state": state, + "site": site, + } + cred = oauth.get_temp_cred(token["accessToken"], token["site"]) + self.result_queue.put((token, cred)) + self.send_response(307) + self.send_header("Location", redirect_url) + self.end_headers() + except Exception: + err = traceback.format_exc() + print(err) + self.send_response(400) + self.end_headers() + self.wfile.write("login failed due to the following error:\n\n".encode("utf-8")) + self.wfile.write(err.encode("utf-8")) + self.wfile.flush() + + # suppress debug message + def log_message(self, format, *args): + pass + + +def try_run(start_search_port, end_search_port): + port = start_search_port + + while port <= end_search_port: + server_address = ('', port) + try: + ThreadingHTTPServer.daemon_threads = True + httpd = ThreadingHTTPServer(server_address, HTTPHandler) + t = Thread(target=httpd.serve_forever) + t.setDaemon(True) + t.start() + return port, HTTPHandler.result_queue + except socket.error: + port += 1 + + raise socket.error("no port available from range [%d, %d]" % (start_search_port, end_search_port)) diff --git a/tccli/plugins/auth/login.py b/tccli/plugins/auth/login.py new file mode 100644 index 0000000000..f37a8b9fb9 --- /dev/null +++ b/tccli/plugins/auth/login.py @@ -0,0 +1,130 @@ +# coding: utf-8 +import base64 +import json +import random +import string +import sys +import time +from six.moves.urllib.parse import urlencode +import webbrowser + +from tccli import oauth +from tccli.plugins.auth import texts + +_APP_ID = 100038427476 +_AUTH_URL = "https://cloud.tencent.com/open/authorize" +_REDIRECT_URL = "https://cli.cloud.tencent.com/oauth" +_SITE = "cn" +_DEFAULT_LANG = "zh-CN" + +_START_SEARCH_PORT = 9000 +_END_SEARCH_PORT = _START_SEARCH_PORT + 100 + + +def print_message(msg): + print(msg) + sys.stdout.flush() + + +def login_command_entrypoint(args, parsed_globals): + language = parsed_globals.get("language") + if not language: + language = _DEFAULT_LANG + texts.set_lang(language) + + profile = parsed_globals.get("profile", "default") + if not profile: + profile = "default" + + browser = args.get("browser") + + login(browser != "no", profile, language) + + +def login(use_browser, profile, language): + characters = string.ascii_letters + string.digits + state = ''.join(random.choice(characters) for _ in range(10)) + + if use_browser: + token, cred = _get_token(state, language) + else: + token, cred = _get_token_no_browser(state, language) + + if token["state"] != state: + raise ValueError("invalid state %s" % token["state"]) + + oauth.save_credential(token, cred, profile) + + print_message("") + print_message(texts.get("login_success") % oauth.cred_path_of_profile(profile)) + + +def _get_token(state, language): + from tccli.plugins.auth import browser_flow + + port, result_queue = browser_flow.try_run(_START_SEARCH_PORT, _END_SEARCH_PORT) + + redirect_params = { + "redirect_url": "http://localhost:%d" % port, + "lang": language, + "site": _SITE, + } + redirect_query = urlencode(redirect_params) + redirect_url = _REDIRECT_URL + "?" + redirect_query + url_params = { + "scope": "login", + "app_id": _APP_ID, + "redirect_url": redirect_url, + "state": state, + } + url_query = urlencode(url_params) + auth_url = _AUTH_URL + "?" + url_query + + if not webbrowser.open(auth_url): + print_message(texts.get("login_failed_due_to_no_browser")) + sys.exit(1) + + print_message(texts.get("login_prompt")) + print_message(auth_url) + + # use polling to avoid being unresponsive in python2 + while result_queue.empty(): + time.sleep(1) + + result = result_queue.get() + if isinstance(result, Exception): + raise result + + return result + + +def _get_token_no_browser(state, language): + redirect_params = { + "browser": "no", + "lang": language, + "site": _SITE, + } + redirect_query = urlencode(redirect_params) + redirect_url = _REDIRECT_URL + "?" + redirect_query + url_params = { + "scope": "login", + "app_id": _APP_ID, + "redirect_url": redirect_url, + "state": state, + } + url_query = urlencode(url_params) + auth_url = _AUTH_URL + "?" + url_query + + print_message(texts.get("login_prompt_no_browser")) + print_message("") + print_message(auth_url) + + try: + input_func = raw_input + except NameError: + input_func = input + + user_input = input_func(texts.get("login_prompt_code_no_browser")) + token = json.loads(base64.b64decode(user_input)) + cred = oauth.get_temp_cred(token["accessToken"], token["site"]) + return token, cred diff --git a/tccli/plugins/auth/logout.py b/tccli/plugins/auth/logout.py new file mode 100644 index 0000000000..744e003d22 --- /dev/null +++ b/tccli/plugins/auth/logout.py @@ -0,0 +1,21 @@ +# coding: utf-8 +import os + +from tccli import oauth +from tccli.plugins.auth import texts + + +def logout_command_entrypoint(args, parsed_globals): + language = parsed_globals.get("language") + if not language: + language = "zh-CN" + texts.set_lang(language) + + profile = parsed_globals.get("profile", "default") + if not profile: + profile = "default" + + cred_path = oauth.cred_path_of_profile(profile) + if os.path.exists(cred_path): + os.remove(cred_path) + print(texts.get("logout") % cred_path) diff --git a/tccli/plugins/auth/texts.py b/tccli/plugins/auth/texts.py new file mode 100644 index 0000000000..75f556bedf --- /dev/null +++ b/tccli/plugins/auth/texts.py @@ -0,0 +1,30 @@ +# encoding: utf-8 +_lang = "zh-CN" + +texts = { + "zh-CN": { + "login_prompt": "您的浏览器已打开, 请根据提示完成登录", + "login_prompt_no_browser": "在浏览器中转到以下链接, 并根据提示完成登录:", + "login_prompt_code_no_browser": "完成后,输入浏览器中提供的验证码:", + "login_failed_due_to_no_browser": "无法打开浏览器, 请尝试添加 '--browser no' 选项", + "login_success": "登陆成功, 密钥凭证已被写入: %s", + "logout": "登出成功, 密钥凭证已被删除: %s", + }, + "en-US": { + "login_prompt": "Your browser is open, please complete the login according to the prompts", + "login_prompt_no_browser": "Go to the following link in your browser, and complete the sign-in prompts:", + "login_prompt_code_no_browser": "Once finished, enter the verification code provided in your browser:", + "login_failed_due_to_no_browser": "Failed to launch browser, please try option '--browser no'", + "login_success": "Login succeed, credential has been written to %s", + "logout": "Logout succeed, credential has been removed: %s", + } +} + + +def set_lang(lang): + global _lang + _lang = lang + + +def get(key): + return texts[_lang][key] diff --git a/tccli/plugins/test/__init__.py b/tccli/plugins/test/__init__.py new file mode 100644 index 0000000000..59675b9c9d --- /dev/null +++ b/tccli/plugins/test/__init__.py @@ -0,0 +1,87 @@ +# encoding: utf-8 +""" +如何自定义插件 +1. 定义一个 spec 对象,参考以下代码 +2. export 一个函数,名字叫做 register_service,它负责向 cli 注册 spec + +注册完成后,用户可以通过 tccli {服务名} {接口名} --{参数1}={参数值} ... 的方式调用 + +如下所示: + 定义了一个服务名叫 test, 包含一个接口 add, add 接口接受 2 个参数 number1, number2 + 则用户可以以如下方式调用 tccli + tccli test add --number1=3 --number2=5 +""" +from tccli.plugins.test.add import add_command + +service_name = "test" +service_version = "2024-08-07" + +_spec = { + "metadata": { + # 产品名 + "serviceShortName": service_name, + # 产品版本号 + "apiVersion": service_version, + # 产品介绍 + "description": "this is a test module", + }, + # 产品所有支持的接口 + "actions": { + # 接口名推荐用小写,避免和云 API 的接口名冲突 + "add": { + # 接口中文名 + "name": "测试接口", + # 接口说明 + "document": "this is an test action", + # 入参对象名,在 objects 中详细定义入参结构 + "input": "addRequest", + # 出参对象名,在 objects 中详细定义出参结构 + "output": "addResponse", + # 实际调用函数 + "action_caller": add_command, # the function to call + }, + }, + "objects": { + "addRequest": { + "members": [ + { + # 参数名 + "name": "number1", + # int64, uint64, string, float, bool, date, datetime, datetime_iso, binary + "member": "int64", + # same as member + "type": "int64", + # 是否必传 + "required": True, + # 参数说明 + "document": "doc about number1", + }, + { + "name": "number2", + "member": "int64", + "type": "int64", + "required": True, + "document": "doc about number2", + }, + ], + }, + "addResponse": { + "members": [ + { + "name": "sum", + "member": "int64", + "type": "int64", + "required": True, + "document": "doc about arg1", + }, + ], + }, + }, + "version": "1.0", +} + + +def register_service(specs): + specs[service_name] = { + service_version: _spec, + } diff --git a/tccli/plugins/test/add.py b/tccli/plugins/test/add.py new file mode 100644 index 0000000000..d3475a6345 --- /dev/null +++ b/tccli/plugins/test/add.py @@ -0,0 +1,31 @@ +# coding: utf-8 +import json +import logging + +from tencentcloud.common import credential +from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException +from tencentcloud.cvm.v20170312 import cvm_client, models + + +def add_command(args, parsed_globals): + # get arguments from args + number1 = args["number1"] + number2 = args["number2"] + print("%d + %d = %d\n" % (number1, number2, number1 + number2)) + + # get secret key from parsed_globals + secret_id = parsed_globals["secretId"] + secret_key = parsed_globals["secretKey"] + token = parsed_globals["token"] + region = parsed_globals["region"] or "ap-guangzhou" + + # do api call with secret key + cred = credential.Credential(secret_id, secret_key, token) + cli = cvm_client.CvmClient(cred, region) + + req = models.DescribeInstancesRequest() + try: + resp = cli.DescribeInstances(req) + print(resp.to_json_string(indent=2)) + except TencentCloudSDKException as e: + logging.exception(e)