diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py index a84d9d48..06258087 100644 --- a/tests/benchmarks/test_benchmarks.py +++ b/tests/benchmarks/test_benchmarks.py @@ -242,6 +242,131 @@ def test_evaluate(self): assert 0 < result["f1"] < 1.0 +class TestParallelExecution: + """Tests for parallel execution safety.""" + + def test_parallel_custom_fn_isolation(self): + """Test that parallel execution with custom functions don't interfere.""" + import tempfile + import threading + import time + from pathlib import Path + + from benchmarks.runner import BenchmarkRunner + from benchmarks.tasks.niah import NIAHBenchmark + + # Track which thread processed which sample + thread_ids = {} + lock = threading.Lock() + + def custom_inference(sample): + """Custom inference that tests isolation by writing temp files.""" + # Create a temp file unique to this invocation + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: + temp_path = f.name + f.write(sample.id) + + # Small delay to increase chance of race conditions + time.sleep(0.01) + + # Read back and verify it's still our data + with open(temp_path) as f: + read_id = f.read() + + # Clean up + Path(temp_path).unlink() + + # Record thread info + with lock: + thread_ids[sample.id] = threading.current_thread().ident + + # Return expected answer if our temp file wasn't clobbered + if read_id == sample.id: + return sample.expected_answer + else: + return f"CLOBBERED: expected {sample.id}, got {read_id}" + + benchmark = NIAHBenchmark(context_length=1000) + runner = BenchmarkRunner(max_workers=4) + + result = runner.run( + benchmark, + method="custom", + custom_fn=custom_inference, + num_samples=8, + seed=42, + max_workers=4, + ) + + # All samples should be correct (temp files weren't clobbered) + assert result.accuracy == 1.0, f"Some samples were clobbered: {[sr for sr in result.sample_results if not sr.is_correct]}" + + # Verify multiple threads were used + unique_threads = set(thread_ids.values()) + assert len(unique_threads) > 1, "Expected multiple threads to be used" + + def test_parallel_preserves_order(self): + """Test that parallel results maintain original sample order.""" + from benchmarks.runner import BenchmarkRunner + from benchmarks.tasks.niah import NIAHBenchmark + + def custom_inference(sample): + return sample.expected_answer + + benchmark = NIAHBenchmark(context_length=1000) + runner = BenchmarkRunner(max_workers=4) + + result = runner.run( + benchmark, + method="custom", + custom_fn=custom_inference, + num_samples=10, + seed=42, + max_workers=4, + ) + + # Verify results are in expected order + expected_ids = [f"niah-{i:04d}" for i in range(10)] + actual_ids = [sr.sample_id for sr in result.sample_results] + assert actual_ids == expected_ids, f"Order mismatch: {actual_ids}" + + def test_parallel_error_handling(self): + """Test that errors in parallel execution are handled gracefully.""" + from benchmarks.runner import BenchmarkRunner + from benchmarks.tasks.niah import NIAHBenchmark + + call_count = {"count": 0} + + def flaky_inference(sample): + call_count["count"] += 1 + if call_count["count"] % 3 == 0: + raise ValueError(f"Simulated error for sample {sample.id}") + return sample.expected_answer + + benchmark = NIAHBenchmark(context_length=1000) + runner = BenchmarkRunner(max_workers=2) + + result = runner.run( + benchmark, + method="custom", + custom_fn=flaky_inference, + num_samples=6, + seed=42, + max_workers=2, + ) + + # Should complete without crashing + assert len(result.sample_results) == 6 + + # Some should have errors + errors = [sr for sr in result.sample_results if sr.error] + assert len(errors) > 0, "Expected some errors" + + # Errors should have zero metrics + for sr in errors: + assert sr.is_correct is False + + class TestBenchmarkIntegration: """Integration tests for benchmark framework."""