-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding SWE-Bench dataset and optimizer
- Loading branch information
1 parent
ce8936e
commit 9b66d87
Showing
5 changed files
with
370 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import subprocess | ||
import tempfile | ||
import pytest | ||
from unittest.mock import MagicMock | ||
from dspygen.utils.dspy_tools import init_dspy | ||
|
||
# Assuming FSM mixin and state definitions are as discussed earlier | ||
from dspygen.agents.coder_agent import CoderAgentState | ||
from dspygen.agents.coder_agent_v4 import CoderAgent | ||
from dspygen.mixin.fsm.fsm_mixin import trigger, FSMMixin | ||
|
||
|
||
class PytestAgent(FSMMixin): | ||
def __init__(self, code): | ||
super().setup_fsm(CoderAgentState, initial=CoderAgentState.ANALYZING_REQUIREMENTS) | ||
self.code = code | ||
self.errors = [] | ||
self.filename = None | ||
self.test_filename = None | ||
|
||
@trigger(source=CoderAgentState.ANALYZING_REQUIREMENTS, dest=CoderAgentState.WRITING_CODE) | ||
def start_coding(self): | ||
"""Write code and tests to files.""" | ||
self.filename = self.write_code_to_file(self.code) | ||
test_code = self.generate_mock_tests() | ||
self.test_filename = self.write_test_code_to_file(test_code) | ||
|
||
def write_code_to_file(self, code): | ||
"""Write code to a temporary file.""" | ||
with tempfile.NamedTemporaryFile(delete=False, suffix='.py', mode='w') as f: | ||
f.write(code) | ||
return f.name | ||
|
||
def write_test_code_to_file(self, test_code): | ||
"""Write test code to a temporary file.""" | ||
with tempfile.NamedTemporaryFile(delete=False, suffix='_test.py', mode='w') as f: | ||
f.write(test_code) | ||
return f.name | ||
|
||
def generate_mock_tests(self): | ||
"""Generate basic pytest tests with mocks.""" | ||
|
||
# Generate with dspygen | ||
from dspygen.typetemp.functional import render | ||
from dspygen.utils.dspy_tools import init_ol | ||
lm = init_ol() | ||
# source_code = bad_code | ||
source_code = example_code | ||
from dspygen.modules.pytest_module import pytest_call | ||
result = pytest_call(source_code=source_code) | ||
from dspygen.utils.file_tools import extract_code | ||
import_str = "from {{ module }} import fetch_user_name\n\n" | ||
# print(extract_code(result)) | ||
# print(lm.inspect_history(n=1)) | ||
result = import_str + extract_code(result) | ||
rcode = render(result, module=os.path.basename(self.filename)[:-3]) | ||
|
||
return rcode | ||
|
||
@trigger(source=CoderAgentState.WRITING_CODE, dest=CoderAgentState.TESTING_CODE) | ||
def test_code(self): | ||
"""Run tests using pytest.""" | ||
result = subprocess.run(['pytest', self.test_filename, '-v'], capture_output=True, text=True, timeout=30) | ||
if result.returncode != 0: | ||
self.errors.append(result.stderr) | ||
print("Test Failed:", result.stdout) | ||
else: | ||
print("Test Passed:", result.stdout) | ||
|
||
@trigger(source=CoderAgentState.TESTING_CODE, dest=CoderAgentState.COMPLETING_TASK, unless=['errors_detected']) | ||
def complete_task(self): | ||
"""Complete the task if tests pass.""" | ||
print("Task completed successfully.") | ||
os.remove(self.filename) | ||
os.remove(self.test_filename) | ||
|
||
def errors_detected(self): | ||
"""Check if there are any errors in the code.""" | ||
return len(self.errors) > 0 | ||
|
||
|
||
example_code = """def fetch_user_name(user_id): | ||
import requests | ||
response = requests.get(f'https://api.example.com/users/{user_id}') | ||
return response.json()['name'] | ||
""" | ||
|
||
def main(): | ||
init_dspy(max_tokens=3000) # Initialize the dspy environment | ||
code_agent = CoderAgent() | ||
agent = PytestAgent(code=example_code) | ||
print("Initial state:", agent.state) | ||
agent.start_coding() | ||
agent.test_code() | ||
if not agent.errors_detected(): | ||
agent.complete_task() | ||
print("Final state:", agent.state) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import random | ||
import tqdm | ||
from datasets import load_dataset | ||
import dspy | ||
|
||
|
||
from pydantic import BaseModel, Field | ||
from typing import List, Optional | ||
import json | ||
|
||
|
||
class SWEBenchData(BaseModel): | ||
instance_id: str = Field(description="A formatted instance identifier, usually as repo_owner__repo_name-PR-number.") | ||
patch: str = Field(description="The gold patch generated by the PR minus test-related code.") | ||
repo: str = Field(description="The repository owner/name identifier from GitHub.") | ||
base_commit: str = Field(description="The commit hash of the repository representing the HEAD of the repository before the solution PR is applied.") | ||
hints_text: str = Field(description="Comments made on the issue prior to the creation of the solution PR’s first commit creation date.") | ||
created_at: str = Field(description="The creation date of the pull request.") | ||
test_patch: str = Field(description="A test-file patch that was contributed by the solution PR.") | ||
problem_statement: str = Field(description="The issue title and body.") | ||
version: str = Field(description="Installation version to use for running evaluation.") | ||
environment_setup_commit: str = Field(description="Commit hash to use for environment setup and installation.") | ||
FAIL_TO_PASS: Optional[List[str]] = Field(default=None, description="A list of strings that represent the set of tests resolved by the PR and tied to the issue resolution.") | ||
PASS_TO_PASS: Optional[List[str]] = Field(default=None, description="A list of strings that represent tests that should pass before and after the PR application.") | ||
|
||
|
||
def get_instance_by_id(dataset, instance_id): | ||
""" | ||
Retrieve a model instance from the dataset by instance_id. | ||
Args: | ||
dataset (list): The dataset containing the instances. | ||
instance_id (str): The instance identifier to search for. | ||
Returns: | ||
SWEBenchData: The found model instance or None if not found. | ||
""" | ||
for item in dataset: | ||
if item['instance_id'] == instance_id: | ||
# Convert FAIL_TO_PASS and PASS_TO_PASS from JSON string to List | ||
item['FAIL_TO_PASS'] = json.loads(item['FAIL_TO_PASS']) if item['FAIL_TO_PASS'] else None | ||
item['PASS_TO_PASS'] = json.loads(item['PASS_TO_PASS']) if item['PASS_TO_PASS'] else None | ||
return SWEBenchData(**item) | ||
return None | ||
|
||
|
||
class SWEBench: | ||
def __init__(self, do_shuffle=True, shuffle_seed=0) -> None: | ||
super().__init__() | ||
|
||
# Load the dataset from the Hugging Face Hub | ||
dataset = load_dataset("princeton-nlp/SWE-bench_oracle", 'default') | ||
|
||
hf_official_train = dataset['train'] | ||
hf_official_test = dataset['test'] | ||
official_train = [] | ||
official_test = [] | ||
|
||
for example in tqdm.tqdm(hf_official_train): | ||
issue = example['problem_statement'] | ||
patch = example['patch'] | ||
test_patch = example['test_patch'] | ||
|
||
official_train.append(dict(issue=issue, patch=patch, test_patch=test_patch)) | ||
|
||
for example in tqdm.tqdm(hf_official_test): | ||
issue = example['problem_statement'] | ||
patch = example['patch'] | ||
test_patch = example['test_patch'] | ||
|
||
official_test.append(dict(issue=issue, patch=patch, test_patch=test_patch)) | ||
|
||
# Optionally shuffle datasets | ||
if do_shuffle: | ||
rng = random.Random(shuffle_seed) | ||
rng.shuffle(official_train) | ||
rng.shuffle(official_test) | ||
|
||
# Split the data | ||
trainset = official_train[:int(0.8 * len(official_train))] | ||
devset = official_train[int(0.8 * len(official_train)):] | ||
testset = official_test | ||
|
||
# Wrap data into dspy.Example format | ||
self.train = [dspy.Example(**x).with_inputs('issue') for x in trainset] | ||
self.dev = [dspy.Example(**x).with_inputs('issue') for x in devset] | ||
self.test = [dspy.Example(**x).with_inputs('issue') for x in testset] | ||
|
||
|
||
|
||
def main(): | ||
"""Main function""" | ||
# Example instantiation and use: | ||
swe_bench = SWEBench(do_shuffle=True, shuffle_seed=42) | ||
print(f"Trainset size: {len(swe_bench.train)}") | ||
print(f"Devset size: {len(swe_bench.dev)}") | ||
print(f"Testset size: {len(swe_bench.test)}") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import dspy | ||
|
||
from dspygen.experiments.mock_gen.swe_bench import SWEBench | ||
from dspygen.utils.dspy_tools import init_ol | ||
|
||
|
||
class CoT(dspy.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.prog = dspy.ChainOfThought("issue -> patch") | ||
|
||
def forward(self, issue): | ||
return self.prog(issue=issue) | ||
|
||
|
||
def main(): | ||
"""Main function""" | ||
from dspy.teleprompt import BootstrapFewShot | ||
# Set up the LM | ||
lm = init_ol() | ||
|
||
# Load the SWE-bench dataset | ||
swe_bench = SWEBench() | ||
swe_bench_trainset, swe_bench_devset = swe_bench.train[:10], swe_bench.dev[:10] | ||
|
||
print(swe_bench_trainset) | ||
|
||
# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program. | ||
config = dict(max_bootstrapped_demos=4, max_labeled_demos=4) | ||
|
||
# Define a custom metric for evaluating patches | ||
def swebench_metric(gold, pred, trace=None): | ||
# This is a placeholder metric; adjust based on actual evaluation needs | ||
return gold.patch == pred.patch | ||
|
||
teleprompter = BootstrapFewShot(metric=swebench_metric, **config) | ||
optimized_cot = teleprompter.compile(CoT(), trainset=swe_bench_trainset) | ||
|
||
from dspy.evaluate import Evaluate | ||
|
||
# Set up the evaluator, which can be used multiple times. | ||
evaluate = Evaluate(devset=swe_bench_devset, metric=swebench_metric, num_threads=4, display_progress=True, | ||
display_table=0) | ||
|
||
# Evaluate our `optimized_cot` program. | ||
evaluate(optimized_cot) | ||
|
||
lm.inspect_history(n=1) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import dspy | ||
|
||
|
||
class GenerateMockPytest(dspy.Signature): | ||
""" | ||
Generates a mocked pytest module for the provided Python source code. | ||
This class aims to create comprehensive and robust mock tests that simulate | ||
possible unit tests based on the functions and methods defined within the source code. | ||
Write the test like a FAANG Python architect at Meta. | ||
Only reply within ```python``` block. All other text needs to be in docstrings or comments. | ||
Use with patch() | ||
""" | ||
source_code = dspy.InputField(desc="Python source code for which to generate a mock test.") | ||
mocked_pytest = dspy.OutputField(desc="Generated mock pytest code. Within triple ba", prefix="```python\n") | ||
|
||
|
||
class PytestModule(dspy.Module): | ||
"""PytestModule""" | ||
|
||
def __init__(self, **forward_args): | ||
super().__init__() | ||
self.forward_args = forward_args | ||
self.output = None | ||
|
||
def forward(self, source_code): | ||
pred = dspy.Predict(GenerateMockPytest) | ||
self.output = pred(source_code=source_code).mocked_pytest | ||
return self.output | ||
|
||
|
||
def pytest_call(source_code): | ||
pytest = PytestModule() | ||
return pytest.forward(source_code=source_code) | ||
|
||
|
||
example_code = """def fetch_user_name(user_id): | ||
import requests | ||
response = requests.get(f'https://api.example.com/users/{user_id}') | ||
return response.json()['name'] | ||
""" | ||
|
||
|
||
bad_code = """def fetch_user_name(user_id): | ||
import requests | ||
response = requests.get(f'https://api.example.com/users/{user_id}') | ||
return response.json()['name'] | ||
``` | ||
import pytest | ||
from your_module import fetch_user_name | ||
@pytest.fixture | ||
def mocker(): | ||
return pytest.mockito() | ||
def test_fetch_user_name(mocker): | ||
mocked_requests_get = mocker.patch('requests.get') | ||
response_json = {'name': 'John Doe'} | ||
mocked_requests_get.return_value.json.return_value = response_json | ||
result = fetch_user_name(123) | ||
assert result == 'John Doe' | ||
# Verify that the requests.get call was not made | ||
assert not mocked_requests_get.called | ||
``` | ||
In this example, we're using the `mocker` fixture to create a mock object for the `requests.get` function. We then set up the mock to return a response with a JSON payload containing the user's name. Finally, we test that our `fetch_user_name` function returns the expected result without actually making a network request. | ||
Initial state: ANALYZING_REQUIREMENTS | ||
Test Failed: ============================= test session starts ============================== | ||
platform darwin -- Python 3.12.3, pytest-8.2.0, pluggy-1.5.0 -- /Users/sac/Library/Caches/pypoetry/virtualenvs/soc-FgW3JNy9-py3.12/bin/python | ||
cachedir: .pytest_cache | ||
rootdir: /var/folders/s6/jqyw48zs39z38b_3f6f_x2sc0000gn/T | ||
plugins: anyio-4.3.0, clarity-1.0.1, Faker-23.3.0, asyncio-0.23.6, mock-3.14.0, xdist-3.6.1 | ||
asyncio: mode=Mode.STRICT | ||
collecting ... collected 1 item | ||
../../../../../../../var/folders/s6/jqyw48zs39z38b_3f6f_x2sc0000gn/T/tmp880863oe_test.py::test_fetch_user_name ERROR [100%] | ||
==================================== ERRORS ==================================== | ||
____________________ ERROR at setup of test_fetch_user_name ____________________ | ||
@pytest.fixture | ||
def mocker(): | ||
> return pytest.mockito() | ||
E AttributeError: module 'pytest' has no attribute 'mockito' | ||
/var/folders/s6/jqyw48zs39z38b_3f6f_x2sc0000gn/T/tmp880863oe_test.py:6: AttributeError | ||
=========================== short test summary info ============================ | ||
ERROR ../../../../../../../var/folders/s6/jqyw48zs39z38b_3f6f_x2sc0000gn/T/tmp880863oe_test.py::test_fetch_user_name | ||
=============================== 1 error in 0.04s =============================== | ||
Final state: TESTING_CODE | ||
""" | ||
|
||
|
||
def main(): | ||
from dspygen.utils.dspy_tools import init_ol | ||
lm = init_ol() | ||
# source_code = bad_code | ||
source_code = example_code | ||
result = pytest_call(source_code=source_code) | ||
from dspygen.utils.file_tools import extract_code | ||
print(extract_code(result)) | ||
# print(lm.inspect_history(n=1)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |