Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,7 @@ def _update_progress(

# Call custom callback if provided
if self.config.progress_callback:
self.config.progress_callback(
stats.completed, stats.total, sample_result, stats
)
self.config.progress_callback(stats.completed, stats.total, sample_result, stats)

# Update display based on mode
if progress_mode == "tqdm" and pbar is not None:
Expand All @@ -282,7 +280,7 @@ def _update_progress(
if stats.completed % interval == 0 or stats.completed == stats.total:
print(
f" Progress: {stats.completed}/{stats.total} "
f"({stats.completed/stats.total:.0%}) | "
f"({stats.completed / stats.total:.0%}) | "
f"Acc: {stats.accuracy:.1%} | "
f"Errors: {stats.errors} | "
f"ETA: {self._format_eta(stats.eta_seconds)}"
Expand Down
4 changes: 1 addition & 3 deletions examples/oolong_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def load_random_oolong_row() -> dict:

def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Run Oolong benchmark example with RLM"
)
parser = argparse.ArgumentParser(description="Run Oolong benchmark example with RLM")
parser.add_argument(
"--backend",
type=str,
Expand Down
1 change: 1 addition & 0 deletions examples/subprocess_repl_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def demonstrate_isolation():
print(f"Platform: {platform.system()}")
if platform.system() == "Linux":
import shutil

if not shutil.which("bwrap"):
print("Note: bubblewrap not installed - filesystem sandbox disabled")
print()
Expand Down
103 changes: 68 additions & 35 deletions rlm/environments/subprocess_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,22 @@ def __init__(

# Pre-approved packages (stdlib + user-specified)
self.allowed_packages: set[str] = {
"json", "re", "math", "collections", "itertools",
"functools", "datetime", "random", "string", "typing",
"os", "sys", "io", "time", "pathlib", "copy",
"json",
"re",
"math",
"collections",
"itertools",
"functools",
"datetime",
"random",
"string",
"typing",
"os",
"sys",
"io",
"time",
"pathlib",
"copy",
}
if allowed_packages:
self.allowed_packages.update(allowed_packages)
Expand Down Expand Up @@ -472,10 +485,12 @@ def _install_package(self, package: str):
capture_output=True,
)
self._installed_packages.add(package)
self._overhead_stats["package_installs"].append({
"package": package,
"time_ms": (time.perf_counter() - start) * 1000,
})
self._overhead_stats["package_installs"].append(
{
"package": package,
"time_ms": (time.perf_counter() - start) * 1000,
}
)

def _extract_missing_module(self, stderr: str) -> str | None:
"""Extract module name from ImportError/ModuleNotFoundError."""
Expand Down Expand Up @@ -580,24 +595,38 @@ def _linux_sandbox_wrap(self, cmd: list[str]) -> list[str]:

bwrap_cmd = [
"bwrap",
"--ro-bind", "/usr", "/usr",
"--ro-bind", "/lib", "/lib",
"--ro-bind", "/bin", "/bin",
"--ro-bind", "/sbin", "/sbin",
"--ro-bind",
"/usr",
"/usr",
"--ro-bind",
"/lib",
"/lib",
"--ro-bind",
"/bin",
"/bin",
"--ro-bind",
"/sbin",
"/sbin",
]

# Add /lib64 if it exists
if os.path.exists("/lib64"):
bwrap_cmd.extend(["--ro-bind", "/lib64", "/lib64"])

bwrap_cmd.extend([
"--ro-bind", self.venv_path, self.venv_path,
"--bind", self.temp_dir, self.temp_dir,
"--unshare-net",
"--unshare-pid",
"--die-with-parent",
"--",
])
bwrap_cmd.extend(
[
"--ro-bind",
self.venv_path,
self.venv_path,
"--bind",
self.temp_dir,
self.temp_dir,
"--unshare-net",
"--unshare-pid",
"--die-with-parent",
"--",
]
)

return bwrap_cmd + cmd

Expand All @@ -621,11 +650,13 @@ def execute_code(self, code: str) -> REPLResult:
total_time = time.perf_counter() - total_start

# Track overhead
self._overhead_stats["executions"].append({
"total_ms": total_time * 1000,
"code_ms": (result.execution_time or 0) * 1000,
"overhead_ms": (total_time - (result.execution_time or 0)) * 1000,
})
self._overhead_stats["executions"].append(
{
"total_ms": total_time * 1000,
"code_ms": (result.execution_time or 0) * 1000,
"overhead_ms": (total_time - (result.execution_time or 0)) * 1000,
}
)

return result

Expand All @@ -638,13 +669,15 @@ def _try_execute(self, code: str) -> REPLResult:

# Build environment - inherit essential vars for macOS compatibility
env = os.environ.copy()
env.update({
"PATH": os.path.join(self.venv_path, "bin") + ":/usr/bin:/bin",
"HOME": self.temp_dir,
"TMPDIR": self.temp_dir,
"PYTHONDONTWRITEBYTECODE": "1",
"PYTHONNOUSERSITE": "1",
})
env.update(
{
"PATH": os.path.join(self.venv_path, "bin") + ":/usr/bin:/bin",
"HOME": self.temp_dir,
"TMPDIR": self.temp_dir,
"PYTHONDONTWRITEBYTECODE": "1",
"PYTHONNOUSERSITE": "1",
}
)

try:
result = subprocess.run(
Expand Down Expand Up @@ -720,9 +753,7 @@ def add_context(
context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt")
with open(context_path, "w") as f:
f.write(context_payload)
self.execute_code(
f"with open(r'{context_path}', 'r') as f:\n {var_name} = f.read()"
)
self.execute_code(f"with open(r'{context_path}', 'r') as f:\n {var_name} = f.read()")
else:
context_path = os.path.join(self.temp_dir, f"context_{context_index}.json")
with open(context_path, "w") as f:
Expand Down Expand Up @@ -798,7 +829,9 @@ def get_overhead_summary(self) -> dict:
"packages_installed": [i["package"] for i in installs],
"overhead_percentage": round(
(total_overhead / (total_overhead + total_code_time)) * 100, 1
) if total_code_time > 0 else 0,
)
if total_code_time > 0
else 0,
}

def print_overhead_summary(self):
Expand Down
18 changes: 11 additions & 7 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ def custom_inference(sample):
)

# All samples should be correct (temp files weren't clobbered)
assert result.accuracy == 1.0, f"Some samples were clobbered: {[sr for sr in result.sample_results if not sr.is_correct]}"
assert result.accuracy == 1.0, (
f"Some samples were clobbered: {[sr for sr in result.sample_results if not sr.is_correct]}"
)

# Verify multiple threads were used
unique_threads = set(thread_ids.values())
Expand Down Expand Up @@ -398,19 +400,21 @@ def test_progress_callback_invoked(self):
callback_invocations = []

def track_callback(completed, total, sample_result, stats):
callback_invocations.append({
"completed": completed,
"total": total,
"accuracy": stats.accuracy,
})
callback_invocations.append(
{
"completed": completed,
"total": total,
"accuracy": stats.accuracy,
}
)

benchmark = NIAHBenchmark(context_length=1000)
runner = BenchmarkRunner(
progress="none",
progress_callback=track_callback,
)

result = runner.run(
runner.run(
benchmark,
method="custom",
custom_fn=lambda s: s.expected_answer,
Expand Down
Loading