Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions src/libkernelbot/launchers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import io
import json
import math
import os
import pprint
import threading
import uuid
import zipfile
import zlib
Expand Down Expand Up @@ -56,9 +58,35 @@ class GitHubLauncher(Launcher):
def __init__(self, repo: str, token: str, branch: str):
super().__init__(name="GitHub", gpus=GitHubGPU)
self.repo = repo
self.token = token
self.tokens = self._load_github_tokens(token)
self._token_lock = threading.Lock()
self._token_idx = 0
self.branch = branch

@staticmethod
def _load_github_tokens(fallback_token: str) -> list[str]:
primary = (os.getenv("GITHUB_TOKEN") or fallback_token).strip()
backup = (os.getenv("GITHUB_TOKEN_BACKUP") or "").strip()

tokens: list[str] = []
for t in (primary, backup):
if t and t not in tokens:
tokens.append(t)

if not tokens:
raise KernelBotError(
"No GitHub tokens configured. Set GITHUB_TOKEN "
"(and optionally GITHUB_TOKEN_BACKUP)."
)

return tokens

def _next_token(self) -> str:
with self._token_lock:
token = self.tokens[self._token_idx]
self._token_idx = (self._token_idx + 1) % len(self.tokens)
return token

async def run_submission( # noqa: C901
self, config: dict, gpu_type: GPU, status: RunProgressReporter
) -> FullResult:
Expand Down Expand Up @@ -87,7 +115,7 @@ async def run_submission( # noqa: C901
lang_name = {"py": "Python", "cu": "CUDA"}[lang]

logger.info(f"Attempting to trigger GitHub action for {lang_name} on {selected_workflow}")
run = GitHubRun(self.repo, self.token, self.branch, selected_workflow)
run = GitHubRun(self.repo, self._next_token(), self.branch, selected_workflow)
logger.info(f"Successfully created GitHub run: {run.run_id}")

payload = base64.b64encode(zlib.compress(json.dumps(config).encode("utf-8"))).decode(
Expand Down Expand Up @@ -174,7 +202,7 @@ class GitHubArtifact:
public_download_url: str


_WORKFLOW_FILE_CACHE: dict[str, Workflow] = {}
_WORKFLOW_FILE_CACHE: dict[tuple[str, str], Workflow] = {}


class GitHubRun:
Expand Down Expand Up @@ -215,12 +243,13 @@ def elapsed_time(self):
return datetime.datetime.now(datetime.timezone.utc) - self.start_time

async def get_workflow(self) -> Workflow:
if self.workflow_file in _WORKFLOW_FILE_CACHE:
cache_key = (self.workflow_file, self.token)
if cache_key in _WORKFLOW_FILE_CACHE:
logger.info(f"Returning cached workflow {self.workflow_file}")
return _WORKFLOW_FILE_CACHE[self.workflow_file]
return _WORKFLOW_FILE_CACHE[cache_key]
logger.info(f"Fetching workflow {self.workflow_file} from GitHub")
workflow = self.repo.get_workflow(self.workflow_file)
_WORKFLOW_FILE_CACHE[self.workflow_file] = workflow
_WORKFLOW_FILE_CACHE[cache_key] = workflow
return workflow

async def trigger(self, inputs: dict) -> bool:
Expand Down
Loading