diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cost_hash_map.py b/tests/test_cost_hash_map.py new file mode 100644 index 00000000..1a5f0bda --- /dev/null +++ b/tests/test_cost_hash_map.py @@ -0,0 +1,102 @@ +"""Tests for inference_gateway/cost_hash_map.py — in-memory cost tracking.""" + +import time +import pytest + +from uuid import uuid4 +from inference_gateway.cost_hash_map import CostHashMap, CostHashMapEntry, COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS + + +class TestCostHashMapEntry: + """Tests for CostHashMapEntry model.""" + + def test_creation(self): + entry = CostHashMapEntry(cost=1.5, last_accessed_at=time.time()) + assert entry.cost == 1.5 + + def test_zero_cost(self): + entry = CostHashMapEntry(cost=0.0, last_accessed_at=time.time()) + assert entry.cost == 0.0 + + +class TestCostHashMapGetCost: + """Tests for CostHashMap.get_cost.""" + + def test_unknown_uuid_returns_zero(self): + chm = CostHashMap() + assert chm.get_cost(uuid4()) == 0 + + def test_known_uuid_returns_cost(self): + chm = CostHashMap() + uid = uuid4() + chm.add_cost(uid, 3.14) + assert chm.get_cost(uid) == 3.14 + + def test_multiple_uuids_independent(self): + chm = CostHashMap() + uid1, uid2 = uuid4(), uuid4() + chm.add_cost(uid1, 1.0) + chm.add_cost(uid2, 2.0) + assert chm.get_cost(uid1) == 1.0 + assert chm.get_cost(uid2) == 2.0 + + +class TestCostHashMapAddCost: + """Tests for CostHashMap.add_cost.""" + + def test_add_cost_accumulates(self): + chm = CostHashMap() + uid = uuid4() + chm.add_cost(uid, 1.0) + chm.add_cost(uid, 2.5) + chm.add_cost(uid, 0.5) + assert chm.get_cost(uid) == 4.0 + + def test_add_cost_creates_entry(self): + chm = CostHashMap() + uid = uuid4() + chm.add_cost(uid, 5.0) + assert uid in chm.cost_hash_map + assert chm.cost_hash_map[uid].cost == 5.0 + + def test_add_cost_updates_last_accessed(self): + chm = CostHashMap() + uid = uuid4() + before = time.time() + chm.add_cost(uid, 1.0) + after = time.time() + assert before <= chm.cost_hash_map[uid].last_accessed_at <= after + + +class TestCostHashMapCleanup: + """Tests for CostHashMap._cleanup method.""" + + def test_cleanup_removes_stale_entries(self): + chm = CostHashMap() + uid = uuid4() + chm.add_cost(uid, 1.0) + # Manually age the entry + chm.cost_hash_map[uid].last_accessed_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 + # Force cleanup by setting last_cleanup_at to the past + chm.last_cleanup_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 + chm._cleanup() + assert uid not in chm.cost_hash_map + + def test_cleanup_preserves_fresh_entries(self): + chm = CostHashMap() + uid = uuid4() + chm.add_cost(uid, 1.0) + chm.last_cleanup_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 + chm._cleanup() + assert uid in chm.cost_hash_map + + def test_cleanup_skipped_when_recent(self): + chm = CostHashMap() + stale_uid = uuid4() + chm.add_cost(stale_uid, 1.0) + chm.cost_hash_map[stale_uid].last_accessed_at = time.time() - COST_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 10 + # last_cleanup_at is recent, so cleanup should not run + chm.last_cleanup_at = time.time() + chm._cleanup() + # Stale entry should still be there since cleanup was skipped + assert stale_uid in chm.cost_hash_map diff --git a/tests/test_diff.py b/tests/test_diff.py new file mode 100644 index 00000000..e363b886 --- /dev/null +++ b/tests/test_diff.py @@ -0,0 +1,154 @@ +"""Tests for utils/diff.py — file diff computation, validation, and application.""" + +import os +import pytest +import tempfile +import subprocess + +from unittest.mock import patch, MagicMock +from utils.diff import get_file_diff, validate_diff_for_local_repo, apply_diff_to_local_repo + + +class TestGetFileDiff: + """Tests for the get_file_diff function.""" + + def _write_temp(self, content: str) -> str: + fd, path = tempfile.mkstemp(suffix=".txt") + with os.fdopen(fd, "w") as f: + f.write(content) + return path + + def test_identical_files_returns_empty_diff(self): + path_a = self._write_temp("hello\nworld\n") + path_b = self._write_temp("hello\nworld\n") + try: + diff = get_file_diff(path_a, path_b) + assert diff.strip() == "" + finally: + os.unlink(path_a) + os.unlink(path_b) + + def test_different_files_returns_unified_diff(self): + path_a = self._write_temp("line1\nline2\n") + path_b = self._write_temp("line1\nmodified\n") + try: + diff = get_file_diff(path_a, path_b) + assert "-line2" in diff + assert "+modified" in diff + finally: + os.unlink(path_a) + os.unlink(path_b) + + def test_diff_header_uses_basename(self): + path_a = self._write_temp("a\n") + path_b = self._write_temp("b\n") + try: + diff = get_file_diff(path_a, path_b) + basename = os.path.basename(path_a) + assert f"--- {basename}" in diff + assert f"+++ {basename}" in diff + finally: + os.unlink(path_a) + os.unlink(path_b) + + def test_missing_file_raises_exception(self): + existing = self._write_temp("content\n") + try: + with pytest.raises(Exception): + get_file_diff(existing, "/nonexistent/file.txt") + finally: + os.unlink(existing) + + def test_both_files_missing_raises_exception(self): + with pytest.raises(Exception): + get_file_diff("/nonexistent/a.txt", "/nonexistent/b.txt") + + def test_added_lines_in_diff(self): + path_a = self._write_temp("line1\n") + path_b = self._write_temp("line1\nline2\nline3\n") + try: + diff = get_file_diff(path_a, path_b) + assert "+line2" in diff + assert "+line3" in diff + finally: + os.unlink(path_a) + os.unlink(path_b) + + def test_removed_lines_in_diff(self): + path_a = self._write_temp("line1\nline2\nline3\n") + path_b = self._write_temp("line1\n") + try: + diff = get_file_diff(path_a, path_b) + assert "-line2" in diff + assert "-line3" in diff + finally: + os.unlink(path_a) + os.unlink(path_b) + + +class TestValidateDiffForLocalRepo: + """Tests for the validate_diff_for_local_repo function.""" + + def _create_git_repo(self, files: dict) -> str: + repo_dir = tempfile.mkdtemp() + subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_dir, capture_output=True) + subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_dir, capture_output=True) + for name, content in files.items(): + filepath = os.path.join(repo_dir, name) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + f.write(content) + subprocess.run(["git", "add", "."], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "commit", "-m", "init"], cwd=repo_dir, capture_output=True, check=True) + return repo_dir + + def test_valid_diff_returns_true(self): + repo = self._create_git_repo({"hello.txt": "line1\nline2\n"}) + diff = "--- a/hello.txt\n+++ b/hello.txt\n@@ -1,2 +1,2 @@\n line1\n-line2\n+modified\n" + is_valid, error = validate_diff_for_local_repo(diff, repo) + assert is_valid is True + assert error is None + + def test_invalid_diff_returns_false(self): + repo = self._create_git_repo({"hello.txt": "line1\n"}) + diff = "--- a/nonexistent.txt\n+++ b/nonexistent.txt\n@@ -1 +1 @@\n-old\n+new\n" + is_valid, error = validate_diff_for_local_repo(diff, repo) + assert is_valid is False + assert error is not None + + def test_empty_diff_is_valid(self): + repo = self._create_git_repo({"hello.txt": "content\n"}) + is_valid, error = validate_diff_for_local_repo("", repo) + assert is_valid is True + + +class TestApplyDiffToLocalRepo: + """Tests for the apply_diff_to_local_repo function.""" + + def _create_git_repo(self, files: dict) -> str: + repo_dir = tempfile.mkdtemp() + subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_dir, capture_output=True) + subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_dir, capture_output=True) + for name, content in files.items(): + filepath = os.path.join(repo_dir, name) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "w") as f: + f.write(content) + subprocess.run(["git", "add", "."], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "commit", "-m", "init"], cwd=repo_dir, capture_output=True, check=True) + return repo_dir + + def test_apply_valid_diff_modifies_file(self): + repo = self._create_git_repo({"hello.txt": "line1\nline2\n"}) + diff = "--- a/hello.txt\n+++ b/hello.txt\n@@ -1,2 +1,2 @@\n line1\n-line2\n+modified\n" + apply_diff_to_local_repo(diff, repo) + with open(os.path.join(repo, "hello.txt")) as f: + assert f.read() == "line1\nmodified\n" + + def test_apply_invalid_diff_raises(self): + repo = self._create_git_repo({"hello.txt": "content\n"}) + diff = "--- a/nonexistent.txt\n+++ b/nonexistent.txt\n@@ -1 +1 @@\n-old\n+new\n" + with pytest.raises(Exception): + apply_diff_to_local_repo(diff, repo) diff --git a/tests/test_evaluation_models.py b/tests/test_evaluation_models.py new file mode 100644 index 00000000..1b896af5 --- /dev/null +++ b/tests/test_evaluation_models.py @@ -0,0 +1,107 @@ +"""Tests for models/ — Pydantic model validation for core domain objects.""" + +import pytest +from uuid import uuid4 +from datetime import datetime + +from models.evaluation_run import EvaluationRun, EvaluationRunStatus, EvaluationRunErrorCode + + +class TestEvaluationRunStatus: + """Tests for EvaluationRunStatus enum.""" + + def test_all_statuses_exist(self): + expected = { + "pending", "initializing_agent", "running_agent", + "initializing_eval", "running_eval", "finished", "error" + } + actual = {s.value for s in EvaluationRunStatus} + assert expected.issubset(actual) + + def test_status_values_are_strings(self): + for status in EvaluationRunStatus: + assert isinstance(status.value, str) + + +class TestEvaluationRunErrorCode: + """Tests for EvaluationRunErrorCode enum.""" + + def test_validator_internal_error_exists(self): + assert hasattr(EvaluationRunErrorCode, "VALIDATOR_INTERNAL_ERROR") + + def test_validator_unknown_problem_exists(self): + assert hasattr(EvaluationRunErrorCode, "VALIDATOR_UNKNOWN_PROBLEM") + + def test_get_error_message_returns_string(self): + for code in EvaluationRunErrorCode: + msg = code.get_error_message() + assert isinstance(msg, str) + assert len(msg) > 0 + + def test_error_code_categories(self): + assert EvaluationRunErrorCode.AGENT_EXCEPTION_RUNNING_AGENT.is_agent_error() + assert not EvaluationRunErrorCode.AGENT_EXCEPTION_RUNNING_AGENT.is_validator_error() + assert EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.is_validator_error() + assert not EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.is_agent_error() + assert EvaluationRunErrorCode.PLATFORM_RESTARTED_WHILE_PENDING.is_platform_error() + assert not EvaluationRunErrorCode.PLATFORM_RESTARTED_WHILE_PENDING.is_agent_error() + + def test_all_agent_errors_in_1xxx_range(self): + for code in EvaluationRunErrorCode: + if code.is_agent_error(): + assert 1000 <= code.value < 2000 + + def test_all_validator_errors_in_2xxx_range(self): + for code in EvaluationRunErrorCode: + if code.is_validator_error(): + assert 2000 <= code.value < 3000 + + def test_all_platform_errors_in_3xxx_range(self): + for code in EvaluationRunErrorCode: + if code.is_platform_error(): + assert 3000 <= code.value < 4000 + + +class TestEvaluationRun: + """Tests for EvaluationRun model.""" + + def test_minimal_creation(self): + run = EvaluationRun( + evaluation_run_id=uuid4(), + evaluation_id=uuid4(), + problem_name="test-problem", + status=EvaluationRunStatus.pending, + created_at=datetime.now(), + ) + assert run.status == EvaluationRunStatus.pending + assert run.patch is None + assert run.error_code is None + + def test_finished_run_with_results(self): + from models.problem import ProblemTestResult, ProblemTestResultStatus + run = EvaluationRun( + evaluation_run_id=uuid4(), + evaluation_id=uuid4(), + problem_name="test-problem", + status=EvaluationRunStatus.finished, + patch="--- a/file.py\n+++ b/file.py\n", + test_results=[ + ProblemTestResult(name="test1", category="default", status=ProblemTestResultStatus.PASS), + ], + created_at=datetime.now(), + finished_or_errored_at=datetime.now(), + ) + assert run.status == EvaluationRunStatus.finished + assert len(run.test_results) == 1 + + def test_error_run(self): + run = EvaluationRun( + evaluation_run_id=uuid4(), + evaluation_id=uuid4(), + problem_name="test-problem", + status=EvaluationRunStatus.error, + error_code=EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR, + error_message="Something went wrong", + created_at=datetime.now(), + ) + assert run.error_code == EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py new file mode 100644 index 00000000..f3da5d35 --- /dev/null +++ b/tests/test_git_utils.py @@ -0,0 +1,146 @@ +"""Tests for utils/git.py — Git repository operations.""" + +import os +import pytest +import tempfile +import subprocess + +from utils.git import ( + clone_local_repo_at_commit, + verify_commit_exists_in_local_repo, + init_local_repo_with_initial_commit, + reset_local_repo, + get_local_repo_commit_hash, +) + + +def _create_repo_with_commits(num_commits: int = 2) -> tuple: + """Create a temporary git repo with multiple commits. Returns (repo_dir, list_of_commit_hashes).""" + repo_dir = tempfile.mkdtemp() + subprocess.run(["git", "init"], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "config", "user.email", "test@test.com"], cwd=repo_dir, capture_output=True) + subprocess.run(["git", "config", "user.name", "Test"], cwd=repo_dir, capture_output=True) + + hashes = [] + for i in range(num_commits): + filepath = os.path.join(repo_dir, f"file_{i}.txt") + with open(filepath, "w") as f: + f.write(f"commit {i}\n") + subprocess.run(["git", "add", "."], cwd=repo_dir, capture_output=True, check=True) + subprocess.run(["git", "commit", "-m", f"commit {i}"], cwd=repo_dir, capture_output=True, check=True) + result = subprocess.run(["git", "rev-parse", "HEAD"], cwd=repo_dir, capture_output=True, text=True, check=True) + hashes.append(result.stdout.strip()) + + return repo_dir, hashes + + +class TestGetLocalRepoCommitHash: + """Tests for get_local_repo_commit_hash.""" + + def test_returns_correct_hash(self): + repo_dir, hashes = _create_repo_with_commits(1) + assert get_local_repo_commit_hash(repo_dir) == hashes[0] + + def test_returns_latest_commit(self): + repo_dir, hashes = _create_repo_with_commits(3) + assert get_local_repo_commit_hash(repo_dir) == hashes[-1] + + def test_hash_is_40_char_hex(self): + repo_dir, _ = _create_repo_with_commits(1) + commit_hash = get_local_repo_commit_hash(repo_dir) + assert len(commit_hash) == 40 + assert all(c in "0123456789abcdef" for c in commit_hash) + + +class TestVerifyCommitExistsInLocalRepo: + """Tests for verify_commit_exists_in_local_repo.""" + + def test_existing_commit_returns_true(self): + repo_dir, hashes = _create_repo_with_commits(2) + assert verify_commit_exists_in_local_repo(repo_dir, hashes[0]) is True + assert verify_commit_exists_in_local_repo(repo_dir, hashes[1]) is True + + def test_nonexistent_commit_returns_false(self): + repo_dir, _ = _create_repo_with_commits(1) + assert verify_commit_exists_in_local_repo(repo_dir, "a" * 40) is False + + def test_nonexistent_directory_returns_false(self): + assert verify_commit_exists_in_local_repo("/nonexistent/dir", "abc123") is False + + +class TestInitLocalRepoWithInitialCommit: + """Tests for init_local_repo_with_initial_commit.""" + + def test_creates_git_repo(self): + temp_dir = tempfile.mkdtemp() + with open(os.path.join(temp_dir, "file.txt"), "w") as f: + f.write("hello\n") + init_local_repo_with_initial_commit(temp_dir) + assert os.path.exists(os.path.join(temp_dir, ".git")) + + def test_initial_commit_exists(self): + temp_dir = tempfile.mkdtemp() + with open(os.path.join(temp_dir, "file.txt"), "w") as f: + f.write("hello\n") + init_local_repo_with_initial_commit(temp_dir) + result = subprocess.run(["git", "log", "--oneline"], cwd=temp_dir, capture_output=True, text=True) + assert "Initial commit" in result.stdout + + def test_custom_commit_message(self): + temp_dir = tempfile.mkdtemp() + with open(os.path.join(temp_dir, "file.txt"), "w") as f: + f.write("hello\n") + init_local_repo_with_initial_commit(temp_dir, "Custom message") + result = subprocess.run(["git", "log", "--oneline"], cwd=temp_dir, capture_output=True, text=True) + assert "Custom message" in result.stdout + + def test_all_files_are_committed(self): + temp_dir = tempfile.mkdtemp() + for name in ["a.txt", "b.txt", "c.txt"]: + with open(os.path.join(temp_dir, name), "w") as f: + f.write(f"{name}\n") + init_local_repo_with_initial_commit(temp_dir) + result = subprocess.run(["git", "status", "--porcelain"], cwd=temp_dir, capture_output=True, text=True) + assert result.stdout.strip() == "" + + +class TestCloneLocalRepoAtCommit: + """Tests for clone_local_repo_at_commit.""" + + def test_clone_at_first_commit(self): + repo_dir, hashes = _create_repo_with_commits(3) + target = tempfile.mkdtemp() + os.rmdir(target) + clone_local_repo_at_commit(repo_dir, hashes[0], target) + cloned_hash = get_local_repo_commit_hash(target) + assert cloned_hash == hashes[0] + + def test_clone_at_latest_commit(self): + repo_dir, hashes = _create_repo_with_commits(2) + target = tempfile.mkdtemp() + os.rmdir(target) + clone_local_repo_at_commit(repo_dir, hashes[-1], target) + cloned_hash = get_local_repo_commit_hash(target) + assert cloned_hash == hashes[-1] + + def test_clone_nonexistent_repo_raises(self): + target = tempfile.mkdtemp() + os.rmdir(target) + with pytest.raises(Exception): + clone_local_repo_at_commit("/nonexistent/repo", "abc123", target) + + +class TestResetLocalRepo: + """Tests for reset_local_repo.""" + + def test_reset_to_earlier_commit(self): + repo_dir, hashes = _create_repo_with_commits(3) + reset_local_repo(repo_dir, hashes[0]) + assert get_local_repo_commit_hash(repo_dir) == hashes[0] + + def test_file_from_later_commit_is_gone(self): + repo_dir, hashes = _create_repo_with_commits(3) + reset_local_repo(repo_dir, hashes[0]) + assert not os.path.exists(os.path.join(repo_dir, "file_1.txt")) + assert not os.path.exists(os.path.join(repo_dir, "file_2.txt")) + assert os.path.exists(os.path.join(repo_dir, "file_0.txt")) diff --git a/tests/test_inference_models.py b/tests/test_inference_models.py new file mode 100644 index 00000000..1cece211 --- /dev/null +++ b/tests/test_inference_models.py @@ -0,0 +1,228 @@ +"""Tests for inference_gateway/models.py — model definitions, tool conversion, and cost calculations.""" + +import json +import pytest + +from inference_gateway.models import ( + InferenceModelInfo, + EmbeddingModelInfo, + EmbeddingModelPricingMode, + InferenceToolCall, + InferenceToolCallArgument, + InferenceToolParameter, + InferenceToolParameterType, + InferenceTool, + InferenceToolMode, + InferenceMessage, + InferenceRequest, + InferenceResponse, + EmbeddingRequest, + EmbeddingResponse, + inference_tool_parameters_to_openai_parameters, + inference_tools_to_openai_tools, + inference_tool_mode_to_openai_tool_choice, + openai_tool_calls_to_inference_tool_calls, +) +from uuid import uuid4 + + +class TestInferenceModelInfo: + """Tests for InferenceModelInfo cost calculation.""" + + def test_cost_calculation_basic(self): + model = InferenceModelInfo( + name="test-model", + external_name="ext-test", + max_input_tokens=4096, + cost_usd_per_million_input_tokens=1.0, + cost_usd_per_million_output_tokens=2.0, + ) + cost = model.get_cost_usd(num_input_tokens=1_000_000, num_output_tokens=1_000_000) + assert cost == 3.0 + + def test_cost_calculation_zero_tokens(self): + model = InferenceModelInfo( + name="test", external_name="test", + max_input_tokens=4096, + cost_usd_per_million_input_tokens=10.0, + cost_usd_per_million_output_tokens=20.0, + ) + assert model.get_cost_usd(0, 0) == 0.0 + + def test_cost_calculation_fractional(self): + model = InferenceModelInfo( + name="test", external_name="test", + max_input_tokens=4096, + cost_usd_per_million_input_tokens=3.0, + cost_usd_per_million_output_tokens=6.0, + ) + cost = model.get_cost_usd(500_000, 250_000) + assert abs(cost - 3.0) < 1e-10 # 1.5 + 1.5 + + +class TestEmbeddingModelInfo: + """Tests for EmbeddingModelInfo cost calculation.""" + + def test_per_token_pricing(self): + model = EmbeddingModelInfo( + name="embed", external_name="embed", + max_input_tokens=8192, + pricing_mode=EmbeddingModelPricingMode.PER_TOKEN, + cost_usd_per_million_input_tokens=0.1, + ) + cost = model.get_cost_usd(num_input_tokens=2_000_000, num_seconds=0) + assert abs(cost - 0.2) < 1e-10 + + def test_per_second_pricing(self): + model = EmbeddingModelInfo( + name="embed", external_name="embed", + max_input_tokens=8192, + pricing_mode=EmbeddingModelPricingMode.PER_SECOND, + cost_usd_per_second=0.001, + ) + cost = model.get_cost_usd(num_input_tokens=0, num_seconds=60) + assert abs(cost - 0.06) < 1e-10 + + def test_default_pricing_mode_is_per_token(self): + model = EmbeddingModelInfo( + name="embed", external_name="embed", + max_input_tokens=8192, + cost_usd_per_million_input_tokens=1.0, + ) + assert model.pricing_mode == EmbeddingModelPricingMode.PER_TOKEN + + +class TestInferenceToolConversions: + """Tests for tool parameter and tool conversion functions.""" + + def test_parameters_to_openai_format(self): + params = [ + InferenceToolParameter( + type=InferenceToolParameterType.STRING, + name="query", + description="Search query", + required=True, + ), + InferenceToolParameter( + type=InferenceToolParameterType.INTEGER, + name="limit", + description="Max results", + required=False, + ), + ] + result = inference_tool_parameters_to_openai_parameters(params) + assert "query" in result["properties"] + assert result["properties"]["query"]["type"] == "string" + assert "query" in result["required"] + assert "limit" not in result["required"] + + def test_tools_to_openai_tools(self): + tools = [ + InferenceTool( + name="search", + description="Search the web", + parameters=[ + InferenceToolParameter( + type=InferenceToolParameterType.STRING, + name="q", + description="Query", + required=True, + ) + ], + ) + ] + openai_tools = inference_tools_to_openai_tools(tools) + assert len(openai_tools) == 1 + assert openai_tools[0]["type"] == "function" + assert openai_tools[0]["function"]["name"] == "search" + + def test_tool_mode_none_to_openai(self): + assert inference_tool_mode_to_openai_tool_choice(InferenceToolMode.NONE) == "none" + + def test_tool_mode_auto_to_openai(self): + assert inference_tool_mode_to_openai_tool_choice(InferenceToolMode.AUTO) == "auto" + + def test_tool_mode_required_to_openai(self): + assert inference_tool_mode_to_openai_tool_choice(InferenceToolMode.REQUIRED) == "required" + + def test_empty_tools_list_conversion(self): + assert inference_tools_to_openai_tools([]) == [] + + +class TestInferenceToolCallArgument: + """Tests for InferenceToolCallArgument and InferenceToolCall.""" + + def test_tool_call_argument_creation(self): + arg = InferenceToolCallArgument(name="param1", value="value1") + assert arg.name == "param1" + assert arg.value == "value1" + + def test_tool_call_with_multiple_args(self): + tc = InferenceToolCall( + name="my_tool", + arguments=[ + InferenceToolCallArgument(name="a", value=1), + InferenceToolCallArgument(name="b", value="hello"), + ], + ) + assert tc.name == "my_tool" + assert len(tc.arguments) == 2 + + +class TestInferenceRequestResponse: + """Tests for request/response model validation.""" + + def test_inference_request_creation(self): + req = InferenceRequest( + evaluation_run_id=uuid4(), + model="gpt-4", + temperature=0.7, + messages=[InferenceMessage(role="user", content="Hello")], + ) + assert req.model == "gpt-4" + assert len(req.messages) == 1 + + def test_inference_request_default_tool_mode(self): + req = InferenceRequest( + evaluation_run_id=uuid4(), + model="gpt-4", + temperature=0.0, + messages=[InferenceMessage(role="user", content="Hi")], + ) + assert req.tool_mode == InferenceToolMode.NONE + + def test_inference_response_creation(self): + resp = InferenceResponse(content="Hello!", tool_calls=[]) + assert resp.content == "Hello!" + assert resp.tool_calls == [] + + def test_embedding_request_creation(self): + req = EmbeddingRequest( + evaluation_run_id=uuid4(), + model="text-embedding-ada-002", + input="Hello world", + ) + assert req.input == "Hello world" + + def test_embedding_response_creation(self): + resp = EmbeddingResponse(embedding=[0.1, 0.2, 0.3]) + assert len(resp.embedding) == 3 + + +class TestAllToolParameterTypes: + """Ensure all parameter types are valid.""" + + def test_all_parameter_types_exist(self): + expected = {"boolean", "integer", "number", "string", "array", "object"} + actual = {t.value for t in InferenceToolParameterType} + assert actual == expected + + def test_each_type_converts_to_openai(self): + for ptype in InferenceToolParameterType: + params = [ + InferenceToolParameter( + type=ptype, name="test", description="test", required=True + ) + ] + result = inference_tool_parameters_to_openai_parameters(params) + assert result["properties"]["test"]["type"] == ptype.value diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 00000000..07c09049 --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,69 @@ +"""Tests for utils/logger.py — logging utilities.""" + +import os +import pytest +from unittest.mock import patch +from io import StringIO + +import utils.logger as logger + + +class TestLogLevels: + """Tests for different log levels.""" + + def test_info_prints_output(self, capsys): + logger.info("test info message") + captured = capsys.readouterr() + assert "INFO" in captured.out + assert "test info message" in captured.out + + def test_warning_prints_output(self, capsys): + logger.warning("test warning") + captured = capsys.readouterr() + assert "WARNING" in captured.out + assert "test warning" in captured.out + + def test_error_prints_output(self, capsys): + logger.error("test error") + captured = capsys.readouterr() + assert "ERROR" in captured.out + assert "test error" in captured.out + + def test_fatal_raises_exception(self): + with pytest.raises(Exception, match="fatal message"): + logger.fatal("fatal message") + + def test_fatal_prints_before_raising(self, capsys): + with pytest.raises(Exception): + logger.fatal("fatal msg") + captured = capsys.readouterr() + assert "FATAL" in captured.out + + @patch.dict(os.environ, {"DEBUG": "true"}) + def test_debug_prints_when_enabled(self, capsys): + logger.debug("debug message") + captured = capsys.readouterr() + assert "debug message" in captured.out + + @patch.dict(os.environ, {"DEBUG": "false"}) + def test_debug_silent_when_disabled(self, capsys): + logger.debug("should not appear") + captured = capsys.readouterr() + assert "should not appear" not in captured.out + + +class TestLogFormat: + """Tests for log message formatting.""" + + def test_log_contains_timestamp(self, capsys): + logger.info("timestamp test") + captured = capsys.readouterr() + # Should contain date-like pattern + assert "-" in captured.out # YYYY-MM-DD + assert ":" in captured.out # HH:MM:SS + + def test_log_contains_file_and_line(self, capsys): + logger.info("location test") + captured = capsys.readouterr() + # Should contain the test file name + assert "test_logger.py" in captured.out diff --git a/tests/test_temp_utils.py b/tests/test_temp_utils.py new file mode 100644 index 00000000..cd54cdc0 --- /dev/null +++ b/tests/test_temp_utils.py @@ -0,0 +1,58 @@ +"""Tests for utils/temp.py — temporary directory management.""" + +import os +import pytest + +from utils.temp import create_temp_dir, delete_temp_dir + + +class TestCreateTempDir: + """Tests for create_temp_dir.""" + + def test_creates_directory(self): + temp_dir = create_temp_dir() + try: + assert os.path.isdir(temp_dir) + finally: + delete_temp_dir(temp_dir) + + def test_directory_is_unique(self): + dirs = [create_temp_dir() for _ in range(5)] + try: + assert len(set(dirs)) == 5 + finally: + for d in dirs: + delete_temp_dir(d) + + def test_directory_is_writable(self): + temp_dir = create_temp_dir() + try: + test_file = os.path.join(temp_dir, "test.txt") + with open(test_file, "w") as f: + f.write("hello") + assert os.path.exists(test_file) + finally: + delete_temp_dir(temp_dir) + + +class TestDeleteTempDir: + """Tests for delete_temp_dir.""" + + def test_removes_directory(self): + temp_dir = create_temp_dir() + delete_temp_dir(temp_dir) + assert not os.path.exists(temp_dir) + + def test_removes_directory_with_contents(self): + temp_dir = create_temp_dir() + # Create nested structure + nested = os.path.join(temp_dir, "sub", "deep") + os.makedirs(nested) + with open(os.path.join(nested, "file.txt"), "w") as f: + f.write("content") + delete_temp_dir(temp_dir) + assert not os.path.exists(temp_dir) + + def test_nonexistent_directory_does_not_raise(self): + # Should not raise due to ignore_errors=True + delete_temp_dir("/nonexistent/temp/dir/12345") diff --git a/tests/test_ttl_cache.py b/tests/test_ttl_cache.py new file mode 100644 index 00000000..abe4a6e5 --- /dev/null +++ b/tests/test_ttl_cache.py @@ -0,0 +1,107 @@ +"""Tests for utils/ttl.py — TTL cache decorator.""" + +import asyncio +import pytest +import time + +from utils.ttl import ttl_cache, TTLCacheEntry, _args_and_kwargs_to_ttl_cache_key +from datetime import datetime, timezone, timedelta + + +class TestTTLCacheKeyGeneration: + """Tests for cache key generation.""" + + def test_same_args_produce_same_key(self): + key1 = _args_and_kwargs_to_ttl_cache_key((1, 2), {"a": 3}) + key2 = _args_and_kwargs_to_ttl_cache_key((1, 2), {"a": 3}) + assert key1 == key2 + + def test_different_args_produce_different_keys(self): + key1 = _args_and_kwargs_to_ttl_cache_key((1,), {}) + key2 = _args_and_kwargs_to_ttl_cache_key((2,), {}) + assert key1 != key2 + + def test_kwargs_order_does_not_matter(self): + key1 = _args_and_kwargs_to_ttl_cache_key((), {"a": 1, "b": 2}) + key2 = _args_and_kwargs_to_ttl_cache_key((), {"b": 2, "a": 1}) + assert key1 == key2 + + def test_empty_args_and_kwargs(self): + key = _args_and_kwargs_to_ttl_cache_key((), {}) + assert key == ((), ()) + + +class TestTTLCacheEntry: + """Tests for TTLCacheEntry model.""" + + def test_entry_creation(self): + entry = TTLCacheEntry( + expires_at=datetime.now(timezone.utc) + timedelta(seconds=60), + value="test_value", + ) + assert entry.value == "test_value" + + def test_entry_expiry(self): + entry = TTLCacheEntry( + expires_at=datetime.now(timezone.utc) - timedelta(seconds=1), + value="expired", + ) + assert datetime.now(timezone.utc) >= entry.expires_at + + +class TestTTLCacheDecorator: + """Tests for the ttl_cache decorator.""" + + @pytest.mark.asyncio + async def test_caches_result(self): + call_count = 0 + + @ttl_cache(ttl_seconds=10) + async def compute(x): + nonlocal call_count + call_count += 1 + return x * 2 + + result1 = await compute(5) + result2 = await compute(5) + assert result1 == 10 + assert result2 == 10 + assert call_count == 1 + + @pytest.mark.asyncio + async def test_different_args_not_cached_together(self): + call_count = 0 + + @ttl_cache(ttl_seconds=10) + async def compute(x): + nonlocal call_count + call_count += 1 + return x + 1 + + await compute(1) + await compute(2) + assert call_count == 2 + + @pytest.mark.asyncio + async def test_max_entries_eviction(self): + @ttl_cache(ttl_seconds=60, max_entries=5) + async def compute(x): + return x + + # Fill beyond max + for i in range(10): + await compute(i) + + # Should still work (eviction is best-effort) + result = await compute(999) + assert result == 999 + + @pytest.mark.asyncio + async def test_returns_correct_type(self): + @ttl_cache(ttl_seconds=10) + async def get_dict(): + return {"key": "value"} + + result = await get_dict() + assert isinstance(result, dict) + assert result["key"] == "value"