Skip to content

Commit

Permalink
feat(IDX): create bot check code
Browse files Browse the repository at this point in the history
  • Loading branch information
cgundy committed Dec 4, 2024
1 parent fbdc94c commit 15a003d
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,67 +1,84 @@
import json
import os
import subprocess

import github3

from check_membership.check_membership import is_approved_bot
from shared.utils import download_gh_file, load_env_vars

BOT_APPROVED_FILES_PATH = ".github/repo_policies/bot_approved_files.json"
REQUIRED_VARS = [
"USER",
"GH_TOKEN",
"GH_ORG",
"REPO",
"MERGE_BASE_SHA",
"BRANCH_HEAD_SHA",
]

def get_changed_files() -> list[str]:
merge_base_sha = os.environ["MERGE_BASE_SHA"]
branch_head_sha = os.environ["BRANCH_HEAD_SHA"]

def get_changed_files(merge_base_sha: str, branch_head_sha: str) -> list[str]:
commit_range = f"{merge_base_sha}..{branch_head_sha}"
result = subprocess.run(['git', 'diff', '--name-only', commit_range], stdout=subprocess.PIPE, text=True)
changed_files = result.stdout.strip().split('\n')
result = subprocess.run(
["git", "diff", "--name-only", commit_range], stdout=subprocess.PIPE, text=True
)
changed_files = result.stdout.strip().split("\n")
return changed_files


def get_approved_files_config(repo: github3.github.repo) -> str:
bot_approved_files_list = [
".github/repo_policies/bot_approved_files.json",
]
for path in bot_approved_files_list:
try:
config_file = repo.file_contents(path)
return config_file
except github3.exceptions.NotFoundError:
print("No config file found")
raise Exception("No config file found")
try:
config_file = download_gh_file(repo, BOT_APPROVED_FILES_PATH)
return config_file
except github3.exceptions.NotFoundError:
raise Exception(
f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}"
)


def get_approved_files(config_file: str) -> list[str]:
with open(config_file, 'r') as f:
data = json.load(f)
try:
approved_files = data["approved_files"]
config = json.loads(config_file)
except json.JSONDecodeError:
raise Exception("Config file is not a valid JSON file")
try:
approved_files = config["approved_files"]
except KeyError:
raise Exception("No approved_files key found in config file")

if len(approved_files) == 0:
print("No approved files found in config file")
raise Exception("No approved files found in config file")
return approved_files
return approved_files


def pr_is_blocked(env_vars: dict) -> bool:
gh = github3.login(token=env_vars["GH_TOKEN"])
repo = gh.repository(owner=env_vars["GH_ORG"], repository=env_vars["REPO"])
changed_files = get_changed_files(
env_vars["MERGE_BASE_SHA"], env_vars["BRANCH_HEAD_SHA"]
)
config = get_approved_files_config(repo)
approved_files = get_approved_files(config)
block_pr = not all(file in approved_files for file in changed_files)
return block_pr


def main() -> None:
env_vars = load_env_vars(REQUIRED_VARS)
user = env_vars["USER"]

gh_token = os.environ["GH_TOKEN"]
org_name = os.environ["GH_ORG"]
repo_name = os.environ["REPO"]
user = os.environ["USER"]
is_bot = is_approved_bot(user)

if is_bot:
gh = github3.login(token=gh_token)
repo = gh.repository(owner=org_name, repository=repo_name)
changed_files = get_changed_files()
config = get_approved_files_config(repo)
approved_files = get_approved_files(config)
block_pr = not all(file in approved_files for file in changed_files)
block_pr = pr_is_blocked(env_vars)

else:
print(f"{user} is not an approved bot. Letting CLA check handle contribution decision.")
print(
f"{user} is not an approved bot. Letting CLA check handle contribution decision."
)
block_pr = False

os.system(f"""echo 'block_pr={block_pr}' >> $GITHUB_OUTPUT""")
subprocess.run(f"""echo 'block_pr={block_pr}' >> $GITHUB_OUTPUT""", shell=True)


if __name__ == "__main__":
Expand Down
11 changes: 11 additions & 0 deletions reusable_workflows/shared/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time

import github3
Expand All @@ -22,3 +23,13 @@ def download_gh_file(repo: github3.github.repo, file_path: str) -> str:

file_decoded = file_content.decoded.decode()
return file_decoded


def load_env_vars(var_names: list[str]) -> dict:
env_vars = {}
for var in var_names:
try:
env_vars[var] = os.environ[var]
except KeyError:
raise Exception(f"Environment variable '{var}' is not set.")
return env_vars
161 changes: 161 additions & 0 deletions reusable_workflows/tests/test_repo_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import subprocess
from unittest import mock

import github3
import pytest

from repo_policies.bot_checks.check_bot_approved_files import (
BOT_APPROVED_FILES_PATH,
get_approved_files,
get_approved_files_config,
get_changed_files,
main,
pr_is_blocked,
)


@mock.patch("subprocess.run")
def test_get_changed_files(mock_subprocess_run):
mock_subprocess_run.return_value = mock.Mock(stdout="file1.py\nfile2.py\n")

changed_files = get_changed_files("merge_base_sha", "branch_head_sha")

