Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,131 @@ def test_evaluate(self):
assert 0 < result["f1"] < 1.0


class TestParallelExecution:
"""Tests for parallel execution safety."""

def test_parallel_custom_fn_isolation(self):
"""Test that parallel execution with custom functions don't interfere."""
import tempfile
import threading
import time
from pathlib import Path

from benchmarks.runner import BenchmarkRunner
from benchmarks.tasks.niah import NIAHBenchmark

# Track which thread processed which sample
thread_ids = {}
lock = threading.Lock()

def custom_inference(sample):
"""Custom inference that tests isolation by writing temp files."""
# Create a temp file unique to this invocation
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_path = f.name
f.write(sample.id)

# Small delay to increase chance of race conditions
time.sleep(0.01)

# Read back and verify it's still our data
with open(temp_path) as f:
read_id = f.read()

# Clean up
Path(temp_path).unlink()

# Record thread info
with lock:
thread_ids[sample.id] = threading.current_thread().ident

# Return expected answer if our temp file wasn't clobbered
if read_id == sample.id:
return sample.expected_answer
else:
return f"CLOBBERED: expected {sample.id}, got {read_id}"

benchmark = NIAHBenchmark(context_length=1000)
runner = BenchmarkRunner(max_workers=4)

result = runner.run(
benchmark,
method="custom",
custom_fn=custom_inference,
num_samples=8,
seed=42,
max_workers=4,
)

# All samples should be correct (temp files weren't clobbered)
assert result.accuracy == 1.0, f"Some samples were clobbered: {[sr for sr in result.sample_results if not sr.is_correct]}"

# Verify multiple threads were used
unique_threads = set(thread_ids.values())
assert len(unique_threads) > 1, "Expected multiple threads to be used"

def test_parallel_preserves_order(self):
"""Test that parallel results maintain original sample order."""
from benchmarks.runner import BenchmarkRunner
from benchmarks.tasks.niah import NIAHBenchmark

def custom_inference(sample):
return sample.expected_answer

benchmark = NIAHBenchmark(context_length=1000)
runner = BenchmarkRunner(max_workers=4)

result = runner.run(
benchmark,
method="custom",
custom_fn=custom_inference,
num_samples=10,
seed=42,
max_workers=4,
)

# Verify results are in expected order
expected_ids = [f"niah-{i:04d}" for i in range(10)]
actual_ids = [sr.sample_id for sr in result.sample_results]
assert actual_ids == expected_ids, f"Order mismatch: {actual_ids}"

def test_parallel_error_handling(self):
"""Test that errors in parallel execution are handled gracefully."""
from benchmarks.runner import BenchmarkRunner
from benchmarks.tasks.niah import NIAHBenchmark

call_count = {"count": 0}

def flaky_inference(sample):
call_count["count"] += 1
if call_count["count"] % 3 == 0:
raise ValueError(f"Simulated error for sample {sample.id}")
return sample.expected_answer

benchmark = NIAHBenchmark(context_length=1000)
runner = BenchmarkRunner(max_workers=2)

result = runner.run(
benchmark,
method="custom",
custom_fn=flaky_inference,
num_samples=6,
seed=42,
max_workers=2,
)

# Should complete without crashing
assert len(result.sample_results) == 6

# Some should have errors
errors = [sr for sr in result.sample_results if sr.error]
assert len(errors) > 0, "Expected some errors"

# Errors should have zero metrics
for sr in errors:
assert sr.is_correct is False


class TestBenchmarkIntegration:
"""Integration tests for benchmark framework."""

Expand Down
Loading