Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions evaluator/problem_suites/polyglot/polyglot_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from utils.git import init_local_repo_with_initial_commit
from evaluator.sandbox.sandbox_manager import SandboxManager
from evaluator.problem_suites.problem_suite import ProblemSuite, ProblemSuiteName
from utils.diff import get_file_diff, apply_diff_to_local_repo, validate_diff_for_local_repo
from utils.diff import get_file_diff, apply_diff_to_local_repo, validate_diff_for_local_repo, validate_patched_files_syntax



Expand Down Expand Up @@ -147,7 +147,13 @@ def _on_mount(temp_dir: str):
# Apply the patch
apply_diff_to_local_repo(patch, sandbox_repo_dir)


# Syntax-check the patched files
is_valid, error_message = validate_patched_files_syntax(sandbox_repo_dir)
if not is_valid:
raise EvaluationRunException(
EvaluationRunErrorCode.AGENT_INVALID_PATCH,
f"{EvaluationRunErrorCode.AGENT_INVALID_PATCH.get_error_message()}: {error_message}"
)

return sandbox_manager.initialize_sandbox(
name=f"eval-sandbox-{problem.name}-{evaluation_run_id}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import BaseModel
from utils.docker import get_docker_client
from typing import Any, Dict, List, Tuple, Optional
from utils.diff import validate_diff_for_local_repo
from utils.diff import validate_diff_for_local_repo, apply_diff_to_local_repo, validate_patched_files_syntax
from evaluator.models import EvaluationRunException
from swebench.harness.constants import SWEbenchInstance
from utils.temp import create_temp_dir, delete_temp_dir
Expand Down Expand Up @@ -184,7 +184,14 @@ def initialize_eval_sandbox(
f"{EvaluationRunErrorCode.AGENT_INVALID_PATCH.get_error_message()}: {error_message}"
)


# Syntax-check the patched files
apply_diff_to_local_repo(patch, temp_dir)
is_valid, error_message = validate_patched_files_syntax(temp_dir)
if not is_valid:
raise EvaluationRunException(
EvaluationRunErrorCode.AGENT_INVALID_PATCH,
f"{EvaluationRunErrorCode.AGENT_INVALID_PATCH.get_error_message()}: {error_message}"
)

swebench_instance = problem.userdata

Expand Down
90 changes: 58 additions & 32 deletions utils/diff.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for computing diffs between files."""

import ast
import os
import tempfile
import subprocess
Expand All @@ -8,15 +9,14 @@
from typing import Tuple, Optional



def get_file_diff(old_path, new_path) -> str:
"""
Gets the diff between two files.

Args:
old_path: The path to the old file
new_path: The path to the new file

Returns:
The diff between the two files, expressed as a diff of the old file, as a string.
"""
Expand All @@ -28,13 +28,9 @@ def get_file_diff(old_path, new_path) -> str:
missing.append(new_path)
if missing:
logger.fatal(f"File(s) not found for diff: {', '.join(missing)}")

# Use diff command
result = subprocess.run(
["diff", "-u", old_path, new_path],
capture_output=True,
text=True
)
result = subprocess.run(["diff", "-u", old_path, new_path], capture_output=True, text=True)

# Check if the diff was generated successfully
# `diff -u` return codes:
Expand All @@ -53,51 +49,44 @@ def get_file_diff(old_path, new_path) -> str:
filename = os.path.basename(old_path)
lines[0] = f"--- {filename}"
lines[1] = f"+++ {filename}"

return "\n".join(lines)

return "\n".join(lines)


def validate_diff_for_local_repo(diff, local_repo_dir) -> Tuple[bool, Optional[str]]:
"""
Validates if a diff string is valid and can be applied to a local repository.

Args:
diff: The diff string to validate
local_repo_dir: The local repository directory

Returns:
(is_valid: bool, error_message: Optional[str])
"""

# Write diff to temp file
with tempfile.NamedTemporaryFile(mode="w", suffix=".diff", delete=False) as f:
f.write(diff)
diff_file = f.name

# Use `git apply --check` to validate without applying
result = subprocess.run(
["git", "apply", "--check", diff_file],
cwd=local_repo_dir,
capture_output=True,
text=True
)
result = subprocess.run(["git", "apply", "--check", diff_file], cwd=local_repo_dir, capture_output=True, text=True)

# Delete the temp file
os.unlink(diff_file)

# Check if the diff was applied successfully
if result.returncode == 0:
return True, None
else:
return False, result.stderr.strip()



def apply_diff_to_local_repo(diff, local_repo_dir) -> None:
"""
Applies a diff string to files in the source directory.

Args:
diff: The diff string to apply
local_repo_dir: The local repository directory
Expand All @@ -107,18 +96,55 @@ def apply_diff_to_local_repo(diff, local_repo_dir) -> None:
with tempfile.NamedTemporaryFile(mode="w", suffix=".diff", delete=False) as f:
f.write(diff)
diff_file = f.name

# Use `git apply` to apply the diff
result = subprocess.run(
["git", "apply", diff_file],
cwd=local_repo_dir,
capture_output=True,
text=True
)
result = subprocess.run(["git", "apply", diff_file], cwd=local_repo_dir, capture_output=True, text=True)

# Delete the temp file
os.unlink(diff_file)

# Check if the diff was applied successfully
if result.returncode != 0:
logger.fatal(f"Failed to apply diff to {local_repo_dir}: {result.stderr.strip()}")
logger.fatal(f"Failed to apply diff to {local_repo_dir}: {result.stderr.strip()}")


def validate_patched_files_syntax(repo_dir: str) -> Tuple[bool, Optional[str]]:
"""
After a patch has been applied, check that modified files have valid syntax.
Supports Python (.py) and JavaScript (.js, .mjs) files.

Args:
repo_dir: The repository directory where the patch was applied

Returns:
(is_valid: bool, error_message: Optional[str])
"""
result = subprocess.run(["git", "diff", "--name-only"], cwd=repo_dir, capture_output=True, text=True)
modified_files = [f.strip() for f in result.stdout.strip().splitlines() if f.strip()]

errors = []
for filepath in modified_files:
full_path = os.path.join(repo_dir, filepath)
if not os.path.exists(full_path):
continue

if filepath.endswith(".py"):
try:
with open(full_path, "r") as f:
source = f.read()
ast.parse(source, filename=filepath)
except SyntaxError as e:
errors.append(f"{filepath}:{e.lineno}: {e.msg}")

elif filepath.endswith((".js", ".mjs")):
with open(full_path, "r") as f:
source = f.read()
result = subprocess.run(
["node", "--input-type=module", "--check"], input=source, capture_output=True, text=True
)
if result.returncode != 0:
errors.append(f"{filepath}: {result.stderr.strip()}")

if errors:
return False, "Patched files have syntax errors:\n" + "\n".join(errors)
return True, None