diff --git a/src/libkernelbot/launchers/github.py b/src/libkernelbot/launchers/github.py index 3049b882..e057fcf1 100644 --- a/src/libkernelbot/launchers/github.py +++ b/src/libkernelbot/launchers/github.py @@ -5,7 +5,9 @@ import io import json import math +import os import pprint +import threading import uuid import zipfile import zlib @@ -56,9 +58,35 @@ class GitHubLauncher(Launcher): def __init__(self, repo: str, token: str, branch: str): super().__init__(name="GitHub", gpus=GitHubGPU) self.repo = repo - self.token = token + self.tokens = self._load_github_tokens(token) + self._token_lock = threading.Lock() + self._token_idx = 0 self.branch = branch + @staticmethod + def _load_github_tokens(fallback_token: str) -> list[str]: + primary = (os.getenv("GITHUB_TOKEN") or fallback_token).strip() + backup = (os.getenv("GITHUB_TOKEN_BACKUP") or "").strip() + + tokens: list[str] = [] + for t in (primary, backup): + if t and t not in tokens: + tokens.append(t) + + if not tokens: + raise KernelBotError( + "No GitHub tokens configured. Set GITHUB_TOKEN " + "(and optionally GITHUB_TOKEN_BACKUP)." + ) + + return tokens + + def _next_token(self) -> str: + with self._token_lock: + token = self.tokens[self._token_idx] + self._token_idx = (self._token_idx + 1) % len(self.tokens) + return token + async def run_submission( # noqa: C901 self, config: dict, gpu_type: GPU, status: RunProgressReporter ) -> FullResult: @@ -87,7 +115,7 @@ async def run_submission( # noqa: C901 lang_name = {"py": "Python", "cu": "CUDA"}[lang] logger.info(f"Attempting to trigger GitHub action for {lang_name} on {selected_workflow}") - run = GitHubRun(self.repo, self.token, self.branch, selected_workflow) + run = GitHubRun(self.repo, self._next_token(), self.branch, selected_workflow) logger.info(f"Successfully created GitHub run: {run.run_id}") payload = base64.b64encode(zlib.compress(json.dumps(config).encode("utf-8"))).decode( @@ -174,7 +202,7 @@ class GitHubArtifact: public_download_url: str -_WORKFLOW_FILE_CACHE: dict[str, Workflow] = {} +_WORKFLOW_FILE_CACHE: dict[tuple[str, str], Workflow] = {} class GitHubRun: @@ -215,12 +243,13 @@ def elapsed_time(self): return datetime.datetime.now(datetime.timezone.utc) - self.start_time async def get_workflow(self) -> Workflow: - if self.workflow_file in _WORKFLOW_FILE_CACHE: + cache_key = (self.workflow_file, self.token) + if cache_key in _WORKFLOW_FILE_CACHE: logger.info(f"Returning cached workflow {self.workflow_file}") - return _WORKFLOW_FILE_CACHE[self.workflow_file] + return _WORKFLOW_FILE_CACHE[cache_key] logger.info(f"Fetching workflow {self.workflow_file} from GitHub") workflow = self.repo.get_workflow(self.workflow_file) - _WORKFLOW_FILE_CACHE[self.workflow_file] = workflow + _WORKFLOW_FILE_CACHE[cache_key] = workflow return workflow async def trigger(self, inputs: dict) -> bool: