Skip to content

Commit c1d7386

Browse files
Refactor AgentCompletion to take in function parameter. (#282)
* Updating AgentCompletion to require updated parameters from issue. * Adding function parameters and cleaning up TOML. * Refactored UnitTestGenerator with PromptBuilder parameters. * Working UnitTestValidator/Generator. * Adding more detailed failure info. * Working refactor. Moved to Claude for testing. * Working call to PromptBuilder without init. Deleting PromptBuilder test. * Removed more of PromptBuilder. * Adding defaults to ABC. * Adding defaults to DefaultAgentCompeletion.. * Updated docstrings. * Fixing path to test file. * Fixed docstrings. * Moving to tagged version of GPT-4o * Adding more logging.
1 parent 4120936 commit c1d7386

16 files changed

+595
-580
lines changed

cover_agent/AICaller.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def retry_wrapper():
3232

3333

3434
class AICaller:
35-
def __init__(self, model: str, api_base: str = "", enable_retry=True):
35+
def __init__(self, model: str, api_base: str = "", enable_retry=True, max_tokens=16384):
3636
"""
3737
Initializes an instance of the AICaller class.
3838
@@ -43,15 +43,15 @@ def __init__(self, model: str, api_base: str = "", enable_retry=True):
4343
self.model = model
4444
self.api_base = api_base
4545
self.enable_retry = enable_retry
46+
self.max_tokens = max_tokens
4647

4748
@conditional_retry # You can access self.enable_retry here
48-
def call_model(self, prompt: dict, max_tokens=16384, stream=True):
49+
def call_model(self, prompt: dict, stream=True):
4950
"""
5051
Call the language model with the provided prompt and retrieve the response.
5152
5253
Parameters:
5354
prompt (dict): The prompt to be sent to the language model.
54-
max_tokens (int, optional): The maximum number of tokens to generate in the response. Defaults to 16384.
5555
stream (bool, optional): Whether to stream the response or not. Defaults to True.
5656
5757
Returns:
@@ -84,15 +84,15 @@ def call_model(self, prompt: dict, max_tokens=16384, stream=True):
8484
"messages": messages,
8585
"stream": stream, # Use the stream parameter passed to the method
8686
"temperature": 0.2,
87-
"max_tokens": max_tokens,
87+
"max_tokens": self.max_tokens,
8888
}
8989

9090
# Model-specific adjustments
9191
if self.model in ["o1-preview", "o1-mini", "o1", "o3-mini"]:
9292
stream = False # o1 doesn't support streaming
9393
completion_params["temperature"] = 1
9494
completion_params["stream"] = False # o1 doesn't support streaming
95-
completion_params["max_completion_tokens"] = 2*max_tokens
95+
completion_params["max_completion_tokens"] = 2*self.max_tokens
9696
# completion_params["reasoning_effort"] = "high"
9797
completion_params.pop("max_tokens", None) # Remove 'max_tokens' if present
9898

cover_agent/AgentCompletionABC.py

+143-19
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,192 @@
11
from abc import ABC, abstractmethod
22
from typing import Tuple
33

4-
54
class AgentCompletionABC(ABC):
6-
"""Abstract base class for AI-driven prompt handling."""
5+
"""
6+
Abstract base class for AI-driven prompt handling. Each method accepts
7+
specific input parameters (e.g. source/test content, logs, coverage data)
8+
and returns a tuple containing the AI response, along with additional
9+
metadata (e.g. token usage and the generated prompt).
10+
"""
711

812
@abstractmethod
913
def generate_tests(
1014
self,
11-
failed_tests: str,
15+
source_file_name: str,
16+
max_tests: int,
17+
source_file_numbered: str,
18+
code_coverage_report: str,
1219
language: str,
13-
test_framework: str,
14-
coverage_report: str,
20+
test_file: str,
21+
test_file_name: str,
22+
testing_framework: str,
23+
additional_instructions_text: str = None,
24+
additional_includes_section: str = None,
25+
failed_tests_section: str = None,
1526
) -> Tuple[str, int, int, str]:
1627
"""
17-
Generates additional unit tests to improve test coverage.
28+
Generates additional unit tests to improve coverage or handle edge cases.
29+
30+
Args:
31+
source_file_name (str): Name of the source file under test.
32+
max_tests (int): Maximum number of test functions to propose.
33+
source_file_numbered (str): The source code with line numbers.
34+
code_coverage_report (str): Coverage details highlighting untested lines.
35+
language (str): The programming language (e.g. "python", "java").
36+
test_file (str): Contents of the existing test file.
37+
test_file_name (str): The name/path of the test file.
38+
testing_framework (str): The test framework in use (e.g. "pytest", "junit").
39+
additional_instructions_text (str, optional): Extra instructions or context.
40+
additional_includes_section (str, optional): Additional code or includes.
41+
failed_tests_section (str, optional): Details of failed tests to consider.
1842
1943
Returns:
20-
Tuple[str, int, int, str]: AI-generated test cases, input token count, output token count, and generated prompt.
44+
Tuple[str, int, int, str]:
45+
A 4-element tuple containing:
46+
- The AI-generated test suggestions (string),
47+
- The input token count (int),
48+
- The output token count (int),
49+
- The final constructed prompt (string).
2150
"""
2251
pass
2352

