From 637260a32da078f9b8a5cf8fd69788cc56b9de53 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:30:17 +0000 Subject: [PATCH 1/3] Initial plan From a0f11f2822454cd5be213ccc48c7898f3ca6e743 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:46:13 +0000 Subject: [PATCH 2/3] Add checks command for CI state classification (#issue) Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- cmd/gh-aw/main.go | 3 + pkg/cli/checks_command.go | 356 +++++++++++++++++++++++++++++++++ pkg/cli/checks_command_test.go | 252 +++++++++++++++++++++++ 3 files changed, 611 insertions(+) create mode 100644 pkg/cli/checks_command.go create mode 100644 pkg/cli/checks_command_test.go diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index 590aace28a..04886f2f44 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -614,6 +614,7 @@ Use "` + string(constants.CLIExtensionPrefix) + ` help all" to show help for all completionCmd := cli.NewCompletionCommand() hashCmd := cli.NewHashCommand() projectCmd := cli.NewProjectCommand() + checksCmd := cli.NewChecksCommand() // Assign commands to groups // Setup Commands @@ -642,6 +643,7 @@ Use "` + string(constants.CLIExtensionPrefix) + ` help all" to show help for all logsCmd.GroupID = "analysis" auditCmd.GroupID = "analysis" healthCmd.GroupID = "analysis" + checksCmd.GroupID = "analysis" // Utilities mcpServerCmd.GroupID = "utilities" @@ -669,6 +671,7 @@ Use "` + string(constants.CLIExtensionPrefix) + ` help all" to show help for all rootCmd.AddCommand(logsCmd) rootCmd.AddCommand(auditCmd) rootCmd.AddCommand(healthCmd) + rootCmd.AddCommand(checksCmd) rootCmd.AddCommand(mcpCmd) rootCmd.AddCommand(mcpServerCmd) rootCmd.AddCommand(prCmd) diff --git a/pkg/cli/checks_command.go b/pkg/cli/checks_command.go new file mode 100644 index 0000000000..10971f12d9 --- /dev/null +++ b/pkg/cli/checks_command.go @@ -0,0 +1,356 @@ +package cli + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "strings" + + "github.com/github/gh-aw/pkg/console" + "github.com/github/gh-aw/pkg/constants" + "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/workflow" + "github.com/spf13/cobra" +) + +var checksLog = logger.New("cli:checks_command") + +// CheckState represents the normalized CI state for a PR. +type CheckState string + +const ( + // CheckStateFailed indicates one or more checks failed. + CheckStateFailed CheckState = "failed" + // CheckStatePending indicates checks are still running. + CheckStatePending CheckState = "pending" + // CheckStateNoChecks indicates no checks have been configured or triggered. + CheckStateNoChecks CheckState = "no_checks" + // CheckStatePolicyBlocked indicates policy or account gates are blocking the PR. + CheckStatePolicyBlocked CheckState = "policy_blocked" + // CheckStateSuccess indicates all checks passed. + CheckStateSuccess CheckState = "success" +) + +// ChecksConfig holds configuration for the checks command. +type ChecksConfig struct { + Repo string + PRNumber string + JSONOutput bool +} + +// PRCheckRun represents a single check run from the GitHub API. +type PRCheckRun struct { + Name string `json:"name"` + Status string `json:"status"` + Conclusion string `json:"conclusion"` + HTMLURL string `json:"html_url"` +} + +// PRCommitStatus represents a single commit status from the GitHub API. +type PRCommitStatus struct { + State string `json:"state"` + Description string `json:"description"` + Context string `json:"context"` + TargetURL string `json:"target_url"` +} + +// ChecksResult is the normalized output for the checks command. +type ChecksResult struct { + State CheckState `json:"state"` + PRNumber string `json:"pr_number"` + HeadSHA string `json:"head_sha"` + CheckRuns []PRCheckRun `json:"check_runs"` + Statuses []PRCommitStatus `json:"statuses"` + TotalCount int `json:"total_count"` +} + +// NewChecksCommand creates the checks command. +func NewChecksCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "checks ", + Short: "Classify CI check state for a pull request", + Long: `Classify CI check state for a pull request and emit a normalized result. + +Maps PR check rollups to one of the following normalized states: + success - all checks passed + failed - one or more checks failed + pending - checks are still running or queued + no_checks - no checks configured or triggered + policy_blocked - policy or account gates are blocking the PR + +` + "Raw check run and commit status signals are included in JSON output." + ` + +Examples: + ` + string(constants.CLIExtensionPrefix) + ` checks 42 # Classify checks for PR #42 + ` + string(constants.CLIExtensionPrefix) + ` checks 42 --repo owner/repo # Specify repository + ` + string(constants.CLIExtensionPrefix) + ` checks 42 --json # Output in JSON format`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + repo, _ := cmd.Flags().GetString("repo") + jsonOutput, _ := cmd.Flags().GetBool("json") + + config := ChecksConfig{ + Repo: repo, + PRNumber: args[0], + JSONOutput: jsonOutput, + } + + return RunChecks(config) + }, + } + + addRepoFlag(cmd) + addJSONFlag(cmd) + + return cmd +} + +// RunChecks executes the checks command with the given configuration. +func RunChecks(config ChecksConfig) error { + checksLog.Printf("Running checks: pr=%s, repo=%s", config.PRNumber, config.Repo) + + result, err := FetchChecksResult(config.Repo, config.PRNumber) + if err != nil { + fmt.Fprintln(os.Stderr, console.FormatErrorMessage(err.Error())) + return fmt.Errorf("failed to fetch check state for PR %s: %w", config.PRNumber, err) + } + + if config.JSONOutput { + return printChecksJSON(result) + } + + return printChecksText(result) +} + +// FetchChecksResult fetches check runs and statuses for a PR and returns a classified result. +// This function is exported for use in tests and other packages. +func FetchChecksResult(repoOverride string, prNumber string) (*ChecksResult, error) { + checksLog.Printf("Fetching checks result: repo=%s, pr=%s", repoOverride, prNumber) + + // Step 1: Resolve head SHA from PR + headSHA, err := fetchPRHeadSHA(repoOverride, prNumber) + if err != nil { + return nil, fmt.Errorf("failed to fetch PR head SHA: %w", err) + } + checksLog.Printf("Resolved head SHA: %s", headSHA) + + // Step 2: Fetch check runs + checkRuns, err := fetchCheckRuns(repoOverride, headSHA) + if err != nil { + // Non-fatal: continue with empty check runs + checksLog.Printf("Failed to fetch check runs: %v", err) + checkRuns = []PRCheckRun{} + } + + // Step 3: Fetch commit statuses + statuses, err := fetchCommitStatuses(repoOverride, headSHA) + if err != nil { + // Non-fatal: continue with empty statuses + checksLog.Printf("Failed to fetch commit statuses: %v", err) + statuses = []PRCommitStatus{} + } + + state := classifyCheckState(checkRuns, statuses) + + return &ChecksResult{ + State: state, + PRNumber: prNumber, + HeadSHA: headSHA, + CheckRuns: checkRuns, + Statuses: statuses, + TotalCount: len(checkRuns) + len(statuses), + }, nil +} + +// fetchPRHeadSHA fetches the head commit SHA for a given PR. +func fetchPRHeadSHA(repoOverride string, prNumber string) (string, error) { + args := []string{"api", "repos/{owner}/{repo}/pulls/" + prNumber, "--jq", ".head.sha"} + if repoOverride != "" { + args = append(args, "--repo", repoOverride) + } + + cmd := workflow.ExecGH(args...) + output, err := cmd.Output() + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return "", fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + } + return "", fmt.Errorf("gh api call failed: %w", err) + } + + sha := strings.TrimSpace(string(output)) + if sha == "" { + return "", fmt.Errorf("empty SHA returned for PR %s", prNumber) + } + return sha, nil +} + +// checkRunsAPIResponse is the envelope returned by the check-runs endpoint. +type checkRunsAPIResponse struct { + TotalCount int `json:"total_count"` + CheckRuns []PRCheckRun `json:"check_runs"` +} + +// fetchCheckRuns fetches check runs for a commit SHA. +func fetchCheckRuns(repoOverride string, sha string) ([]PRCheckRun, error) { + args := []string{"api", "repos/{owner}/{repo}/commits/" + sha + "/check-runs", "--paginate"} + if repoOverride != "" { + args = append(args, "--repo", repoOverride) + } + + cmd := workflow.ExecGH(args...) + output, err := cmd.Output() + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return nil, fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + } + return nil, fmt.Errorf("gh api call failed: %w", err) + } + + var resp checkRunsAPIResponse + if err := json.Unmarshal(output, &resp); err != nil { + return nil, fmt.Errorf("failed to parse check runs response: %w", err) + } + + return resp.CheckRuns, nil +} + +// commitStatusAPIResponse is the envelope returned by the statuses endpoint. +type commitStatusAPIResponse struct { + State string `json:"state"` + Statuses []PRCommitStatus `json:"statuses"` +} + +// fetchCommitStatuses fetches commit statuses (legacy Status API) for a commit SHA. +func fetchCommitStatuses(repoOverride string, sha string) ([]PRCommitStatus, error) { + args := []string{"api", "repos/{owner}/{repo}/commits/" + sha + "/status"} + if repoOverride != "" { + args = append(args, "--repo", repoOverride) + } + + cmd := workflow.ExecGH(args...) + output, err := cmd.Output() + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return nil, fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + } + return nil, fmt.Errorf("gh api call failed: %w", err) + } + + var resp commitStatusAPIResponse + if err := json.Unmarshal(output, &resp); err != nil { + return nil, fmt.Errorf("failed to parse commit status response: %w", err) + } + + return resp.Statuses, nil +} + +// policyCheckPatterns are patterns that indicate a policy/account-gate check rather than a +// product failure. These names come from GitHub's branch-protection rule enforcement. +var policyCheckPatterns = []string{ + "required status check", + "branch protection", + "mergeability", + "repo policy", + "policy check", + "access control", +} + +// isPolicyCheck returns true if the check run name looks like a policy/account-gate check. +func isPolicyCheck(name string) bool { + lower := strings.ToLower(name) + for _, pattern := range policyCheckPatterns { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} + +// classifyCheckState derives a normalized CheckState from raw check runs and commit statuses. +func classifyCheckState(checkRuns []PRCheckRun, statuses []PRCommitStatus) CheckState { + if len(checkRuns) == 0 && len(statuses) == 0 { + return CheckStateNoChecks + } + + hasPending := false + hasFailed := false + hasPolicyBlocked := false + + for _, cr := range checkRuns { + switch cr.Status { + case "queued", "in_progress", "waiting", "requested", "pending": + hasPending = true + case "completed": + switch cr.Conclusion { + case "failure", "timed_out", "startup_failure": + if isPolicyCheck(cr.Name) { + hasPolicyBlocked = true + } else { + hasFailed = true + } + case "action_required": + hasPolicyBlocked = true + } + } + } + + for _, s := range statuses { + switch s.State { + case "pending": + hasPending = true + case "failure", "error": + if isPolicyCheck(s.Context) { + hasPolicyBlocked = true + } else { + hasFailed = true + } + } + } + + switch { + case hasPolicyBlocked && !hasFailed && !hasPending: + return CheckStatePolicyBlocked + case hasFailed: + return CheckStateFailed + case hasPending: + return CheckStatePending + default: + return CheckStateSuccess + } +} + +// printChecksJSON prints the result as JSON to stdout. +func printChecksJSON(result *ChecksResult) error { + enc := json.NewEncoder(os.Stdout) + enc.SetIndent("", " ") + if err := enc.Encode(result); err != nil { + return fmt.Errorf("failed to encode JSON output: %w", err) + } + return nil +} + +// printChecksText prints the result in human-readable form to stderr. +func printChecksText(result *ChecksResult) error { + switch result.State { + case CheckStateSuccess: + fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("PR #%s: all checks passed (%d total)", result.PRNumber, result.TotalCount))) + case CheckStateFailed: + fmt.Fprintln(os.Stderr, console.FormatErrorMessage(fmt.Sprintf("PR #%s: checks failed (%d total)", result.PRNumber, result.TotalCount))) + case CheckStatePending: + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("PR #%s: checks pending (%d total)", result.PRNumber, result.TotalCount))) + case CheckStateNoChecks: + fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf("PR #%s: no checks configured or triggered", result.PRNumber))) + case CheckStatePolicyBlocked: + fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf("PR #%s: blocked by policy or account gate (%d total)", result.PRNumber, result.TotalCount))) + } + + // Always print the normalized state to stdout for machine consumption. + fmt.Println(string(result.State)) + return nil +} diff --git a/pkg/cli/checks_command_test.go b/pkg/cli/checks_command_test.go new file mode 100644 index 0000000000..464f4d813b --- /dev/null +++ b/pkg/cli/checks_command_test.go @@ -0,0 +1,252 @@ +//go:build !integration + +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// classifyCheckState – fixture-based unit tests +// --------------------------------------------------------------------------- + +func TestClassifyCheckState_NoChecks(t *testing.T) { + state := classifyCheckState([]PRCheckRun{}, []PRCommitStatus{}) + assert.Equal(t, CheckStateNoChecks, state, "empty check runs and statuses should yield no_checks") +} + +func TestClassifyCheckState_AllSuccess(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "success"}, + {Name: "lint", Status: "completed", Conclusion: "success"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStateSuccess, state, "all successful check runs should yield success") +} + +func TestClassifyCheckState_Failed(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "success"}, + {Name: "test", Status: "completed", Conclusion: "failure"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStateFailed, state, "at least one failed check run should yield failed") +} + +func TestClassifyCheckState_Pending(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "success"}, + {Name: "test", Status: "in_progress", Conclusion: ""}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStatePending, state, "in-progress check run should yield pending") +} + +func TestClassifyCheckState_Queued(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "queued", Conclusion: ""}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStatePending, state, "queued check run should yield pending") +} + +func TestClassifyCheckState_PolicyBlocked(t *testing.T) { + runs := []PRCheckRun{ + {Name: "Branch protection rule check", Status: "completed", Conclusion: "failure"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStatePolicyBlocked, state, "branch protection rule failure should yield policy_blocked") +} + +func TestClassifyCheckState_PolicyBlockedActionRequired(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "success"}, + {Name: "required status check", Status: "completed", Conclusion: "action_required"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStatePolicyBlocked, state, "action_required on policy check should yield policy_blocked") +} + +func TestClassifyCheckState_PolicyBlockedWithFailures(t *testing.T) { + // If both a policy check and a real failure are present, failed takes priority. + runs := []PRCheckRun{ + {Name: "required status check", Status: "completed", Conclusion: "failure"}, + {Name: "test suite", Status: "completed", Conclusion: "failure"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStateFailed, state, "real failure alongside policy check should yield failed, not policy_blocked") +} + +func TestClassifyCheckState_CommitStatusNoChecks(t *testing.T) { + state := classifyCheckState(nil, []PRCommitStatus{}) + assert.Equal(t, CheckStateNoChecks, state, "empty commit statuses should yield no_checks") +} + +func TestClassifyCheckState_CommitStatusPending(t *testing.T) { + statuses := []PRCommitStatus{ + {Context: "ci/circleci", State: "pending"}, + } + state := classifyCheckState(nil, statuses) + assert.Equal(t, CheckStatePending, state, "pending commit status should yield pending") +} + +func TestClassifyCheckState_CommitStatusFailed(t *testing.T) { + statuses := []PRCommitStatus{ + {Context: "ci/circleci", State: "failure"}, + } + state := classifyCheckState(nil, statuses) + assert.Equal(t, CheckStateFailed, state, "failure commit status should yield failed") +} + +func TestClassifyCheckState_CommitStatusError(t *testing.T) { + statuses := []PRCommitStatus{ + {Context: "ci/circleci", State: "error"}, + } + state := classifyCheckState(nil, statuses) + assert.Equal(t, CheckStateFailed, state, "error commit status should yield failed") +} + +func TestClassifyCheckState_CommitStatusSuccess(t *testing.T) { + statuses := []PRCommitStatus{ + {Context: "ci/circleci", State: "success"}, + } + state := classifyCheckState(nil, statuses) + assert.Equal(t, CheckStateSuccess, state, "success commit status should yield success") +} + +func TestClassifyCheckState_MixedRunsAndStatuses(t *testing.T) { + runs := []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "success"}, + } + statuses := []PRCommitStatus{ + {Context: "ci/circleci", State: "pending"}, + } + state := classifyCheckState(runs, statuses) + assert.Equal(t, CheckStatePending, state, "pending status with successful run should yield pending") +} + +func TestClassifyCheckState_TimedOut(t *testing.T) { + runs := []PRCheckRun{ + {Name: "slow-test", Status: "completed", Conclusion: "timed_out"}, + } + state := classifyCheckState(runs, nil) + assert.Equal(t, CheckStateFailed, state, "timed_out should yield failed") +} + +// --------------------------------------------------------------------------- +// isPolicyCheck – pattern matching tests +// --------------------------------------------------------------------------- + +func TestIsPolicyCheck(t *testing.T) { + tests := []struct { + name string + checkName string + expected bool + }{ + { + name: "branch protection pattern", + checkName: "Branch protection rule check", + expected: true, + }, + { + name: "required status check pattern", + checkName: "Required status check", + expected: true, + }, + { + name: "mergeability pattern", + checkName: "Mergeability check", + expected: true, + }, + { + name: "policy check pattern", + checkName: "policy check for org", + expected: true, + }, + { + name: "normal test run", + checkName: "unit tests", + expected: false, + }, + { + name: "build check", + checkName: "build / linux", + expected: false, + }, + { + name: "empty string", + checkName: "", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isPolicyCheck(tt.checkName) + assert.Equal(t, tt.expected, got, "isPolicyCheck(%q) should return %v", tt.checkName, tt.expected) + }) + } +} + +// --------------------------------------------------------------------------- +// NewChecksCommand – command shape tests +// --------------------------------------------------------------------------- + +func TestChecksCommand(t *testing.T) { + cmd := NewChecksCommand() + require.NotNil(t, cmd, "checks command should not be nil") + assert.Equal(t, "checks", cmd.Name(), "command name should be 'checks'") + assert.True(t, cmd.HasAvailableFlags(), "command should expose flags") + + repoFlag := cmd.Flags().Lookup("repo") + require.NotNil(t, repoFlag, "should have --repo flag") + assert.Empty(t, repoFlag.DefValue, "--repo default should be empty") + + jsonFlag := cmd.Flags().Lookup("json") + require.NotNil(t, jsonFlag, "should have --json flag") + assert.Equal(t, "false", jsonFlag.DefValue, "--json default should be false") +} + +func TestChecksCommand_RequiresArg(t *testing.T) { + cmd := NewChecksCommand() + err := cmd.Args(cmd, []string{}) + assert.Error(t, err, "checks command should require exactly one argument") +} + +func TestChecksCommand_AcceptsOneArg(t *testing.T) { + cmd := NewChecksCommand() + err := cmd.Args(cmd, []string{"42"}) + assert.NoError(t, err, "checks command should accept exactly one argument") +} + +func TestChecksCommand_RejectsMultipleArgs(t *testing.T) { + cmd := NewChecksCommand() + err := cmd.Args(cmd, []string{"42", "43"}) + assert.Error(t, err, "checks command should reject more than one argument") +} + +// --------------------------------------------------------------------------- +// ChecksResult JSON serialization +// --------------------------------------------------------------------------- + +func TestChecksResultJSONShape(t *testing.T) { + result := &ChecksResult{ + State: CheckStateFailed, + PRNumber: "42", + HeadSHA: "abc123", + CheckRuns: []PRCheckRun{ + {Name: "build", Status: "completed", Conclusion: "failure", HTMLURL: "https://example.com"}, + }, + Statuses: []PRCommitStatus{}, + TotalCount: 1, + } + + require.Equal(t, CheckStateFailed, result.State, "state should be failed") + require.Equal(t, "42", result.PRNumber, "PR number should be preserved") + require.Equal(t, "abc123", result.HeadSHA, "head SHA should be preserved") + require.Len(t, result.CheckRuns, 1, "should have one check run") + assert.Equal(t, "build", result.CheckRuns[0].Name, "check run name should be preserved") +} From 3c806798906182c1bddb8ecec2fe47817060f66f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 15:12:29 +0000 Subject: [PATCH 3/3] Handle PR not found (404) and auth errors (403) with console helpers Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- pkg/cli/checks_command.go | 56 ++++++++++++++++++++++++++++++---- pkg/cli/checks_command_test.go | 49 +++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 6 deletions(-) diff --git a/pkg/cli/checks_command.go b/pkg/cli/checks_command.go index 10971f12d9..2601c3d1ef 100644 --- a/pkg/cli/checks_command.go +++ b/pkg/cli/checks_command.go @@ -113,8 +113,7 @@ func RunChecks(config ChecksConfig) error { result, err := FetchChecksResult(config.Repo, config.PRNumber) if err != nil { - fmt.Fprintln(os.Stderr, console.FormatErrorMessage(err.Error())) - return fmt.Errorf("failed to fetch check state for PR %s: %w", config.PRNumber, err) + return err } if config.JSONOutput { @@ -176,18 +175,61 @@ func fetchPRHeadSHA(repoOverride string, prNumber string) (string, error) { if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - return "", fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + stderr := strings.TrimSpace(string(exitErr.Stderr)) + return "", classifyGHAPIError(exitErr.ExitCode(), stderr, prNumber, repoOverride) } return "", fmt.Errorf("gh api call failed: %w", err) } sha := strings.TrimSpace(string(output)) if sha == "" { - return "", fmt.Errorf("empty SHA returned for PR %s", prNumber) + return "", errors.New(console.FormatErrorWithSuggestions( + "PR #"+prNumber+" returned an empty SHA", + []string{ + "Verify that PR #" + prNumber + " exists and is accessible", + "Check that the --repo flag points to the correct repository", + }, + )) } return sha, nil } +// classifyGHAPIError converts a gh API exit error into a user-friendly, pre-formatted error. +func classifyGHAPIError(exitCode int, stderr string, prNumber string, repo string) error { + checksLog.Printf("API error: exitCode=%d, stderr=%s", exitCode, stderr) + + lower := strings.ToLower(stderr) + + switch { + case strings.Contains(lower, "404") || strings.Contains(lower, "not found"): + repoHint := "the current repository" + if repo != "" { + repoHint = repo + } + return errors.New(console.FormatErrorWithSuggestions( + fmt.Sprintf("PR #%s not found in %s", prNumber, repoHint), + []string{ + "Verify that the pull request number is correct", + "Use --repo owner/repo to specify the target repository explicitly", + "Ensure you have read access to the repository", + }, + )) + case strings.Contains(lower, "403") || strings.Contains(lower, "forbidden") || + strings.Contains(lower, "bad credentials") || strings.Contains(lower, "401") || + strings.Contains(lower, "unauthorized"): + return errors.New(console.FormatErrorWithSuggestions( + "GitHub API authentication failed", + []string{ + "Run 'gh auth login' to authenticate with GitHub", + "Ensure your token has the 'repo' scope for private repositories", + "Check that GH_TOKEN or GITHUB_TOKEN is set correctly if using environment variables", + }, + )) + default: + return fmt.Errorf("gh api call failed (exit %d): %s", exitCode, stderr) + } +} + // checkRunsAPIResponse is the envelope returned by the check-runs endpoint. type checkRunsAPIResponse struct { TotalCount int `json:"total_count"` @@ -206,7 +248,8 @@ func fetchCheckRuns(repoOverride string, sha string) ([]PRCheckRun, error) { if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - return nil, fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + stderr := strings.TrimSpace(string(exitErr.Stderr)) + return nil, classifyGHAPIError(exitErr.ExitCode(), stderr, sha, repoOverride) } return nil, fmt.Errorf("gh api call failed: %w", err) } @@ -237,7 +280,8 @@ func fetchCommitStatuses(repoOverride string, sha string) ([]PRCommitStatus, err if err != nil { var exitErr *exec.ExitError if errors.As(err, &exitErr) { - return nil, fmt.Errorf("gh api call failed (exit %d): %s", exitErr.ExitCode(), strings.TrimSpace(string(exitErr.Stderr))) + stderr := strings.TrimSpace(string(exitErr.Stderr)) + return nil, classifyGHAPIError(exitErr.ExitCode(), stderr, sha, repoOverride) } return nil, fmt.Errorf("gh api call failed: %w", err) } diff --git a/pkg/cli/checks_command_test.go b/pkg/cli/checks_command_test.go index 464f4d813b..ac2a1eaa45 100644 --- a/pkg/cli/checks_command_test.go +++ b/pkg/cli/checks_command_test.go @@ -250,3 +250,52 @@ func TestChecksResultJSONShape(t *testing.T) { require.Len(t, result.CheckRuns, 1, "should have one check run") assert.Equal(t, "build", result.CheckRuns[0].Name, "check run name should be preserved") } + +// --------------------------------------------------------------------------- +// classifyGHAPIError – error classification tests +// --------------------------------------------------------------------------- + +func TestClassifyGHAPIError_NotFound(t *testing.T) { + err := classifyGHAPIError(1, "HTTP 404: Not Found", "42", "") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "not found", "error should mention not found") + assert.Contains(t, msg, "#42", "error should mention PR number") + assert.Contains(t, msg, "current repository", "error should mention current repository when no repo override") +} + +func TestClassifyGHAPIError_NotFoundWithRepo(t *testing.T) { + err := classifyGHAPIError(1, "HTTP 404: Not Found", "99", "myorg/myrepo") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "myorg/myrepo", "error should mention the specified repo") +} + +func TestClassifyGHAPIError_Forbidden(t *testing.T) { + err := classifyGHAPIError(1, "HTTP 403: Forbidden", "42", "") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "authentication failed", "error should mention auth failure") + assert.Contains(t, msg, "gh auth login", "error should suggest running gh auth login") +} + +func TestClassifyGHAPIError_Unauthorized(t *testing.T) { + err := classifyGHAPIError(1, "HTTP 401: Unauthorized (Bad credentials)", "42", "") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "authentication failed", "error should mention auth failure") +} + +func TestClassifyGHAPIError_BadCredentials(t *testing.T) { + err := classifyGHAPIError(1, "Bad credentials", "42", "") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "authentication failed", "bad credentials should yield auth error") +} + +func TestClassifyGHAPIError_Generic(t *testing.T) { + err := classifyGHAPIError(1, "HTTP 500: Internal Server Error", "42", "") + require.Error(t, err, "should return an error") + msg := err.Error() + assert.Contains(t, msg, "gh api call failed", "generic errors should surface exit code message") +}