diff --git a/README.md b/README.md index 55cfa78d0..f65deb753 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Augur NEW Release v0.80.1 +# Augur NEW Release v0.81.0 Augur is primarily a data engineering tool that makes it possible for data scientists to gather open source software community data - less data carpentry for everyone else! The primary way of looking at Augur data is through [8Knot](https://github.com/oss-aspen/8knot), a public instance of 8Knot is available [here](https://metrix.chaoss.io) - this is tied to a public instance of [Augur](https://ai.chaoss.io). @@ -11,7 +11,8 @@ We follow the [First Timers Only](https://www.firsttimersonly.com/) philosophy o ## NEW RELEASE ALERT! **If you want to jump right in, the updated docker, docker-compose and bare metal installation instructions are available [here](docs/new-install.md)**. -Augur is now releasing a dramatically improved new version to the ```main``` branch. It is also available [here](https://github.com/chaoss/augur/releases/tag/v0.80.1). +<<<<<<< HEAD +Augur is now releasing a dramatically improved new version to the ```main``` branch. It is also available [here](https://github.com/chaoss/augur/releases/tag/v0.81.0). - The `main` branch is a stable version of our new architecture, which features: diff --git a/augur/application/cli/backend.py b/augur/application/cli/backend.py index f470675d1..2b5ec6904 100644 --- a/augur/application/cli/backend.py +++ b/augur/application/cli/backend.py @@ -18,6 +18,8 @@ from augur.tasks.start_tasks import augur_collection_monitor, create_collection_status_records from augur.tasks.git.facade_tasks import clone_repos +from augur.tasks.github.util.github_api_key_handler import GithubApiKeyHandler +from augur.tasks.gitlab.gitlab_api_key_handler import GitlabApiKeyHandler from augur.tasks.data_analysis.contributor_breadth_worker.contributor_breadth_worker import contributor_breadth_model from augur.tasks.init.redis_connection import redis_connection from augur.application.db.models import UserRepo @@ -27,6 +29,7 @@ from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext import sqlalchemy as s +from keyman.KeyClient import KeyClient, KeyPublisher logger = AugurLogger("augur", reset_logfiles=True).get_logger() @@ -116,8 +119,27 @@ def start(ctx, disable_collection, development, pidfile, port): celery_beat_process = None celery_command = f"celery -A augur.tasks.init.celery_app.celery_app beat -l {log_level.lower()}" celery_beat_process = subprocess.Popen(celery_command.split(" ")) - + keypub = KeyPublisher() + if not disable_collection: + orchestrator = subprocess.Popen("python keyman/Orchestrator.py".split()) + + # Wait for orchestrator startup + if not keypub.wait(republish=True): + logger.critical("Key orchestrator did not respond in time") + return + + # load keys + ghkeyman = GithubApiKeyHandler(logger) + glkeyman = GitlabApiKeyHandler(logger) + + for key in ghkeyman.keys: + keypub.publish(key, "github_rest") + keypub.publish(key, "github_graphql") + + for key in glkeyman.keys: + keypub.publish(key, "gitlab_rest") + with DatabaseSession(logger, engine=ctx.obj.engine) as session: clean_collection_status(session) @@ -157,6 +179,7 @@ def start(ctx, disable_collection, development, pidfile, port): if not disable_collection: try: + keypub.shutdown() cleanup_after_collection_halt(logger, ctx.obj.engine) except RedisConnectionError: pass diff --git a/augur/application/cli/collection.py b/augur/application/cli/collection.py index 84bbd5cba..b42f1f3fc 100644 --- a/augur/application/cli/collection.py +++ b/augur/application/cli/collection.py @@ -17,6 +17,8 @@ from augur.tasks.start_tasks import augur_collection_monitor, create_collection_status_records from augur.tasks.git.facade_tasks import clone_repos +from augur.tasks.github.util.github_api_key_handler import GithubApiKeyHandler +from augur.tasks.gitlab.gitlab_api_key_handler import GitlabApiKeyHandler from augur.tasks.data_analysis.contributor_breadth_worker.contributor_breadth_worker import contributor_breadth_model from augur.application.db.models import UserRepo from augur.application.db.session import DatabaseSession @@ -25,6 +27,8 @@ from augur.application.cli import test_connection, test_db_connection, with_database, DatabaseContext from augur.application.cli._cli_util import _broadcast_signal_to_processes, raise_open_file_limit, clear_redis_caches, clear_rabbitmq_messages +from keyman.KeyClient import KeyClient, KeyPublisher + logger = AugurLogger("augur", reset_logfiles=False).get_logger() @click.group('server', short_help='Commands for controlling the backend API server & data collection workers') @@ -51,6 +55,26 @@ def start(ctx, development): logger.error("Failed to raise open file limit!") raise e + keypub = KeyPublisher() + + orchestrator = subprocess.Popen("python keyman/Orchestrator.py".split()) + + # Wait for orchestrator startup + if not keypub.wait(republish=True): + logger.critical("Key orchestrator did not respond in time") + return + + # load keys + ghkeyman = GithubApiKeyHandler(logger) + glkeyman = GitlabApiKeyHandler(logger) + + for key in ghkeyman.keys: + keypub.publish(key, "github_rest") + keypub.publish(key, "github_graphql") + + for key in glkeyman.keys: + keypub.publish(key, "gitlab_rest") + if development: os.environ["AUGUR_DEV"] = "1" logger.info("Starting in development mode") @@ -94,6 +118,8 @@ def start(ctx, development): if p: p.terminate() + keypub.shutdown() + if celery_beat_process: logger.info("Shutting down celery beat process") celery_beat_process.terminate() diff --git a/augur/application/cli/github.py b/augur/application/cli/github.py new file mode 100644 index 000000000..cad13be79 --- /dev/null +++ b/augur/application/cli/github.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: MIT +import logging +import click +import sqlalchemy as s +from datetime import datetime +import httpx +from collections import Counter + +from augur.application.cli import test_connection, test_db_connection + +from augur.application.db.engine import DatabaseEngine +from augur.tasks.github.util.github_api_key_handler import GithubApiKeyHandler + + +logger = logging.getLogger(__name__) + +@click.group("github", short_help="Github utilities") +def cli(): + pass + +@cli.command("api-keys") +@test_connection +@test_db_connection +def update_api_key(): + """ + Get the ratelimit of Github API keys + """ + + with DatabaseEngine() as engine, engine.connect() as connection: + + get_api_keys_sql = s.sql.text( + """ + SELECT value as github_key from config Where section_name='Keys' AND setting_name='github_api_key' + UNION All + SELECT access_token as github_key from worker_oauth ORDER BY github_key DESC; + """ + ) + + result = connection.execute(get_api_keys_sql).fetchall() + keys = [x[0] for x in result] + + with httpx.Client() as client: + + invalid_keys = [] + valid_key_data = [] + for key in keys: + core_key_data, graphql_key_data = GithubApiKeyHandler.get_key_rate_limit(client, key) + if core_key_data is None or graphql_key_data is None: + invalid_keys.append(key) + else: + valid_key_data.append((key, core_key_data, graphql_key_data)) + + valid_key_data = sorted(valid_key_data, key=lambda x: x[1]["requests_remaining"]) + + core_request_header = "Core Requests Left" + core_reset_header = "Core Reset Time" + graphql_request_header = "Graphql Requests Left" + graphql_reset_header = "Graphql Reset Time" + print(f"{'Key'.center(40)} {core_request_header} {core_reset_header} {graphql_request_header} {graphql_reset_header}") + for key, core_key_data, graphql_key_data in valid_key_data: + core_requests = str(core_key_data['requests_remaining']).center(len(core_request_header)) + core_reset_time = str(epoch_to_local_time_with_am_pm(core_key_data["reset_epoch"])).center(len(core_reset_header)) + + graphql_requests = str(graphql_key_data['requests_remaining']).center(len(graphql_request_header)) + graphql_reset_time = str(epoch_to_local_time_with_am_pm(graphql_key_data["reset_epoch"])).center(len(graphql_reset_header)) + + print(f"{key} | {core_requests} | {core_reset_time} | {graphql_requests} | {graphql_reset_time} |") + + valid_key_list = [x[0] for x in valid_key_data] + duplicate_keys = find_duplicates(valid_key_list) + if len(duplicate_keys) > 0: + print("\n\nWARNING: There are duplicate keys this will slow down collection") + print("Duplicate keys".center(40)) + for key in duplicate_keys: + print(key) + + + if len(invalid_keys) > 0: + invalid_key_header = "Invalid Keys".center(40) + print("\n") + print(invalid_key_header) + for key in invalid_keys: + print(key) + print("") + + + + engine.dispose() + + +def epoch_to_local_time_with_am_pm(epoch): + local_time = datetime.fromtimestamp(epoch) + formatted_time = local_time.strftime('%I:%M %p') # This format includes the date as well + return formatted_time + + +def find_duplicates(lst): + counter = Counter(lst) + return [item for item, count in counter.items() if count > 1] + diff --git a/augur/tasks/github/events.py b/augur/tasks/github/events.py index 481fc0a42..cf7df5758 100644 --- a/augur/tasks/github/events.py +++ b/augur/tasks/github/events.py @@ -3,22 +3,24 @@ import sqlalchemy as s from sqlalchemy.sql import text from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone from augur.tasks.init.celery_app import celery_app as celery from augur.tasks.init.celery_app import AugurCoreRepoCollectionTask from augur.application.db.data_parse import * from augur.tasks.github.util.github_data_access import GithubDataAccess, UrlNotFoundException from augur.tasks.github.util.github_random_key_auth import GithubRandomKeyAuth +from augur.tasks.github.util.github_task_session import GithubTaskManifest from augur.tasks.github.util.util import get_owner_repo from augur.tasks.util.worker_util import remove_duplicate_dicts -from augur.application.db.models import PullRequestEvent, IssueEvent, Contributor, CollectionStatus -from augur.application.db.lib import get_repo_by_repo_git, bulk_insert_dicts, get_issues_by_repo_id, get_pull_requests_by_repo_id, update_issue_closed_cntrbs_by_repo_id, get_session, get_engine +from augur.application.db.models import PullRequestEvent, IssueEvent, Contributor, Repo +from augur.application.db.lib import get_repo_by_repo_git, bulk_insert_dicts, get_issues_by_repo_id, get_pull_requests_by_repo_id, update_issue_closed_cntrbs_by_repo_id, get_session, get_engine, get_core_data_last_collected platform_id = 1 @celery.task(base=AugurCoreRepoCollectionTask) -def collect_events(repo_git: str): +def collect_events(repo_git: str, full_collection: bool): logger = logging.getLogger(collect_events.__name__) @@ -26,6 +28,14 @@ def collect_events(repo_git: str): logger.debug(f"Collecting Github events for {owner}/{repo}") + if full_collection: + core_data_last_collected = None + else: + repo_id = get_repo_by_repo_git(repo_git).repo_id + + # subtract 2 days to ensure all data is collected + core_data_last_collected = (get_core_data_last_collected(repo_id) - timedelta(days=2)).replace(tzinfo=timezone.utc) + key_auth = GithubRandomKeyAuth(logger) if bulk_events_collection_endpoint_contains_all_data(key_auth, logger, owner, repo): @@ -33,7 +43,7 @@ def collect_events(repo_git: str): else: collection_strategy = ThoroughGithubEventCollection(logger) - collection_strategy.collect(repo_git, key_auth) + collection_strategy.collect(repo_git, key_auth, core_data_last_collected) def bulk_events_collection_endpoint_contains_all_data(key_auth, logger, owner, repo): @@ -60,7 +70,7 @@ def __init__(self, logger): self._data_source = "Github API" @abstractmethod - def collect(self, repo_git, key_auth): + def collect(self, repo_git, key_auth, since): pass def _insert_issue_events(self, events): @@ -97,7 +107,7 @@ def __init__(self, logger): super().__init__(logger) - def collect(self, repo_git, key_auth): + def collect(self, repo_git, key_auth, since): repo_obj = get_repo_by_repo_git(repo_git) repo_id = repo_obj.repo_id @@ -106,7 +116,7 @@ def collect(self, repo_git, key_auth): self.repo_identifier = f"{owner}/{repo}" events = [] - for event in self._collect_events(repo_git, key_auth): + for event in self._collect_events(repo_git, key_auth, since): events.append(event) # making this a decent size since process_events retrieves all the issues and prs each time @@ -117,7 +127,7 @@ def collect(self, repo_git, key_auth): if events: self._process_events(events, repo_id) - def _collect_events(self, repo_git: str, key_auth): + def _collect_events(self, repo_git: str, key_auth, since): owner, repo = get_owner_repo(repo_git) @@ -125,7 +135,13 @@ def _collect_events(self, repo_git: str, key_auth): github_data_access = GithubDataAccess(key_auth, self._logger) - return github_data_access.paginate_resource(url) + for event in github_data_access.paginate_resource(url): + + yield event + + # return if last event on the page was updated before the since date + if since and datetime.fromisoformat(event["created_at"].replace("Z", "+00:00")).replace(tzinfo=timezone.utc) < since: + return def _process_events(self, events, repo_id): @@ -248,7 +264,7 @@ class ThoroughGithubEventCollection(GithubEventCollection): def __init__(self, logger): super().__init__(logger) - def collect(self, repo_git, key_auth): + def collect(self, repo_git, key_auth, since): repo_obj = get_repo_by_repo_git(repo_git) repo_id = repo_obj.repo_id @@ -256,10 +272,10 @@ def collect(self, repo_git, key_auth): owner, repo = get_owner_repo(repo_git) self.repo_identifier = f"{owner}/{repo}" - self._collect_and_process_issue_events(owner, repo, repo_id, key_auth) - self._collect_and_process_pr_events(owner, repo, repo_id, key_auth) + self._collect_and_process_issue_events(owner, repo, repo_id, key_auth, since) + self._collect_and_process_pr_events(owner, repo, repo_id, key_auth, since) - def _collect_and_process_issue_events(self, owner, repo, repo_id, key_auth): + def _collect_and_process_issue_events(self, owner, repo, repo_id, key_auth, since): engine = get_engine() @@ -267,7 +283,11 @@ def _collect_and_process_issue_events(self, owner, repo, repo_id, key_auth): # TODO: Remove src id if it ends up not being needed query = text(f""" - select issue_id as issue_id, gh_issue_number as issue_number, gh_issue_id as gh_src_id from issues WHERE repo_id={repo_id} order by created_at desc; + select issue_id as issue_id, gh_issue_number as issue_number, gh_issue_id as gh_src_id + from issues + where repo_id={repo_id} + and updated_at > timestamptz(timestamp '{since}') + order by created_at desc; """) issue_result = connection.execute(query).fetchall() @@ -309,14 +329,18 @@ def _collect_and_process_issue_events(self, owner, repo, repo_id, key_auth): events.clear() - def _collect_and_process_pr_events(self, owner, repo, repo_id, key_auth): + def _collect_and_process_pr_events(self, owner, repo, repo_id, key_auth, since): engine = get_engine() with engine.connect() as connection: query = text(f""" - select pull_request_id, pr_src_number as gh_pr_number, pr_src_id from pull_requests WHERE repo_id={repo_id} order by pr_created_at desc; + select pull_request_id, pr_src_number as gh_pr_number, pr_src_id + from pull_requests + where repo_id={repo_id} + and pr_updated_at > timestamptz(timestamp '{since}') + order by pr_created_at desc; """) pr_result = connection.execute(query).fetchall() diff --git a/augur/tasks/github/pull_requests/tasks.py b/augur/tasks/github/pull_requests/tasks.py index c581ceb35..b65da7f4f 100644 --- a/augur/tasks/github/pull_requests/tasks.py +++ b/augur/tasks/github/pull_requests/tasks.py @@ -219,7 +219,7 @@ def process_pull_request_review_contributor(pr_review: dict, tool_source: str, t return pr_review_cntrb @celery.task(base=AugurSecondaryRepoCollectionTask) -def collect_pull_request_review_comments(repo_git: str) -> None: +def collect_pull_request_review_comments(repo_git: str, full_collection: bool) -> None: owner, repo = get_owner_repo(repo_git) @@ -230,6 +230,11 @@ def collect_pull_request_review_comments(repo_git: str) -> None: repo_id = get_repo_by_repo_git(repo_git).repo_id + if not full_collection: + # subtract 2 days to ensure all data is collected + core_data_last_collected = (get_core_data_last_collected(repo_id) - timedelta(days=2)).replace(tzinfo=timezone.utc) + review_msg_url += f"?since={core_data_last_collected.isoformat()}" + pr_reviews = get_pull_request_reviews_by_repo_id(repo_id) # maps the github pr_review id to the auto incrementing pk that augur stores as pr_review id diff --git a/augur/tasks/github/util/github_api_key_handler.py b/augur/tasks/github/util/github_api_key_handler.py index 4f8178e7c..9ca777a1b 100644 --- a/augur/tasks/github/util/github_api_key_handler.py +++ b/augur/tasks/github/util/github_api_key_handler.py @@ -9,6 +9,8 @@ from augur.application.db.lib import get_value, get_worker_oauth_keys from sqlalchemy import func +RATE_LIMIT_URL = "https://api.github.com/rate_limit" + class NoValidKeysError(Exception): pass @@ -152,12 +154,9 @@ def is_bad_api_key(self, client: httpx.Client, oauth_key: str) -> bool: True if key is bad. False if the key is good """ - # this endpoint allows us to check the rate limit, but it does not use one of our 5000 requests - url = "https://api.github.com/rate_limit" - headers = {'Authorization': f'token {oauth_key}'} - data = client.request(method="GET", url=url, headers=headers, timeout=180).json() + data = client.request(method="GET", url=RATE_LIMIT_URL, headers=headers, timeout=180).json() try: if data["message"] == "Bad credentials": @@ -165,4 +164,25 @@ def is_bad_api_key(self, client: httpx.Client, oauth_key: str) -> bool: except KeyError: pass - return False \ No newline at end of file + return False + + @staticmethod + def get_key_rate_limit(client, github_key): + + headers = {'Authorization': f'token {github_key}'} + + data = client.request(method="GET", url=RATE_LIMIT_URL, headers=headers, timeout=180).json() + + if "message" in data: + return None, None + + def convert_rate_limit_request(data): + return { + "requests_remaining": data["remaining"], + "reset_epoch": data["reset"] + } + + core_data = convert_rate_limit_request(data["resources"]["core"]) + graphql_data = convert_rate_limit_request(data["resources"]["graphql"]) + + return core_data, graphql_data \ No newline at end of file diff --git a/augur/tasks/github/util/github_data_access.py b/augur/tasks/github/util/github_data_access.py index 850336f53..a648f990c 100644 --- a/augur/tasks/github/util/github_data_access.py +++ b/augur/tasks/github/util/github_data_access.py @@ -3,6 +3,7 @@ import httpx from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception, RetryError from urllib.parse import urlparse, parse_qs, urlencode +from keyman.KeyClient import KeyClient class RatelimitException(Exception): @@ -21,7 +22,8 @@ class GithubDataAccess: def __init__(self, key_manager, logger: logging.Logger): self.logger = logger - self.key_manager = key_manager + self.key_client = KeyClient("github_rest", logger) + self.key = None def get_resource_count(self, url): @@ -93,7 +95,12 @@ def make_request(self, url, method="GET", timeout=100): with httpx.Client() as client: - response = client.request(method=method, url=url, auth=self.key_manager, timeout=timeout, follow_redirects=True) + if not self.key: + self.key = self.key_client.request() + + headers = {"Authorization": f"token {self.key}"} + + response = client.request(method=method, url=url, headers=headers, timeout=timeout, follow_redirects=True) if response.status_code in [403, 429]: raise RatelimitException(response) @@ -121,7 +128,7 @@ def __make_request_with_retries(self, url, method="GET", timeout=100): 1. Retires 10 times 2. Waits 5 seconds between retires 3. Does not rety UrlNotFoundException - 4. Catches RatelimitException and waits before raising exception + 4. Catches RatelimitException and waits or expires key before raising exception """ try: @@ -150,8 +157,9 @@ def __handle_github_ratelimit_response(self, response): self.logger.error(f"Key reset time was less than 0 setting it to 0.\nThe current epoch is {current_epoch} and the epoch that the key resets at is {epoch_when_key_resets}") key_reset_time = 0 - self.logger.info(f"\n\n\nAPI rate limit exceeded. Sleeping until the key resets ({key_reset_time} seconds)") - time.sleep(key_reset_time) + self.logger.info(f"\n\n\nAPI rate limit exceeded. Key resets in {key_reset_time} seconds. Informing key manager that key is expired") + self.key = self.key_client.expire(self.key, epoch_when_key_resets) + else: time.sleep(60) diff --git a/augur/tasks/gitlab/gitlab_api_key_handler.py b/augur/tasks/gitlab/gitlab_api_key_handler.py index 40b37d62c..03c0cf66c 100644 --- a/augur/tasks/gitlab/gitlab_api_key_handler.py +++ b/augur/tasks/gitlab/gitlab_api_key_handler.py @@ -128,8 +128,9 @@ def get_api_keys(self) -> List[str]: # add all the keys to redis self.redis_key_list.extend(valid_keys) - if not valid_keys: - raise NoValidKeysError("No valid gitlab api keys found in the config or worker oauth table") + # Removed because most people do not collect gitlab and this blows up on startup if they don't have any gitlab keys + # if not valid_keys: + # raise NoValidKeysError("No valid gitlab api keys found in the config or worker oauth table") # shuffling the keys so not all processes get the same keys in the same order diff --git a/augur/tasks/gitlab/issues_task.py b/augur/tasks/gitlab/issues_task.py index 391f981d3..dc5b8ffc2 100644 --- a/augur/tasks/gitlab/issues_task.py +++ b/augur/tasks/gitlab/issues_task.py @@ -288,7 +288,7 @@ def process_gitlab_issue_messages(data, task_name, repo_id, logger, session): # create mapping from mr number to pull request id of current mrs issue_number_to_id_map = {} - issues = session.session.query(Issue).filter(Issue.repo_id == repo_id).all() + issues = session.query(Issue).filter(Issue.repo_id == repo_id).all() for issue in issues: issue_number_to_id_map[issue.gh_issue_number] = issue.issue_id diff --git a/augur/tasks/start_tasks.py b/augur/tasks/start_tasks.py index 8aa767ece..3ba30ed70 100644 --- a/augur/tasks/start_tasks.py +++ b/augur/tasks/start_tasks.py @@ -73,7 +73,7 @@ def primary_repo_collect_phase(repo_git, full_collection): #Define secondary group that can't run until after primary jobs have finished. secondary_repo_jobs = group( - collect_events.si(repo_git),#*create_grouped_task_load(dataList=first_pass, task=collect_events).tasks, + collect_events.si(repo_git, full_collection),#*create_grouped_task_load(dataList=first_pass, task=collect_events).tasks, collect_github_messages.si(repo_git, full_collection), #*create_grouped_task_load(dataList=first_pass,task=collect_github_messages).tasks, collect_github_repo_clones_data.si(repo_git), ) @@ -120,7 +120,7 @@ def secondary_repo_collect_phase(repo_git, full_collection): repo_task_group = group( process_pull_request_files.si(repo_git, full_collection), process_pull_request_commits.si(repo_git, full_collection), - chain(collect_pull_request_reviews.si(repo_git, full_collection), collect_pull_request_review_comments.si(repo_git)), + chain(collect_pull_request_reviews.si(repo_git, full_collection), collect_pull_request_review_comments.si(repo_git, full_collection)), process_ossf_dependency_metrics.si(repo_git) ) diff --git a/keyman/KeyClient.py b/keyman/KeyClient.py new file mode 100644 index 000000000..05f2c4739 --- /dev/null +++ b/keyman/KeyClient.py @@ -0,0 +1,253 @@ +from augur.tasks.init.redis_connection import redis_connection as conn +from redis.client import PubSub +from logging import Logger +from os import getpid +import time, json + +from keyman.KeyOrchestrationAPI import spec, WaitKeyTimeout + +class KeyClient: + """ NOT THREAD SAFE! + + Only one KeyClient can exist at a time per *process*, as + the process ID is used for async communication between + the client and the orchestrator. + + param platform: The default platform to use for key requests + """ + def __init__(self, platform: str, logger: Logger): + self.id = getpid() + + # Load channel names and IDs from the spec + for channel in spec["channels"]: + # IE: self.ANNOUNCE = "augur-oauth-announce" + setattr(self, channel["name"], channel["id"]) + + if not platform: + raise ValueError("Platform must not be empty") + + self.stdout = conn + self.stdin: PubSub = conn.pubsub(ignore_subscribe_messages = True) + self.stdin.subscribe(f"{self.REQUEST}-{self.id}") + self.platform = platform + self.logger = logger + + def _send(self, req_type, **kwargs): + kwargs["type"] = req_type + kwargs["requester_id"] = self.id + + self.stdout.publish(self.REQUEST, json.dumps(kwargs)) + + def _recv(self, timeout = None): + if timeout is not None: + return self.stdin.get_message(timeout = timeout) + + stream = self.stdin.listen() + + reply = next(stream) + + msg = json.loads(reply["data"]) + + if "wait" in msg: + raise WaitKeyTimeout(msg["wait"]) + else: + return msg + + """ Request a new key from the Orchestrator + + Will block until a key is available. Will block + *indefinitely* if no keys are available for the + requested platform. + + Optionally supply a platform string, if the default + one provided during initialization does not match + the desired platform for this request. + """ + def request(self, platform = None) -> str: + while True: + self._send("NEW", key_platform = platform or self.platform) + try: + msg = self._recv() + if "key" in msg: + return msg["key"] + + else: + raise Exception(f"Invalid response type: {msg}") + except WaitKeyTimeout as e: + self.logger.debug(f"NO FRESH KEYS: sleeping for {e.tiemout_seconds} seconds") + time.sleep(e.tiemout_seconds) + except Exception as e: + self.logger.exception("Error during key request") + time.sleep(20) + + """ Expire a key, and get a new key in return + + Will block until a key is available. Will block + *indefinitely* if no keys are available for the + requested platform. + + Optionally supply a platform string, if the default + one provided during initialization does not match + the desired platform for this request. The platform + given *must* match the old key, and also the new key. + """ + def expire(self, key: str, refresh_timestamp: int, platform: str = None) -> str: + message = { + "type": "EXPIRE", + "key_str": key, + "key_platform": platform or self.platform, + "refresh_time": refresh_timestamp, + "requester_id": self.id + } + + self.stdout.publish(self.REQUEST, json.dumps(message)) + time.sleep(0.1) + return self.request() + +class KeyPublisher: + """ NOT THREAD SAFE! + + Only one KeyPublisher can exist at a time per *process*, + as the process ID is used for async communication between + the publisher and the orchestrator. + """ + def __init__(self) -> None: + # Load channel names and IDs from the spec + for channel in spec["channels"]: + # IE: self.ANNOUNCE = "augur-oauth-announce" + setattr(self, channel["name"], channel["id"]) + + self.id = getpid() + self.stdin: PubSub = conn.pubsub(ignore_subscribe_messages = True) + self.stdin.subscribe(f"{self.ANNOUNCE}-{self.id}") + + """ Publish a key to the orchestration server + + No reply is sent, and keys are added or overwritten + silently. + """ + def publish(self, key: str, platform: str): + message = { + "type": "PUBLISH", + "key_str": key, + "key_platform": platform + } + + conn.publish(self.ANNOUNCE, json.dumps(message)) + + """ Unpublish a key, and remove it from orchestration + + They key will remain in use by any workers that are currently + using it, but it will not be assigned to any new requests. + + No reply is sent, and non-existent keys or platforms are + ignored silently. + """ + def unpublish(self, key: str, platform: str): + message = { + "type": "UNPUBLISH", + "key_str": key, + "key_platform": platform + } + + conn.publish(self.ANNOUNCE, json.dumps(message)) + + """ Wait for ACK from the orchestrator + + If a lot of publish or unpublish messages are waiting to + be processed, this will block until all of them have been + read. If the timeout is reached, this returns False, or if + the orchestration server acknkowledges within the time + limit, this returns True. + + If republish is true, the initial ACK request will be resent + 10 times per second until the orchestrator responds. This + should only be used to wait for the orchestrator to come + online, as it could put a lot of unnecessary messages on the + queue if the orchestrator is running, but very busy. + """ + def wait(self, timeout_seconds = 30, republish = False): + if timeout_seconds < 0: + raise ValueError("timeout cannot be negative") + + message = { + "type": "ACK", + "requester_id": self.id + } + + listen_delta = 0.1 + conn.publish(self.ANNOUNCE, json.dumps(message)) + + # Just wait for and consume the next incoming message + while timeout_seconds >= 0: + # get_message supposedly takes a 'timeout' parameter, but that did not work + reply = self.stdin.get_message(ignore_subscribe_messages = True) + + if reply: + return True + elif timeout_seconds < listen_delta: + break + elif republish: + conn.publish(self.ANNOUNCE, json.dumps(message)) + + time.sleep(listen_delta) + timeout_seconds -= listen_delta + + return False + + """ Get a list of currently loaded orchestration platforms + + Will raise a ValueError if the orchestration server + returns a malformed response. + """ + def list_platforms(self): + message = { + "type": "LIST_PLATFORMS", + "requester_id": self.id + } + + conn.publish(self.ANNOUNCE, json.dumps(message)) + + reply = next(self.stdin.listen()) + + try: + reply = json.loads(reply["data"]) + except Exception as e: + raise ValueError("Exception during platform list decoding") + + if isinstance(reply, list): + return reply + + raise ValueError(f"Unexpected reply during list operation: {reply}") + + """ Get a list of currently loaded keys for the given platform + + Will raise a ValueError if the orchestration server + returns a malformed response, or if the platform does + not exist. + """ + def list_keys(self, platform): + message = { + "type": "LIST_KEYS", + "requester_id": self.id, + "key_platform": platform + } + + conn.publish(self.ANNOUNCE, json.dumps(message)) + + reply = next(self.stdin.listen()) + + try: + reply = json.loads(reply["data"]) + except Exception as e: + raise ValueError("Exception during key list decoding") + + if isinstance(reply, list): + return reply + elif isinstance(reply, dict) and "status" in reply: + raise ValueError(f"Orchestration error: {reply['status']}") + else: + raise ValueError(f"Unexpected reply during list operation: {reply}") + + def shutdown(self): + conn.publish(self.ANNOUNCE, json.dumps({"type": "SHUTDOWN"})) diff --git a/keyman/KeyOrchestrationAPI.py b/keyman/KeyOrchestrationAPI.py new file mode 100644 index 000000000..c689e2c23 --- /dev/null +++ b/keyman/KeyOrchestrationAPI.py @@ -0,0 +1,79 @@ + +""" This is a hybrid-fixed specification + +The names of the channels *MUST NOT* change, +but the channel IDs are free to +""" +spec = { + "channels": [ + { + "name": "ANNOUNCE", + "id": "augur-oauth-announce", + "message_types": { + "PUBLISH": { + "key_str": str, + "key_platform": str + }, + "UNPUBLISH": { + "key_str": str, + "key_platform": str + }, + "ACK": { + "fields": { + "requester_id": { "required": str } + }, + "response": "" + }, + "LIST_PLATFORMS": { + "fields": { + "requester_id": { "required": str } + }, + "response": [ + { "required": list[str] } + ] + }, + "LIST_KEYS": { + "fields": { + "requester_id": { "required": str }, + "key_platform": { "required": str } + }, + "response": [ + { "optional": list[str] }, + { "optional": { "status": "error" } } + ] + }, + "SHUTDOWN": {} + } + }, { + "name": "REQUEST", + "id": "worker-oath-request", + "message_types": { + "NEW": { + "fields": { + "key_platform": { "required": str }, + "requester_id": { "required": str } + }, + "response": { + "key": { "optional": str }, + "wait": { "optional": int } + } + }, + "EXPIRE": { + "fields": { + "key_str": str, + "key_platform": str, + "refresh_time": int, + "requester_id": str + }, + "response": { + + } + } + } + } + ] +} + +class WaitKeyTimeout(Exception): + def __init__(self, timeout_seconds) -> None: + self.tiemout_seconds = timeout_seconds \ No newline at end of file diff --git a/keyman/Orchestrator.py b/keyman/Orchestrator.py new file mode 100644 index 000000000..f6e5245fe --- /dev/null +++ b/keyman/Orchestrator.py @@ -0,0 +1,165 @@ +from augur.tasks.init.redis_connection import redis_connection as conn +from augur.application.logs import AugurLogger +import json, random, time + +from keyman.KeyOrchestrationAPI import spec, WaitKeyTimeout + +class KeyOrchestrator: + def __init__(self) -> None: + self.stdin = conn.pubsub(ignore_subscribe_messages = True) + self.logger = AugurLogger("KeyOrchestrator").get_logger() + + # Load channel names and IDs from the spec + for channel in spec["channels"]: + # IE: self.ANNOUNCE = "augur-oauth-announce" + setattr(self, channel["name"], channel["id"]) + self.stdin.subscribe(channel["id"]) + + self.fresh_keys: dict[str, list[str]] = {} + self.expired_keys: dict[str, dict[str, int]] = {} + + def publish_key(self, key, platform): + if platform not in self.fresh_keys: + self.fresh_keys[platform] = [key] + self.expired_keys[platform] = {} + else: + self.fresh_keys[platform].append(key) + + def unpublish_key(self, key, platform): + if platform not in self.fresh_keys: + return + + if key in self.fresh_keys[platform]: + self.fresh_keys[platform].remove(key) + elif key in self.expired_keys[platform]: + self.expired_keys[platform].pop(key) + + def expire_key(self, key, platform, timeout): + if not platform in self.fresh_keys or not key in self.fresh_keys[platform]: + return + + self.fresh_keys[platform].remove(key) + + self.expired_keys[platform][key] = timeout + + def refresh_keys(self): + curr_time = time.time() + + for platform in self.expired_keys: + refreshed_keys = [] + + for key, timeout in self.expired_keys[platform].items(): + if timeout <= curr_time: + refreshed_keys.append(key) + + for key in refreshed_keys: + self.fresh_keys[platform].append(key) + self.expired_keys[platform].pop(key) + + def new_key(self, platform): + if not len(self.fresh_keys[platform]): + if not len(self.expired_keys[platform]): + self.logger.warning(f"Key was requested for {platform}, but none are published") + return + + min = 0 + for key, timeout in self.expired_keys[platform].items(): + if not min or timeout < min: + min = timeout + + delta = int(min - time.time()) + + raise WaitKeyTimeout(delta + 5 if delta > 0 else 5) + + return random.choice(self.fresh_keys[platform]) + + def run(self): + self.logger.info("Ready") + for msg in self.stdin.listen(): + try: + if msg.get("type") != "message": + # Filter out unwanted events + continue + elif not (channel := msg.get("channel")): + # The pub/sub API makes no guarantee that a channel will be specified + continue + + # The docs say that msg.channel is a bytes, but testing shows it's a str ? + channel: str = channel.decode() if isinstance(channel, bytes) else channel + + request = json.loads(msg.get("data")) + except Exception as e: + self.logger.error("Error during request decoding") + self.logger.exception(e) + continue + + """ For performance reasons: + + Instead of dynamically checking that the + given channel matches one that we're + listening for, just check against each + channel that we have actions prepared for. + """ + if channel == self.ANNOUNCE: + if "requester_id" in request: + stdout = f"{self.ANNOUNCE}-{request['requester_id']}" + try: + if request["type"] == "PUBLISH": + self.publish_key(request["key_str"], request["key_platform"]) + elif request["type"] == "UNPUBLISH": + self.unpublish_key(request["key_str"], request["key_platform"]) + elif request["type"] == "ACK": + conn.publish(stdout, "") + self.logger.info(f"ACK; for: {request['requester_id']}") + elif request["type"] == "LIST_PLATFORMS": + platforms = [ p for p in self.fresh_keys.keys() ] + conn.publish(stdout, json.dumps(platforms)) + elif request["type"] == "LIST_KEYS": + keys = list(self.fresh_keys[request["key_platform"]]) + keys += list(self.expired_keys[request["key_platform"]].keys()) + conn.publish(stdout, json.dumps(keys)) + elif request["type"] == "SHUTDOWN": + self.logger.info("Shutting down") + # Close + return + except Exception as e: + # This is a bare exception, because we don't really care why failure happened + self.logger.exception("Error during ANNOUNCE") + continue + + elif channel == self.REQUEST: + self.refresh_keys() + stdout = f"{self.REQUEST}-{request['requester_id']}" + + try: + if request["type"] == "NEW": + new_key = self.new_key(request["key_platform"]) + elif request["type"] == "EXPIRE": + self.expire_key(request["key_str"], request["key_platform"], request["refresh_time"]) + self.logger.debug(f"EXPIRE; from: {request['requester_id']}, platform: {request['key_platform']}") + continue + except WaitKeyTimeout as w: + timeout = w.tiemout_seconds + conn.publish(stdout, json.dumps({ + "wait": timeout + })) + continue + except Exception as e: + # This is a bare exception, because we don't really care why failure happened + self.logger.exception("Error during REQUEST") + continue + + self.logger.debug(f"REPLY; for: {request['requester_id']}, platform: {request['key_platform']}") + conn.publish(stdout, json.dumps({ + "key": new_key + })) + +if __name__ == "__main__": + manager = KeyOrchestrator() + + try: + manager.run() + except KeyboardInterrupt: + # Exit silently on sigint + manager.logger.info("Interrupted") + pass diff --git a/metadata.py b/metadata.py index a8e71cd7a..2ac76c87d 100644 --- a/metadata.py +++ b/metadata.py @@ -5,8 +5,8 @@ __short_description__ = "Python 3 package for free/libre and open-source software community metrics, models & data collection" -__version__ = "0.80.1" -__release__ = "v0.80.1 (Data Monster)" +__version__ = "0.81.0" +__release__ = "v0.81.0 (Super Soaker)" __license__ = "MIT" -__copyright__ = "University of Missouri, University of Nebraska-Omaha, CHAOSS, Brian Warner & Augurlabs 2112" +__copyright__ = "University of Missouri, University of Nebraska-Omaha, CHAOSS, Sean Goggins, Brian Warner & Augurlabs 2112" diff --git a/tests/key_manager.py b/tests/key_manager.py new file mode 100644 index 000000000..b7b397ee8 --- /dev/null +++ b/tests/key_manager.py @@ -0,0 +1,96 @@ +from keyman.KeyClient import KeyClient, KeyPublisher +from augur.application.logs import AugurLogger + +from multiprocessing import Process, current_process +from subprocess import Popen, PIPE + +import random, time, atexit + +keys = { + "github": [ + "key1", + "key2", + "key3", + "key4" + ], + "gitlab": [ + "key5", + "key6", + "key7", + "key8" + ] +} + +def mp_consumer(platform): + if platform not in keys: + raise ValueError(f"Platform not valid for testing keys dict: {platform}") + + logger = AugurLogger(f"Keyman_test_consumer").get_logger() + logger.setLevel(1) + client = KeyClient(platform, logger) + + key = client.request() + for _ in range(len(keys[platform])): + if key not in keys[platform]: + raise AssertionError(f"Received key {platform}:{key} not valid") + + sleep_timeout = random.randint(10, 30) + + """ This is a fairly unrealistic scenario, because theoretically + every worker that would expire a key would expire it at the + exact same timestamp (as reported by the platform API). + + Whereas for this testing, each worker is just reporting a + random future timestamp. In effect, this means that two + workers might assign different timeouts to the same key, + making it appear as though the orchestrator is not assigning + keys properly, when in fact it is working as expected. + """ + logger.info(f"Expiring {platform}:{key} for {sleep_timeout} seconds") + key = client.expire(key, int(time.time()) + sleep_timeout) + +if __name__ == "__main__": + orchestrator = Popen("python keyman/Orchestrator.py".split()) + + publisher = KeyPublisher() + + if not publisher.wait(republish = True): + raise AssertionError("Orchestrator not reachable on startup") + + atexit.register(publisher.shutdown) + + for platform, key_list in keys.items(): + for key in key_list: + publisher.publish(key, platform) + + if not publisher.wait(): + raise AssertionError("Orchestrator did not ACK within the time limit") + + logger = AugurLogger("Keyman_test").get_logger() + logger.info("Keys loaded") + + platforms = publisher.list_platforms() + + logger.info(f"Loaded platforms: {platforms}") + + for platform in platforms: + key_list = publisher.list_keys(platform) + logger.info(f"Keys for {platform}: {key_list}") + + logger.info("Running expiration tests") + workers: list[Process] = [] + for platform, key_list in keys.items(): + num_workers = len(key_list) // 2 + + for i in range(num_workers): + workers.append(Process(target = mp_consumer, args = [platform])) + + try: + for worker in workers: + worker.start() + + for worker in workers: + worker.join() + except KeyboardInterrupt: + pass + \ No newline at end of file