2453
@abstractmethod
2554
def analyze_test_failure(
26-
self, stderr: str, stdout: str, processed_test_file: str
55+
self,
56+
source_file_name: str,
57+
source_file: str,
58+
processed_test_file: str,
59+
stdout: str,
60+
stderr: str,
61+
test_file_name: str,
2762
) -> Tuple[str, int, int, str]:
2863
"""
29-
Analyzes a test failure and returns insights.
64+
Analyzes the output of a failed test to determine possible causes and
65+
recommended fixes.
66+
67+
Args:
68+
source_file_name (str): Name of the source file being tested.
69+
source_file (str): Raw content of the source file.
70+
processed_test_file (str): Content of the failing test file (pre-processed).
71+
stdout (str): Captured standard output from the test run.
72+
stderr (str): Captured standard error from the test run.
73+
test_file_name (str): Name/path of the failing test file.
3074
3175
Returns:
32-
Tuple[str, int, int, str]: AI-generated analysis, input token count, output token count, and generated prompt.
76+
Tuple[str, int, int, str]:
77+
A 4-element tuple containing:
78+
- The AI-generated analysis or explanation (string),
79+
- The input token count (int),
80+
- The output token count (int),
81+
- The final constructed prompt (string).
3382
"""
3483
pass
3584

3685
@abstractmethod
37-
def analyze_test_insert_line(self, test_file: str) -> Tuple[str, int, int, str]:
86+
def analyze_test_insert_line(
87+
self,
88+
language: str,
89+
test_file_numbered: str,
90+
test_file_name: str,
91+
additional_instructions_text: str = None,
92+
) -> Tuple[str, int, int, str]:
3893
"""
39-
Determines where to insert new test cases.
94+
Determines the correct placement for inserting new test cases into
95+
an existing test file.
96+
97+
Args:
98+
language (str): The programming language of the test file.
99+
test_file_numbered (str): The test file content, labeled with line numbers.
100+
test_file_name (str): Name/path of the test file.
101+
additional_instructions_text (str, optional): Any extra instructions or context.
40102
41103
Returns:
42-
Tuple[str, int, int, str]: Suggested insertion point, input token count, output token count, and generated prompt.
104+
Tuple[str, int, int, str]:
105+
A 4-element tuple containing:
106+
- The AI-generated suggestion or instructions (string),
107+
- The input token count (int),
108+
- The output token count (int),
109+
- The final constructed prompt (string).
43110
"""
44111
pass
45112

46113
@abstractmethod
47114
def analyze_test_against_context(
48-
self, test_code: str, context: str
115+
self,
116+
language: str,
117+
test_file_content: str,
118+
test_file_name_rel: str,
119+
context_files_names_rel: str,
49120
) -> Tuple[str, int, int, str]:
50121
"""
51-
Validates whether a test is appropriate for its corresponding source code.
122+
Evaluates a test file against a set of related context files to identify:
123+
1. If it is a unit test,
124+
2. Which context file the test is primarily targeting.
125+
126+
Args:
127+
language (str): The programming language of the test file.
128+
test_file_content (str): Raw content of the test file under review.
129+
test_file_name_rel (str): Relative path/name of the test file.
130+
context_files_names_rel (str): One or more file names related to the context.
52131
53132
Returns:
54-
Tuple[str, int, int, str]: AI validation result, input token count, output token count, and generated prompt.
133+
Tuple[str, int, int, str]:
134+
A 4-element tuple containing:
135+
- The AI-generated classification or analysis (string),
136+
- The input token count (int),
137+
- The output token count (int),
138+
- The final constructed prompt (string).
55139
"""
56140
pass
57141