assert changed_files == ["file1.py", "file2.py"]
mock_subprocess_run.assert_called_once_with(
["git", "diff", "--name-only", "merge_base_sha..branch_head_sha"],
stdout=subprocess.PIPE,
text=True,
)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file")
def test_get_approved_files_config(download_gh_file):
repo = mock.Mock()
config_file_mock = mock.Mock()
download_gh_file.return_value = config_file_mock

config_file = get_approved_files_config(repo)

download_gh_file.assert_called_once_with(repo, BOT_APPROVED_FILES_PATH)
assert config_file == config_file_mock


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.download_gh_file")
def test_get_approved_files_config_fails(download_gh_file):
repo = mock.Mock()
download_gh_file.side_effect = github3.exceptions.NotFoundError(mock.Mock())

with pytest.raises(Exception) as exc:
get_approved_files_config(repo)

assert (
str(exc.value)
== f"No config file found. Make sure you have a file saved at {BOT_APPROVED_FILES_PATH}"
)


def test_get_approved_files():
config_file = '{"approved_files": ["file1.py", "file2.py"]}'
approved_files = get_approved_files(config_file)

assert approved_files == ["file1.py", "file2.py"]


def test_get_approved_files_not_json():
config_file = "not a json file"

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "Config file is not a valid JSON file"


def test_get_approved_files_no_approved_files():
config_file = '{"another_key": ["file1.py", "file2.py"]}'

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "No approved_files key found in config file"


def test_get_approved_files_no_files():
config_file = '{"approved_files": []}'

with pytest.raises(Exception) as exc:
get_approved_files(config_file)

assert str(exc.value) == "No approved files found in config file"


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files")
@mock.patch(
"repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config"
)
@mock.patch("github3.login")
def test_pr_is_blocked_false(gh_login, get_approved_files_config, get_changed_files):
env_vars = {
"GH_TOKEN": "token",
"GH_ORG": "org",
"REPO": "repo",
"MERGE_BASE_SHA": "base",
"BRANCH_HEAD_SHA": "head",
}
gh = mock.Mock()
gh_login.return_value = gh
repo = mock.Mock()
gh.repository.return_value = repo
get_changed_files.return_value = ["file1.py", "file2.py"]
get_approved_files_config.return_value = (
'{"approved_files": ["file1.py", "file2.py", "file3.py"]}'
)

blocked = pr_is_blocked(env_vars)

assert blocked == False
get_changed_files.assert_called_once_with("base", "head")
get_approved_files_config.assert_called_once_with(repo)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.get_changed_files")
@mock.patch(
"repo_policies.bot_checks.check_bot_approved_files.get_approved_files_config"
)
@mock.patch("github3.login")
def test_pr_is_blocked_true(gh_login, get_approved_files_config, get_changed_files):
env_vars = {
"GH_TOKEN": "token",
"GH_ORG": "org",
"REPO": "repo",
"MERGE_BASE_SHA": "base",
"BRANCH_HEAD_SHA": "head",
}
gh = mock.Mock()
gh_login.return_value = gh
repo = mock.Mock()
gh.repository.return_value = repo
get_changed_files.return_value = ["file1.py", "file2.py"]
get_approved_files_config.return_value = '{"approved_files": ["file1.py"]}'

blocked = pr_is_blocked(env_vars)

assert blocked == True
get_changed_files.assert_called_once_with("base", "head")
get_approved_files_config.assert_called_once_with(repo)


@mock.patch("repo_policies.bot_checks.check_bot_approved_files.load_env_vars")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.is_approved_bot")
@mock.patch("repo_policies.bot_checks.check_bot_approved_files.pr_is_blocked")
@mock.patch("subprocess.run")
def test_main_succeeds(subprocess_run, pr_is_blocked, is_approved_bot, load_env_vars):
env_vars = {"GH_TOKEN": "token", "USER": "user"}
load_env_vars.return_value = env_vars
is_approved_bot.return_value = True
pr_is_blocked.return_value = False

main()

subprocess_run.assert_called_once_with(
"echo 'block_pr=False' >> $GITHUB_OUTPUT", shell=True
)
18 changes: 17 additions & 1 deletion reusable_workflows/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from unittest import mock

import pytest

from shared.utils import download_gh_file
from shared.utils import download_gh_file, load_env_vars


def test_download_file_succeeds_first_try():
Expand Down Expand Up @@ -46,3 +47,18 @@ def test_download_file_fails(mock_get):

assert repo.file_contents.call_count == 5
file_content_obj.decoded.assert_not_called


@mock.patch.dict(os.environ, {"REPO": "repo-1", "GH_TOKEN": "token"})
def test_load_env_vars_succeeds(capfd):
env_vars = load_env_vars(["REPO", "GH_TOKEN"])

assert env_vars == {"REPO": "repo-1", "GH_TOKEN": "token"}


@mock.patch.dict(os.environ, {"REPO": "repo-1"}, clear=True)
def test_load_env_vars_fails(capfd):
with pytest.raises(Exception) as exc:
env_vars = load_env_vars(["REPO", "GH_TOKEN"])
print(env_vars)
assert str(exc.value) == "Environment variable 'GH_TOKEN' is not set."

0 comments on commit 15a003d

Please sign in to comment.