diff --git a/app/(dashboard)/dashboard/pull-request.test.tsx b/app/(dashboard)/dashboard/pull-request.test.tsx
index 06ed0e6f..bf366543 100644
--- a/app/(dashboard)/dashboard/pull-request.test.tsx
+++ b/app/(dashboard)/dashboard/pull-request.test.tsx
@@ -8,6 +8,7 @@ import { generateTestsResponseSchema } from "@/app/api/generate-tests/schema";
vi.mock('@/lib/github', () => ({
getPullRequestInfo: vi.fn(),
commitChangesToPullRequest: vi.fn(),
+ getFailingTests: vi.fn(),
}));
vi.mock('@/hooks/use-toast', () => ({
@@ -56,7 +57,7 @@ describe('PullRequestItem', () => {
expect(screen.getByText('Build: success')).toBeInTheDocument();
});
- it('handles "Write new tests" button click', async () => {
+ it('handles "Write new tests" button click for successful build', async () => {
const { getPullRequestInfo } = await import('@/lib/github');
vi.mocked(getPullRequestInfo).mockResolvedValue({
diff: 'mock diff',
@@ -68,7 +69,7 @@ describe('PullRequestItem', () => {
json: () => Promise.resolve([{ name: 'generated_test.ts', content: 'generated content' }]),
} as Response);
- render();
+ render();
const writeTestsButton = screen.getByText('Write new tests');
fireEvent.click(writeTestsButton);
@@ -82,17 +83,23 @@ describe('PullRequestItem', () => {
});
});
- it('handles "Update tests to fix" button click', async () => {
+ it('handles "Update tests to fix" button click for failed build', async () => {
const failedPR = { ...mockPullRequest, buildStatus: 'failure' };
- const { getPullRequestInfo } = await import('@/lib/github');
+ const { getPullRequestInfo, getFailingTests } = await import('@/lib/github');
vi.mocked(getPullRequestInfo).mockResolvedValue({
diff: 'mock diff',
- testFiles: [{ name: 'test.ts', content: 'test content' }],
+ testFiles: [
+ { name: 'test1.ts', content: 'test content 1' },
+ { name: 'test2.ts', content: 'test content 2' },
+ ],
});
+ vi.mocked(getFailingTests).mockResolvedValue([
+ { name: 'test1.ts', content: 'failing test content' },
+ ]);
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
- json: () => Promise.resolve([{ name: 'fixed_test.ts', content: 'fixed content' }]),
+ json: () => Promise.resolve([{ name: 'test1.ts', content: 'fixed content' }]),
} as Response);
render();
@@ -104,7 +111,8 @@ describe('PullRequestItem', () => {
});
await waitFor(() => {
- expect(screen.getByText('fixed_test.ts')).toBeInTheDocument();
+ expect(screen.getByText('test1.ts')).toBeInTheDocument();
+ expect(screen.queryByText('test2.ts')).not.toBeInTheDocument();
expect(screen.getByTestId('react-diff-viewer')).toBeInTheDocument();
});
});
@@ -120,10 +128,6 @@ describe('PullRequestItem', () => {
ok: false,
} as Response);
- const { useToast } = await import('@/hooks/use-toast');
- const mockToast = vi.fn();
- vi.mocked(useToast).mockReturnValue({ toast: mockToast });
-
render();
const writeTestsButton = screen.getByText('Write new tests');
fireEvent.click(writeTestsButton);
@@ -259,4 +263,25 @@ describe('PullRequestItem', () => {
}));
});
});
+
+ it('displays pending build status', () => {
+ const pendingPR = { ...mockPullRequest, buildStatus: 'pending' };
+ render();
+ expect(screen.getByText('Build: pending')).toBeInTheDocument();
+ });
+
+ it('disables buttons when loading', async () => {
+ vi.mocked(global.fetch).mockResolvedValue({
+ ok: true,
+ json: () => new Promise(resolve => setTimeout(() => resolve([]), 100)),
+ } as Response);
+
+ render();
+ const writeTestsButton = screen.getByText('Write new tests');
+ fireEvent.click(writeTestsButton);
+
+ await waitFor(() => {
+ expect(writeTestsButton).toBeDisabled();
+ });
+ });
});
\ No newline at end of file
diff --git a/app/(dashboard)/dashboard/pull-request.tsx b/app/(dashboard)/dashboard/pull-request.tsx
index 55a9591d..d46e2463 100644
--- a/app/(dashboard)/dashboard/pull-request.tsx
+++ b/app/(dashboard)/dashboard/pull-request.tsx
@@ -18,7 +18,7 @@ import dynamic from "next/dynamic";
import { PullRequest, TestFile } from "./types";
import { generateTestsResponseSchema } from "@/app/api/generate-tests/schema";
import { useToast } from "@/hooks/use-toast";
-import { commitChangesToPullRequest, getPullRequestInfo } from "@/lib/github";
+import { commitChangesToPullRequest, getPullRequestInfo, getFailingTests } from "@/lib/github";
const ReactDiffViewer = dynamic(() => import("react-diff-viewer"), {
ssr: false,
@@ -53,6 +53,19 @@ export function PullRequestItem({ pullRequest }: PullRequestItemProps) {
pr.number
);
+ let testFilesToUpdate = oldTestFiles;
+
+ if (mode === "update") {
+ const failingTests = await getFailingTests(
+ pr.repository.owner.login,
+ pr.repository.name,
+ pr.number
+ );
+ testFilesToUpdate = oldTestFiles.filter(file =>
+ failingTests.some(failingFile => failingFile.name === file.name)
+ );
+ }
+
const response = await fetch("/api/generate-tests", {
method: "POST",
headers: {
@@ -62,7 +75,7 @@ export function PullRequestItem({ pullRequest }: PullRequestItemProps) {
mode,
pr_id: pr.id,
pr_diff: diff,
- test_files: oldTestFiles,
+ test_files: testFilesToUpdate,
}),
});
diff --git a/app/api/generate-tests/route.ts b/app/api/generate-tests/route.ts
index 93510fd6..ee01f659 100644
--- a/app/api/generate-tests/route.ts
+++ b/app/api/generate-tests/route.ts
@@ -12,7 +12,7 @@ export async function POST(req: Request) {
const prompt = `You are an expert software engineer. ${
mode === "write"
? "Write entirely new tests and update relevant existing tests in order to reflect the added/edited/removed functionality."
- : "Update existing test files in order to get the PR build back to passing. Make updates to tests solely, do not add or remove tests."
+ : "Update the provided failing test files in order to get the PR build back to passing. Make updates to tests solely, do not add or remove tests."
}
PR Diff:
@@ -20,7 +20,7 @@ export async function POST(req: Request) {
${pr_diff}
- Existing test files:
+ ${mode === "update" ? "Failing test files:" : "Existing test files:"}
${test_files
.map((file) => `${file.name}\n${file.content ? `: ${file.content}` : ""}`)
diff --git a/lib/github.ts b/lib/github.ts
index 069ec40f..13be9e64 100644
--- a/lib/github.ts
+++ b/lib/github.ts
@@ -270,3 +270,59 @@ export async function getPullRequestInfo(
throw new Error("Failed to fetch PR info");
}
}
+
+export async function getFailingTests(
+ owner: string,
+ repo: string,
+ pullNumber: number
+): Promise {
+ const octokit = await getOctokit();
+
+ try {
+ const { data: checkRuns } = await octokit.checks.listForRef({
+ owner,
+ repo,
+ ref: `refs/pull/${pullNumber}/head`,
+ status: 'completed',
+ filter: 'latest',
+ });
+
+ const failedChecks = checkRuns.check_runs.filter(
+ (run) => run.conclusion === 'failure'
+ );
+
+ const failingTestFiles: TestFile[] = [];
+ for (const check of failedChecks) {
+ if (check.output.annotations_count > 0) {
+ const { data: annotations } = await octokit.checks.listAnnotations({
+ owner,
+ repo,
+ check_run_id: check.id,
+ });
+
+ for (const annotation of annotations) {
+ if (annotation.path.includes('test') || annotation.path.includes('spec')) {
+ const { data: fileContent } = await octokit.repos.getContent({
+ owner,
+ repo,
+ path: annotation.path,
+ ref: `refs/pull/${pullNumber}/head`,
+ });
+
+ if ('content' in fileContent) {
+ failingTestFiles.push({
+ name: annotation.path,
+ content: Buffer.from(fileContent.content, 'base64').toString('utf-8'),
+ });
+ }
+ }
+ }
+ }
+ }
+
+ return failingTestFiles;
+ } catch (error) {
+ console.error('Error fetching failing tests:', error);
+ throw error;
+ }
+}