58142
@abstractmethod
59143
def analyze_suite_test_headers_indentation(
60-
self, test_file: str
144+
self,
145+
language: str,
146+
test_file_name: str,
147+
test_file: str,
148+
) -> Tuple[str, int, int, str]:
149+
"""
150+
Analyzes an existing test suite to determine its indentation style,
151+
the number of existing tests, and potentially the testing framework.
152+
153+
Args:
154+
language (str): The programming language of the test file.
155+
test_file_name (str): Name/path of the test file.
156+
test_file (str): Raw content of the test file.
157+
158+
Returns:
159+
Tuple[str, int, int, str]:
160+
A 4-element tuple containing:
161+
- The AI-generated suite analysis (string),
162+
- The input token count (int),
163+
- The output token count (int),
164+
- The final constructed prompt (string).
165+
"""
166+
pass
167+
168+
@abstractmethod
169+
def adapt_test_command_for_a_single_test_via_ai(
170+
self,
171+
test_file_relative_path: str,
172+
test_command: str,
173+
project_root_dir: str,
61174
) -> Tuple[str, int, int, str]:
62175
"""
63-
Determines the indentation style used in test suite headers.
176+
Adapts an existing test command line to run only a single test file,
177+
preserving other relevant flags and arguments where possible.
178+
179+
Args:
180+
test_file_relative_path (str): Path to the specific test file to be isolated.
181+
test_command (str): The original command line used for running multiple tests.
182+
project_root_dir (str): Root directory of the project.
64183
65184
Returns:
66-
Tuple[str, int, int, str]: Suggested indentation style, input token count, output token count, and generated prompt.
185+
Tuple[str, int, int, str]:
186+
A 4-element tuple containing:
187+
- The AI-generated modified command line (string),
188+
- The input token count (int),
189+
- The output token count (int),
190+
- The final constructed prompt (string).
67191
"""
68192
pass

cover_agent/CoverAgent.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from typing import List
88

99
from cover_agent.CustomLogger import CustomLogger
10-
from cover_agent.PromptBuilder import adapt_test_command_for_a_single_test_via_ai
10+
from cover_agent.PromptBuilder import PromptBuilder, adapt_test_command_for_a_single_test_via_ai
1111
from cover_agent.UnitTestGenerator import UnitTestGenerator
1212
from cover_agent.UnitTestValidator import UnitTestValidator
1313
from cover_agent.UnitTestDB import UnitTestDB
1414
from cover_agent.AICaller import AICaller
15-
from cover_agent.PromptBuilder import PromptBuilder
1615
from cover_agent.AgentCompletionABC import AgentCompletionABC
1716
from cover_agent.DefaultAgentCompletion import DefaultAgentCompletion
17+
import cover_agent.utils
1818

1919

2020
class CoverAgent:
@@ -42,18 +42,8 @@ def __init__(self, args, agent_completion: AgentCompletionABC = None):
4242
self.agent_completion = agent_completion
4343
else:
4444
# Default to using the DefaultAgentCompletion object with the PromptBuilder and AICaller
45-
self.ai_caller = AICaller(model=args.model, api_base=args.api_base)
46-
self.prompt_builder = PromptBuilder(
47-
source_file_path=args.source_file_path,
48-
test_file_path=args.test_file_output_path,
49-
code_coverage_report="",
50-
included_files=UnitTestGenerator.get_included_files(args.included_files, args.project_root),
51-
additional_instructions=args.additional_instructions,
52-
failed_test_runs="",
53-
language="",
54-
testing_framework="",
55-
project_root=args.project_root,
56-
)
45+
self.ai_caller = AICaller(model=args.model, api_base=args.api_base, max_tokens=8192)
46+
self.prompt_builder = PromptBuilder()
5747
self.agent_completion = DefaultAgentCompletion(
5848
builder=self.prompt_builder, caller=self.ai_caller
5949
)
@@ -231,13 +221,18 @@ def run_test_gen(
231221
)
232222

233223
# Loop through each new test and validate it
234-
for generated_test in generated_tests_dict.get("new_tests", []):
235-
# Validate the test and record the result
236-
test_result = self.test_validator.validate_test(generated_test)
237-
238-
# Insert the test result into the database
239-
test_result["prompt"] = self.test_gen.prompt["user"]
240-
self.test_db.insert_attempt(test_result)
224+
try:
225+
for generated_test in generated_tests_dict.get("new_tests", []):
226+
# Validate the test and record the result
227+
test_result = self.test_validator.validate_test(generated_test)
228+
229+
# Insert the test result into the database
230+
test_result["prompt"] = self.test_gen.prompt["user"]
231+
self.test_db.insert_attempt(test_result)
232+
except AttributeError as e:
233+
self.logger.error(
234+
f"Failed to validate the test {generated_test} within {generated_tests_dict}. Error: {e}"
235+
)
241236

242237
# Increment the iteration count
243238
iteration_count += 1

0 commit comments

Comments
 (0)