From d5a6ee2b119186f049b798e2db2e6e2abab73731 Mon Sep 17 00:00:00 2001 From: dominik <01197296@pw.edu.pl> Date: Mon, 19 Jan 2026 16:08:00 +0100 Subject: [PATCH] feat: add unit test and metric generator --- generate_metrics.py | 346 +++++++++++++++++++ pytest.ini | 3 +- requirements.txt | Bin 2792 -> 1391 bytes tests/__init__.py | 1 + tests/cli/__init__.py | 1 + tests/cli/test_cli_app.py | 286 +++++++++++++++ tests/cli/test_command_tree.py | 203 +++++++++++ tests/cli/test_handlers.py | 498 +++++++++++++++++++++++++++ tests/cli/test_loop.py | 132 +++++++ tests/cli/test_main.py | 280 +++++++++++++++ tests/conftest.py | 100 ++++++ tests/fastapi/__init__.py | 1 + tests/fastapi/test_crud.py | 344 ++++++++++++++++++ tests/fastapi/test_main.py | 353 +++++++++++++++++++ tests/fastapi/test_models.py | 186 ++++++++++ tests/scraper/__init__.py | 1 + tests/scraper/test_analyzers.py | 303 ++++++++++++++++ tests/scraper/test_config_mapper.py | 144 ++++++++ tests/scraper/test_config_utils.py | 176 ++++++++++ tests/scraper/test_engine.py | 343 ++++++++++++++++++ tests/scraper/test_labeler.py | 181 ++++++++++ tests/scraper/test_scraper_config.py | 253 ++++++++++++++ tests/scraper/test_server.py | 179 ++++++++++ 23 files changed, 4313 insertions(+), 1 deletion(-) create mode 100644 generate_metrics.py create mode 100644 tests/__init__.py create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/test_cli_app.py create mode 100644 tests/cli/test_command_tree.py create mode 100644 tests/cli/test_handlers.py create mode 100644 tests/cli/test_loop.py create mode 100644 tests/cli/test_main.py create mode 100644 tests/conftest.py create mode 100644 tests/fastapi/__init__.py create mode 100644 tests/fastapi/test_crud.py create mode 100644 tests/fastapi/test_main.py create mode 100644 tests/fastapi/test_models.py create mode 100644 tests/scraper/__init__.py create mode 100644 tests/scraper/test_analyzers.py create mode 100644 tests/scraper/test_config_mapper.py create mode 100644 tests/scraper/test_config_utils.py create mode 100644 tests/scraper/test_engine.py create mode 100644 tests/scraper/test_labeler.py create mode 100644 tests/scraper/test_scraper_config.py create mode 100644 tests/scraper/test_server.py diff --git a/generate_metrics.py b/generate_metrics.py new file mode 100644 index 0000000..82845b0 --- /dev/null +++ b/generate_metrics.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python3 +""" +Generate repository metrics report. + +Calculates: +- Total File Count (source files) +- Lines of Code (LOC) +- Number of Unit Tests +- Code Coverage Percentage + +Usage: + python generate_metrics.py [--run-tests] [--output FORMAT] + +Options: + --run-tests Run pytest to get actual coverage (requires pytest-cov) + --output Output format: text, json, or markdown (default: text) +""" +import argparse +import json +import os +import re +import subprocess +import sys +from pathlib import Path +from typing import Dict, List, Tuple + + +# File extensions to count as source files +SOURCE_EXTENSIONS = {'.py', '.yaml', '.yml', '.sh', '.js', 'Dockerfile'} + +# Patterns to exclude from counting +EXCLUDE_PATTERNS = { + '.md', '.txt', '.pdf', '.png', '.jpg', '.jpeg', '.gif', + '.drawio', '.gitignore', '.dockerignore', 'LICENSE' +} + +# Directories to exclude +EXCLUDE_DIRS = { + '.git', '__pycache__', 'venv', '.venv', 'node_modules', + '.mypy_cache', '.pytest_cache', 'eggs', '*.egg-info' +} + + +def is_excluded_dir(path: Path) -> bool: + """Check if path contains an excluded directory.""" + parts = path.parts + for exclude in EXCLUDE_DIRS: + if exclude.startswith('*'): + # Pattern match + suffix = exclude[1:] + if any(part.endswith(suffix) for part in parts): + return True + elif exclude in parts: + return True + return False + + +def is_source_file(path: Path) -> bool: + """Check if file should be counted as a source file.""" + if is_excluded_dir(path): + return False + + name = path.name + suffix = path.suffix + + # Check excluded patterns + if suffix in EXCLUDE_PATTERNS or name in EXCLUDE_PATTERNS: + return False + + # Check for Dockerfile (no extension) + if name == 'Dockerfile': + return True + + # Check source extensions + if suffix in SOURCE_EXTENSIONS: + return True + + return False + + +def count_lines(file_path: Path) -> int: + """Count non-empty lines in a file.""" + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + return sum(1 for line in f if line.strip()) + except (IOError, OSError): + return 0 + + +def find_source_files(root_dir: Path) -> List[Path]: + """Find all source files in the repository.""" + source_files = [] + + for path in root_dir.rglob('*'): + if path.is_file() and is_source_file(path): + source_files.append(path) + + return source_files + + +def count_test_functions(test_dir: Path) -> Tuple[int, List[str]]: + """Count pytest test functions in test files.""" + test_count = 0 + test_names = [] + + test_pattern = re.compile(r'^\s*def\s+(test_\w+)\s*\(', re.MULTILINE) + class_pattern = re.compile(r'^\s*class\s+(Test\w+)', re.MULTILINE) + + for test_file in test_dir.rglob('test_*.py'): + if is_excluded_dir(test_file): + continue + + try: + with open(test_file, 'r', encoding='utf-8') as f: + content = f.read() + + # Count test functions + matches = test_pattern.findall(content) + test_count += len(matches) + test_names.extend(matches) + + # Also count test classes (for reference) + class_matches = class_pattern.findall(content) + test_names.extend([f"class:{cls}" for cls in class_matches]) + + except (IOError, OSError): + continue + + return test_count, test_names + + +def run_pytest_coverage(root_dir: Path) -> Tuple[float, str]: + """Run pytest with coverage and return coverage percentage.""" + try: + result = subprocess.run( + [ + sys.executable, '-m', 'pytest', + 'tests/', + '--cov=cli', + '--cov=fastapi_app', + '--cov=scraper', + '--cov-report=term-missing', + '--cov-report=json', + '-q' + ], + cwd=root_dir, + capture_output=True, + text=True, + timeout=300 + ) + + # Try to parse coverage from JSON report + coverage_json = root_dir / 'coverage.json' + if coverage_json.exists(): + with open(coverage_json) as f: + cov_data = json.load(f) + coverage_pct = cov_data.get('totals', {}).get('percent_covered', 0.0) + return coverage_pct, result.stdout + result.stderr + + # Fall back to parsing terminal output + output = result.stdout + result.stderr + coverage_match = re.search(r'TOTAL\s+\d+\s+\d+\s+(\d+)%', output) + if coverage_match: + return float(coverage_match.group(1)), output + + return 0.0, output + + except subprocess.TimeoutExpired: + return 0.0, "Test execution timed out" + except FileNotFoundError: + return 0.0, "pytest not found. Install with: pip install pytest pytest-cov" + except Exception as e: + return 0.0, f"Error running tests: {e}" + + +def generate_report(root_dir: Path, run_tests: bool = False) -> Dict: + """Generate the complete metrics report.""" + metrics = { + 'repository': root_dir.name, + 'source_files': [], + 'total_file_count': 0, + 'total_loc': 0, + 'test_count': 0, + 'coverage_percent': 0.0, + 'test_output': '' + } + + # Find and count source files + source_files = find_source_files(root_dir) + metrics['total_file_count'] = len(source_files) + + # Count lines of code + for file_path in source_files: + loc = count_lines(file_path) + rel_path = file_path.relative_to(root_dir) + metrics['source_files'].append({ + 'path': str(rel_path), + 'loc': loc + }) + metrics['total_loc'] += loc + + # Count test functions + tests_dir = root_dir / 'tests' + if tests_dir.exists(): + test_count, test_names = count_test_functions(tests_dir) + metrics['test_count'] = test_count + metrics['test_names'] = test_names + + # Run coverage if requested + if run_tests: + coverage, output = run_pytest_coverage(root_dir) + metrics['coverage_percent'] = coverage + metrics['test_output'] = output + + return metrics + + +def format_text(metrics: Dict) -> str: + """Format metrics as plain text.""" + lines = [ + "=" * 60, + f"Repository Metrics: {metrics['repository']}", + "=" * 60, + "", + f"๐Ÿ“ Total File Count: {metrics['total_file_count']}", + f"๐Ÿ“ Lines of Code (LOC): {metrics['total_loc']}", + f"๐Ÿงช Number of Unit Tests: {metrics['test_count']}", + f"๐Ÿ“Š Code Coverage: {metrics['coverage_percent']:.1f}%", + "", + "-" * 60, + "Files by Directory:", + "-" * 60, + ] + + # Group files by directory + by_dir: Dict[str, List[Dict]] = {} + for f in metrics['source_files']: + dir_name = str(Path(f['path']).parent) + if dir_name not in by_dir: + by_dir[dir_name] = [] + by_dir[dir_name].append(f) + + for dir_name in sorted(by_dir.keys()): + files = by_dir[dir_name] + dir_loc = sum(f['loc'] for f in files) + lines.append(f"\n{dir_name}/ ({len(files)} files, {dir_loc} LOC)") + for f in sorted(files, key=lambda x: x['path']): + lines.append(f" - {Path(f['path']).name}: {f['loc']} LOC") + + lines.extend([ + "", + "=" * 60, + ]) + + if metrics.get('test_output'): + lines.extend([ + "", + "Test Output:", + "-" * 60, + metrics['test_output'][:2000] # Limit output length + ]) + + return "\n".join(lines) + + +def format_markdown(metrics: Dict) -> str: + """Format metrics as Markdown.""" + lines = [ + f"# Repository Metrics: {metrics['repository']}", + "", + "## Summary", + "", + "| Metric | Value |", + "|--------|-------|", + f"| ๐Ÿ“ Total File Count | {metrics['total_file_count']} |", + f"| ๐Ÿ“ Lines of Code (LOC) | {metrics['total_loc']} |", + f"| ๐Ÿงช Number of Unit Tests | {metrics['test_count']} |", + f"| ๐Ÿ“Š Code Coverage | {metrics['coverage_percent']:.1f}% |", + "", + "## Files by Directory", + "", + ] + + # Group files by directory + by_dir: Dict[str, List[Dict]] = {} + for f in metrics['source_files']: + dir_name = str(Path(f['path']).parent) + if dir_name not in by_dir: + by_dir[dir_name] = [] + by_dir[dir_name].append(f) + + for dir_name in sorted(by_dir.keys()): + files = by_dir[dir_name] + dir_loc = sum(f['loc'] for f in files) + lines.append(f"### `{dir_name}/` ({len(files)} files, {dir_loc} LOC)") + lines.append("") + for f in sorted(files, key=lambda x: x['path']): + lines.append(f"- `{Path(f['path']).name}`: {f['loc']} LOC") + lines.append("") + + return "\n".join(lines) + + +def format_json(metrics: Dict) -> str: + """Format metrics as JSON.""" + # Remove verbose test output for JSON + output = {k: v for k, v in metrics.items() if k != 'test_output'} + return json.dumps(output, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description='Generate repository metrics report') + parser.add_argument( + '--run-tests', + action='store_true', + help='Run pytest to get actual coverage' + ) + parser.add_argument( + '--output', + choices=['text', 'json', 'markdown'], + default='text', + help='Output format (default: text)' + ) + parser.add_argument( + '--root', + type=Path, + default=Path.cwd(), + help='Repository root directory' + ) + + args = parser.parse_args() + + # Generate report + metrics = generate_report(args.root, run_tests=args.run_tests) + + # Format output + if args.output == 'json': + print(format_json(metrics)) + elif args.output == 'markdown': + print(format_markdown(metrics)) + else: + print(format_text(metrics)) + + +if __name__ == '__main__': + main() diff --git a/pytest.ini b/pytest.ini index b893048..1c94b7c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] -pythonpath = src \ No newline at end of file +pythonpath = . fastapi_app cli scraper +asyncio_mode = auto \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 71d1a304682219e5b316366bf650f70bf08dd4d0..08a6f688ff532d86fd56629ceba5c6197a675bb3 100644 GIT binary patch literal 1391 zcmY*ZU60!!5Pavq*fLI>OCETrTB*{hS4C>2s!u0^iA{rn0#45N>pM$6+P-<0Wp;LU zuyoF6nRQ?FUJ23EO}%FFk<3fbiPBGXgY;s&5bK(ALU-4MBB6- z74Vj+V@UwAVQQ2y>0CJ(4sMK@0Iz z^HLC(ko^#47#C340DQ%Im%@@q@6aA>8m1>{Y;{JXUS^M0Ezlcwlqt&qk`3tby1geV zaJbISH~-}P+}}^A_0h_c-VzhS8cgmFBnDxz&nEA03{YCHJ2rA&`hX!Da}LT!OabBQ zU5T^z_NFW^i8m^5w`gQ~2LXVCPiK^OM>E_IHf>-^PwM)i3 zP)8m;gH$Iu7)K1d66Y}mUA)9wfCM`aLxCYFLiAI}uQdIM+(U%AmW4(6UZ=cR(84V+ zr(7V@l=ZbM#JzOc02u`>oO4gSUbRJv*B#q0pQ5q$dTocIvN!N0yEQn*uErUwBdc)4 z!=zotIZ$_)bH6`e~$Tl0gVq0}PWh_$S_| zl=WHnrC|rChS3iQ2B2YYQG&JeW9TA7+i~7(sTMpcgTH}RG&d)p#_Peu?CMWYop<85*Jju4NZ}NM4;Z`H32! L1c?F&H5>K^i#4G6 literal 2792 zcmZ{mU2hs^5QX3CO8F^R2IDlj$VHV(k)ouklt@*tELhgy#(b0oWBbRqJo;g4r^OBino>3UAfhR?~#0tigoZa z)PKZiq)~NTOZr$S{#JRn@-zqxwq>dB?&b42BH;N--pJetG$C!uw}?u%$cKHbwWMqL zoP?jm4{O)dNCsLvY0X1#Qx^Kh=)9I5cwdFLnKJ66`y}nf5s8uHM0E{q#JN+{L3!vi zpXoRLx8W6iGPM4!WSyiF`896D11m5yV&^&oAqXB>i}hX2|7otvJ{WXvRWZ7iw~5ee zRRuobO3iPsNITh2Lk4b{AHDn#(%)YoV&C-VI&fR4mf_NDAb5g+BDRVAG&0XKKU7TME{ zvSS)2nkUdpT-gM!6b&YbS!$(AmC35iXrmhTvYq#{NQ}yxjcQzL(wN(us0wa;3w%#8 zyY6dXX6f`y@tC$JX&YVoJ&bLVxuHph-a~a}NBtJS7ZbgbANWH&d*P>d(o3D2h`v3E zIabt*em@$^bvLXStDmDb3P{V-~J7tEjq5p>-cI+;L|{%{*MNYV#c4>Q$?l{aJn~ zpUbC^pI?l><=+*1USO4{%o%0%Fr(_a!)ROuv%E#&N294Rvh*GeXPB~qYVn?|v)~T; z4pR-fb9Iq>m`z- z;VIviJayUkL^h?JurNoz=Rsu)?>pSOOHE$&ZG71Net%%+u{%!RrQFY#Hb36E=nqe- z@Yd%dO@&hSt5&%e3chw*Wsz82#fC744l#2u49vmxe7k#!56b-@`AijRo*Ic+hn%M`Tqd<#*n@M diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..b19282a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# FixMyCodeDB Test Suite diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 0000000..3d7483e --- /dev/null +++ b/tests/cli/__init__.py @@ -0,0 +1 @@ +# CLI Module Tests diff --git a/tests/cli/test_cli_app.py b/tests/cli/test_cli_app.py new file mode 100644 index 0000000..9fa769d --- /dev/null +++ b/tests/cli/test_cli_app.py @@ -0,0 +1,286 @@ +""" +Unit tests for cli/cli_app.py +Tests interactive CLI functions with mocked I/O. +""" +import pytest +from unittest.mock import patch, MagicMock +import json + + +class TestBuildApiPayload: + """Tests for build_api_payload function.""" + + def test_empty_params(self): + """Test with empty params returns empty filter.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({}) + + assert result == {} + + def test_repo_url_param(self): + """Test repo_url is mapped correctly.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({"repo_url": "https://github.com/test/repo"}) + + assert result["repo.url"] == "https://github.com/test/repo" + + def test_commit_hash_param(self): + """Test commit_hash is mapped correctly.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({"commit_hash": "abc123"}) + + assert result["repo.commit_hash"] == "abc123" + + def test_code_hash_param(self): + """Test code_hash is passed through.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({"code_hash": "a" * 64}) + + assert result["code_hash"] == "a" * 64 + + def test_boolean_true_variants(self): + """Test various true values for boolean flags.""" + from cli.cli_app import build_api_payload + + for true_val in ["true", "1", "yes", "y", "True", "YES"]: + result = build_api_payload({"has_memory_management": true_val}) + assert result.get("labels.groups.memory_management") is True + + def test_boolean_false_variants(self): + """Test various false values for boolean flags.""" + from cli.cli_app import build_api_payload + + for false_val in ["false", "0", "no", "n", "False", "NO"]: + result = build_api_payload({"has_memory_management": false_val}) + assert result.get("labels.groups.memory_management") is False + + def test_empty_boolean_ignored(self): + """Test empty boolean values are ignored.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({"has_memory_management": ""}) + + assert "labels.groups.memory_management" not in result + + def test_multiple_boolean_flags(self): + """Test multiple boolean flags.""" + from cli.cli_app import build_api_payload + + result = build_api_payload({ + "has_memory_management": "true", + "has_logic_error": "true", + "has_concurrency": "false" + }) + + assert result["labels.groups.memory_management"] is True + assert result["labels.groups.logic_error"] is True + assert result["labels.groups.concurrency"] is False + + +class TestDoImport: + """Tests for do_import function.""" + + def test_import_success(self, tmp_path): + """Test successful import.""" + from cli.cli_app import do_import + + with patch('cli.cli_app.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [{"_id": "123"}] + mock_post.return_value = mock_response + + params = { + "target file": str(tmp_path / "output.json"), + "limit": "100" + } + + do_import(params) + + # File should be created + assert (tmp_path / "output.json").exists() + + def test_import_connection_error(self, capsys): + """Test import with connection error.""" + from cli.cli_app import do_import + import requests + + with patch('cli.cli_app.requests.post') as mock_post: + mock_post.side_effect = requests.exceptions.ConnectionError() + + params = {"target file": "/tmp/test.json", "limit": "100"} + + do_import(params) + + captured = capsys.readouterr() + assert "Error" in captured.out + + +class TestDoScrape: + """Tests for do_scrape function.""" + + def test_scrape_success(self, capsys): + """Test successful scrape.""" + from cli.cli_app import do_scrape + + with patch('cli.cli_app.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.recv.side_effect = [ + b"ACK: Scraping config.json", + b"ACK: Finished Scraping config.json\n" + ] + + params = {"config_file": "config.json"} + + do_scrape(params) + + captured = capsys.readouterr() + assert "SCRAPE" in captured.out + + def test_scrape_no_response(self, capsys): + """Test scrape with no response.""" + from cli.cli_app import do_scrape + + with patch('cli.cli_app.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.recv.return_value = b"" + + params = {"config_file": "config.json"} + + do_scrape(params) + + captured = capsys.readouterr() + assert "No response" in captured.out + + +class TestDoExportAll: + """Tests for do_export_all function.""" + + def test_export_creates_directory(self, tmp_path): + """Test export creates directory if needed.""" + from cli.cli_app import do_export_all + + export_dir = tmp_path / "export_test" + + with patch('cli.cli_app.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = [] + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_get.return_value = mock_response + + with patch('cli.cli_app.Path') as mock_path: + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + do_export_all({}) + + +class TestDoLabel: + """Tests for do_label function.""" + + def test_label_prints_message(self, capsys): + """Test label command prints message.""" + from cli.cli_app import do_label + + do_label({}) + + captured = capsys.readouterr() + assert "Labeling" in captured.out + + +class TestCLIApp: + """Tests for CLIApp class.""" + + def test_cli_app_inherits_command_tree(self): + """Test CLIApp inherits from CommandTree.""" + from cli.cli_app import CLIApp + from cli.command_tree import CommandTree + + app = CLIApp() + + assert isinstance(app, CommandTree) + + def test_cli_app_has_commands(self): + """Test CLIApp registers expected commands.""" + from cli.cli_app import CLIApp + + app = CLIApp() + + assert "scrape" in app.root.children + assert "import" in app.root.children + assert "import-all" in app.root.children + assert "export-all" in app.root.children + assert "label" in app.root.children + + def test_scrape_command_has_action(self): + """Test scrape command has action bound.""" + from cli.cli_app import CLIApp + + app = CLIApp() + + assert app.root.children["scrape"].action is not None + assert app.root.children["scrape"].is_command is True + + def test_scrape_command_has_params(self): + """Test scrape command has parameters.""" + from cli.cli_app import CLIApp + + app = CLIApp() + + assert "config_file" in app.root.children["scrape"].param_set + + +class TestSafeFilename: + """Tests for _safe_filename function.""" + + def test_alphanumeric(self): + """Test alphanumeric characters are preserved.""" + from cli.cli_app import _safe_filename + + result = _safe_filename("abc123") + + assert result == "abc123" + + def test_special_chars_removed(self): + """Test special characters are removed.""" + from cli.cli_app import _safe_filename + + result = _safe_filename("file/with\\special:chars") + + assert "/" not in result + assert "\\" not in result + assert ":" not in result + + def test_allowed_chars_preserved(self): + """Test allowed special chars are preserved.""" + from cli.cli_app import _safe_filename + + result = _safe_filename("file-name_v1.txt") + + assert result == "file-name_v1.txt" + + +class TestConstants: + """Tests for module constants.""" + + def test_filter_params_defined(self): + """Test FILTER_PARAMS dictionary is defined.""" + from cli.cli_app import FILTER_PARAMS + + assert "limit" in FILTER_PARAMS + assert "repo_url" in FILTER_PARAMS + assert "has_memory_management" in FILTER_PARAMS + + def test_api_base_default(self): + """Test API_BASE has default value.""" + from cli.cli_app import API_BASE + + assert "localhost" in API_BASE or "8000" in API_BASE diff --git a/tests/cli/test_command_tree.py b/tests/cli/test_command_tree.py new file mode 100644 index 0000000..4eb2a09 --- /dev/null +++ b/tests/cli/test_command_tree.py @@ -0,0 +1,203 @@ +""" +Unit tests for cli/command_tree.py +Tests the CommandNode and CommandTree classes. +""" +import pytest +from unittest.mock import MagicMock, patch + + +class TestCommandNode: + """Tests for CommandNode class.""" + + def test_create_node(self): + """Test creating a command node.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + + assert node.name == "test" + assert node.children == {} + assert node.parent is None + assert node.is_command is False + assert node.param_set == {} + assert node.action is None + + def test_create_node_with_parent(self): + """Test creating a node with parent.""" + from cli.command_tree import CommandNode + + parent = CommandNode("parent") + child = CommandNode("child", parent=parent) + + assert child.parent == parent + + def test_add_child(self): + """Test adding child node.""" + from cli.command_tree import CommandNode + + parent = CommandNode("parent") + child = CommandNode("child") + + parent.add_child(child) + + assert "child" in parent.children + assert parent.children["child"] == child + assert child.parent == parent + + def test_get_child(self): + """Test getting child by name.""" + from cli.command_tree import CommandNode + + parent = CommandNode("parent") + child = CommandNode("child") + parent.add_child(child) + + result = parent.get_child("child") + + assert result == child + + def test_get_child_not_found(self): + """Test getting non-existent child.""" + from cli.command_tree import CommandNode + + parent = CommandNode("parent") + + result = parent.get_child("missing") + + assert result is None + + def test_repr(self): + """Test node string representation.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + + result = repr(node) + + assert "" == result + + def test_execute_with_action(self): + """Test executing node with action.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + node.is_command = True + mock_action = MagicMock() + node.action = mock_action + + with patch('cli.command_tree.questionary.press_any_key_to_continue') as mock_press: + mock_press.return_value.ask.return_value = None + node.execute() + + mock_action.assert_called_once() + + def test_execute_without_action(self, capsys): + """Test executing node without action.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + node.is_command = True + + node.execute() + + captured = capsys.readouterr() + assert "Error: No action bound" in captured.out + + def test_collect_params_empty(self): + """Test collecting params when empty.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + + result = node.collect_params() + + assert result == {} + + def test_collect_params_with_values(self): + """Test collecting params with user input.""" + from cli.command_tree import CommandNode + + node = CommandNode("test") + node.param_set = {"path": "./data", "format": "json"} + + with patch('cli.command_tree.questionary.text') as mock_text: + mock_text.return_value.ask.side_effect = ["./output", "csv"] + result = node.collect_params() + + assert result["path"] == "./output" + assert result["format"] == "csv" + + +class TestCommandTree: + """Tests for CommandTree class.""" + + def test_create_tree(self): + """Test creating a command tree.""" + from cli.command_tree import CommandTree + + tree = CommandTree() + + assert tree.root is not None + assert tree.root.name == "root" + + def test_add_command_simple(self): + """Test adding a simple command.""" + from cli.command_tree import CommandTree + + tree = CommandTree() + mock_action = MagicMock() + + tree.add_command("Scrape", action=mock_action) + + assert "Scrape" in tree.root.children + assert tree.root.children["Scrape"].is_command is True + assert tree.root.children["Scrape"].action == mock_action + + def test_add_command_nested(self): + """Test adding a nested command.""" + from cli.command_tree import CommandTree + + tree = CommandTree() + mock_action = MagicMock() + + tree.add_command("Data Export", action=mock_action) + + assert "Data" in tree.root.children + data_node = tree.root.children["Data"] + assert "Export" in data_node.children + assert data_node.children["Export"].is_command is True + + def test_add_command_with_params(self): + """Test adding command with parameters.""" + from cli.command_tree import CommandTree + + tree = CommandTree() + params = {"path": "./data"} + + tree.add_command("Export", action=MagicMock(), param_set=params) + + export_node = tree.root.children["Export"] + assert export_node.param_set == params + + def test_add_multiple_commands_shared_parent(self): + """Test adding multiple commands with shared parent.""" + from cli.command_tree import CommandTree + + tree = CommandTree() + + tree.add_command("Data Import", action=MagicMock()) + tree.add_command("Data Export", action=MagicMock()) + + data_node = tree.root.children["Data"] + assert "Import" in data_node.children + assert "Export" in data_node.children + + +class TestCustomStyle: + """Tests for custom_style questionary style.""" + + def test_custom_style_exists(self): + """Test custom_style is defined.""" + from cli.command_tree import custom_style + + assert custom_style is not None diff --git a/tests/cli/test_handlers.py b/tests/cli/test_handlers.py new file mode 100644 index 0000000..d4efae6 --- /dev/null +++ b/tests/cli/test_handlers.py @@ -0,0 +1,498 @@ +""" +Unit tests for cli/handlers.py +Tests CLI handler functions with mocked requests and sockets. +""" +import pytest +from unittest.mock import MagicMock, patch, mock_open +from pathlib import Path +import json + + +# ============================================================================ +# Test Label Mapping +# ============================================================================ + +class TestLabelMapping: + """Tests for label mapping utilities.""" + + def test_labels_to_filter_group_labels(self): + """Test converting group labels to MongoDB filter.""" + from cli.handlers import labels_to_filter + + result = labels_to_filter(["MemError", "LogicError"]) + + assert "labels.groups.memory_management" in result + assert result["labels.groups.memory_management"] is True + assert "labels.groups.logic_error" in result + assert result["labels.groups.logic_error"] is True + + def test_labels_to_filter_cppcheck_labels(self): + """Test converting unknown labels as cppcheck labels.""" + from cli.handlers import labels_to_filter + + result = labels_to_filter(["nullPointer"]) + + assert "labels.cppcheck" in result + assert result["labels.cppcheck"] == {"$in": ["nullPointer"]} + + def test_labels_to_filter_mixed(self): + """Test mixed group and cppcheck labels.""" + from cli.handlers import labels_to_filter + + result = labels_to_filter(["MemError", "customLabel"]) + + assert "labels.groups.memory_management" in result + assert "labels.cppcheck" in result + + +# ============================================================================ +# Test Scrape Handler +# ============================================================================ + +class TestHandleScrape: + """Tests for handle_scrape function.""" + + def test_scrape_success(self): + """Test successful scrape command.""" + from cli.handlers import handle_scrape + + with patch('cli.handlers.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.recv.side_effect = [ + b"ACK: Scraping config.json", + b"ACK: Finished Scraping config.json\n", + ] + + result = handle_scrape("config.json") + + assert result == 0 + mock_socket.connect.assert_called_once() + mock_socket.sendall.assert_called_once() + + def test_scrape_connection_error(self): + """Test scrape with connection error.""" + from cli.handlers import handle_scrape + import socket + + with patch('cli.handlers.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.connect.side_effect = socket.gaierror("host not found") + + result = handle_scrape("config.json") + + assert result == 1 + + def test_scrape_timeout(self): + """Test scrape with timeout.""" + from cli.handlers import handle_scrape + import socket + + with patch('cli.handlers.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.recv.side_effect = socket.timeout() + + result = handle_scrape("config.json") + + assert result == 1 + + def test_scrape_no_response(self): + """Test scrape with no initial response.""" + from cli.handlers import handle_scrape + + with patch('cli.handlers.socket.socket') as mock_socket_class: + mock_socket = MagicMock() + mock_socket_class.return_value = mock_socket + mock_socket.recv.return_value = b"" + + result = handle_scrape("config.json") + + assert result == 1 + + +# ============================================================================ +# Test List Handlers +# ============================================================================ + +class TestHandleListAll: + """Tests for handle_list_all function.""" + + def test_list_all_success(self): + """Test successful list all.""" + from cli.handlers import handle_list_all + + with patch('cli.handlers.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"_id": "123", "labels": {"groups": {"memory_management": True}, "cppcheck": []}} + ] + mock_get.return_value = mock_response + + result = handle_list_all() + + assert result == 0 + + def test_list_all_connection_error(self): + """Test list all with connection error.""" + from cli.handlers import handle_list_all + import requests + + with patch('cli.handlers.requests.get') as mock_get: + mock_get.side_effect = requests.exceptions.ConnectionError() + + result = handle_list_all() + + assert result == 1 + + def test_list_all_http_error(self): + """Test list all with HTTP error.""" + from cli.handlers import handle_list_all + import requests + + with patch('cli.handlers.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + response=MagicMock(text="Server Error") + ) + mock_get.return_value = mock_response + + result = handle_list_all() + + assert result == 1 + + +class TestHandleListLabels: + """Tests for handle_list_labels function.""" + + def test_list_labels_success(self): + """Test successful list by labels.""" + from cli.handlers import handle_list_labels + + with patch('cli.handlers.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"_id": "123", "labels": {"groups": {"memory_management": True}, "cppcheck": []}} + ] + mock_post.return_value = mock_response + + result = handle_list_labels(["MemError"]) + + assert result == 0 + # Verify filter was constructed correctly + call_args = mock_post.call_args + assert "json" in call_args.kwargs + assert "filter" in call_args.kwargs["json"] + + def test_list_labels_empty_result(self): + """Test list labels with no matching entries.""" + from cli.handlers import handle_list_labels + + with patch('cli.handlers.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_post.return_value = mock_response + + result = handle_list_labels(["NonexistentLabel"]) + + assert result == 0 + + +# ============================================================================ +# Test Import Handler +# ============================================================================ + +class TestHandleImportAll: + """Tests for handle_import_all function.""" + + def test_import_folder_not_exists(self, tmp_path): + """Test import from non-existent folder.""" + from cli.handlers import handle_import_all + + result = handle_import_all("/nonexistent/path", "JSON") + + assert result == 1 + + def test_import_no_files(self, tmp_path): + """Test import from empty folder.""" + from cli.handlers import handle_import_all + + result = handle_import_all(str(tmp_path), "JSON") + + assert result == 1 + + def test_import_json_success(self, tmp_path, sample_code_entry_dict): + """Test successful JSON import.""" + from cli.handlers import handle_import_all + + # Create a test JSON file + json_file = tmp_path / "test.json" + entry = sample_code_entry_dict.copy() + entry.pop("_id", None) + json_file.write_text(json.dumps(entry)) + + with patch('cli.handlers.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 201 + mock_post.return_value = mock_response + + result = handle_import_all(str(tmp_path), "JSON") + + assert result == 0 + + def test_import_json_duplicate(self, tmp_path, sample_code_entry_dict): + """Test import with duplicate entry.""" + from cli.handlers import handle_import_all + + json_file = tmp_path / "test.json" + entry = sample_code_entry_dict.copy() + entry.pop("_id", None) + json_file.write_text(json.dumps(entry)) + + with patch('cli.handlers.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 409 + mock_post.return_value = mock_response + + result = handle_import_all(str(tmp_path), "JSON") + + # Errors occurred but function completes + assert result == 1 + + +# ============================================================================ +# Test Export Handler +# ============================================================================ + +class TestHandleExportAll: + """Tests for handle_export_all function.""" + + def test_export_json_success(self, tmp_path, sample_code_entry_dict): + """Test successful JSON export.""" + from cli.handlers import handle_export_all + + with patch('cli.handlers.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = [ + json.dumps(sample_code_entry_dict) + ] + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_get.return_value = mock_response + + result = handle_export_all(str(tmp_path), "JSON") + + assert result == 0 + + def test_export_csv_success(self, tmp_path, sample_code_entry_dict): + """Test successful CSV export.""" + from cli.handlers import handle_export_all + + with patch('cli.handlers.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = [ + json.dumps(sample_code_entry_dict) + ] + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_get.return_value = mock_response + + result = handle_export_all(str(tmp_path), "CSV") + + assert result == 0 + + def test_export_with_labels_filter(self, tmp_path, sample_code_entry_dict): + """Test export with labels filter.""" + from cli.handlers import handle_export_all + + with patch('cli.handlers.requests.post') as mock_post: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [sample_code_entry_dict] + mock_post.return_value = mock_response + + result = handle_export_all(str(tmp_path), "JSON", labels=["MemError"]) + + assert result == 0 + mock_post.assert_called_once() + + def test_export_connection_error(self, tmp_path): + """Test export with connection error.""" + from cli.handlers import handle_export_all + import requests + + with patch('cli.handlers.requests.get') as mock_get: + mock_get.side_effect = requests.exceptions.ConnectionError() + + result = handle_export_all(str(tmp_path), "JSON") + + assert result == 1 + + def test_export_empty(self, tmp_path): + """Test export with no entries.""" + from cli.handlers import handle_export_all + + with patch('cli.handlers.requests.get') as mock_get: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.iter_lines.return_value = [] + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_get.return_value = mock_response + + result = handle_export_all(str(tmp_path), "JSON") + + assert result == 0 + + +# ============================================================================ +# Test Edit Handler +# ============================================================================ + +class TestHandleEditLabels: + """Tests for handle_edit_labels function.""" + + def test_edit_add_labels_success(self): + """Test successful label addition.""" + from cli.handlers import handle_edit_labels + + with patch('cli.handlers.requests.patch') as mock_patch: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"_id": "123"} + mock_patch.return_value = mock_response + + result = handle_edit_labels("123", add_labels=["MemError"]) + + assert result == 0 + + def test_edit_remove_labels_success(self): + """Test successful label removal.""" + from cli.handlers import handle_edit_labels + + with patch('cli.handlers.requests.patch') as mock_patch: + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"_id": "123"} + mock_patch.return_value = mock_response + + result = handle_edit_labels("123", remove_labels=["LogicError"]) + + assert result == 0 + + def test_edit_no_labels_specified(self): + """Test edit with no labels specified.""" + from cli.handlers import handle_edit_labels + + result = handle_edit_labels("123") + + assert result == 1 + + def test_edit_entry_not_found(self): + """Test edit on non-existent entry.""" + from cli.handlers import handle_edit_labels + + with patch('cli.handlers.requests.patch') as mock_patch: + mock_response = MagicMock() + mock_response.status_code = 404 + mock_patch.return_value = mock_response + + result = handle_edit_labels("nonexistent", add_labels=["MemError"]) + + assert result == 1 + + def test_edit_connection_error(self): + """Test edit with connection error.""" + from cli.handlers import handle_edit_labels + import requests + + with patch('cli.handlers.requests.patch') as mock_patch: + mock_patch.side_effect = requests.exceptions.ConnectionError() + + result = handle_edit_labels("123", add_labels=["MemError"]) + + assert result == 1 + + +# ============================================================================ +# Test Helper Functions +# ============================================================================ + +class TestHelperFunctions: + """Tests for helper functions.""" + + def test_safe_filename(self): + """Test safe filename generation.""" + from cli.handlers import _safe_filename + + assert _safe_filename("abc123") == "abc123" + assert _safe_filename("abc/def") == "abcdef" + assert _safe_filename("test.json") == "test.json" + assert _safe_filename("file with spaces") == "filewithspaces" + + def test_flatten_entry(self, sample_code_entry_dict): + """Test entry flattening for CSV.""" + from cli.handlers import _flatten_entry + + flat = _flatten_entry(sample_code_entry_dict) + + assert "_id" in flat + assert "code_original" in flat + assert "repo_url" in flat + assert "labels_memory_management" in flat + + def test_unflatten_csv_row(self): + """Test CSV row unflattening.""" + from cli.handlers import _unflatten_csv_row + + row = { + "code_original": "int main() {}", + "code_fixed": "", + "code_hash": "a" * 64, + "repo_url": "https://github.com/test", + "repo_commit_hash": "abc", + "repo_commit_date": "2024-01-01T00:00:00", + "ingest_timestamp": "2024-01-01T00:00:00", + "labels_cppcheck": '["nullPointer"]', + "labels_clang": "{}", + "labels_memory_management": "True", + "labels_invalid_access": "False", + "labels_uninitialized": "False", + "labels_concurrency": "False", + "labels_logic_error": "False", + "labels_resource_leak": "False", + "labels_security_portability": "False", + "labels_code_quality_performance": "False", + } + + entry = _unflatten_csv_row(row) + + assert entry["code_original"] == "int main() {}" + assert entry["repo"]["url"] == "https://github.com/test" + assert entry["labels"]["cppcheck"] == ["nullPointer"] + assert entry["labels"]["groups"]["memory_management"] is True + + def test_print_entries_table_empty(self, capsys): + """Test printing empty entries table.""" + from cli.handlers import _print_entries_table + + _print_entries_table([]) + + captured = capsys.readouterr() + assert "No entries found" in captured.out + + def test_print_entries_table_with_entries(self, capsys, sample_code_entry_dict): + """Test printing entries table with data.""" + from cli.handlers import _print_entries_table + + _print_entries_table([sample_code_entry_dict]) + + captured = capsys.readouterr() + assert "507f1f77bcf86cd799439011" in captured.out + assert "Total: 1 entries" in captured.out diff --git a/tests/cli/test_loop.py b/tests/cli/test_loop.py new file mode 100644 index 0000000..457e1c3 --- /dev/null +++ b/tests/cli/test_loop.py @@ -0,0 +1,132 @@ +""" +Unit tests for cli/loop.py +Tests the interactive menu loop functions. +""" +import pytest +from unittest.mock import MagicMock, patch + + +class TestGetBreadcrumbs: + """Tests for get_breadcrumbs function.""" + + def test_breadcrumbs_root(self): + """Test breadcrumbs at root returns Root.""" + from cli.loop import get_breadcrumbs + + mock_node = MagicMock() + mock_node.name = "root" + mock_node.parent = None + + result = get_breadcrumbs(mock_node) + + assert result == "Root" + + def test_breadcrumbs_one_level(self): + """Test breadcrumbs one level deep.""" + from cli.loop import get_breadcrumbs + + mock_parent = MagicMock() + mock_parent.name = "root" + mock_parent.parent = None + + mock_node = MagicMock() + mock_node.name = "Scrape" + mock_node.parent = mock_parent + + result = get_breadcrumbs(mock_node) + + assert result == "Scrape" + + def test_breadcrumbs_multi_level(self): + """Test breadcrumbs multiple levels deep.""" + from cli.loop import get_breadcrumbs + + mock_root = MagicMock() + mock_root.name = "root" + mock_root.parent = None + + mock_parent = MagicMock() + mock_parent.name = "Data" + mock_parent.parent = mock_root + + mock_node = MagicMock() + mock_node.name = "Export" + mock_node.parent = mock_parent + + result = get_breadcrumbs(mock_node) + + assert result == "Data / Export" + + +class TestRunMenuLoop: + """Tests for run_menu_loop function.""" + + def test_menu_loop_exit(self): + """Test menu loop exits on EXIT selection.""" + from cli.loop import run_menu_loop + + with patch('cli.loop.CLIApp') as mock_app_class, \ + patch('cli.loop.questionary.select') as mock_select: + + mock_app = MagicMock() + mock_root = MagicMock() + mock_root.children = {} + mock_root.parent = None + mock_root.name = "root" + mock_root.is_command = False + mock_app.root = mock_root + mock_app_class.return_value = mock_app + + # Simulate EXIT selection + mock_select.return_value.ask.return_value = "EXIT" + + run_menu_loop() + + mock_select.assert_called() + + def test_menu_loop_keyboard_interrupt(self): + """Test menu loop handles keyboard interrupt.""" + from cli.loop import run_menu_loop + + with patch('cli.loop.CLIApp') as mock_app_class, \ + patch('cli.loop.questionary.select') as mock_select: + + mock_app = MagicMock() + mock_root = MagicMock() + mock_root.children = {} + mock_root.parent = None + mock_root.name = "root" + mock_root.is_command = False + mock_app.root = mock_root + mock_app_class.return_value = mock_app + + # Simulate KeyboardInterrupt + mock_select.side_effect = KeyboardInterrupt() + + run_menu_loop() + + # Should exit gracefully + mock_select.assert_called() + + def test_menu_loop_back_navigation(self): + """Test menu loop back navigation.""" + from cli.loop import run_menu_loop + + with patch('cli.loop.CLIApp') as mock_app_class, \ + patch('cli.loop.questionary.select') as mock_select: + + mock_app = MagicMock() + mock_root = MagicMock() + mock_root.children = {} + mock_root.parent = None + mock_root.name = "root" + mock_root.is_command = False + mock_app.root = mock_root + mock_app_class.return_value = mock_app + + # Simulate BACK then EXIT + mock_select.return_value.ask.side_effect = ["BACK", "EXIT"] + + run_menu_loop() + + assert mock_select.call_count >= 1 diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py new file mode 100644 index 0000000..ab57918 --- /dev/null +++ b/tests/cli/test_main.py @@ -0,0 +1,280 @@ +""" +Unit tests for cli/main.py +Tests argument parsing and command execution. +""" +import pytest +from unittest.mock import patch, MagicMock +import sys + + +class TestArgumentParser: + """Tests for argument parser configuration.""" + + def test_create_parser(self): + """Test parser creation.""" + from cli.main import create_parser + + parser = create_parser() + + assert parser.prog == "fixmycodedb" + assert parser.description is not None + + def test_parse_scrape_argument(self): + """Test parsing --scrape argument.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--scrape", "config.json"]) + + assert args.scrape == "config.json" + + def test_parse_list_all_argument(self): + """Test parsing --list-all argument.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--list-all"]) + + assert args.list_all is True + + def test_parse_list_labels_argument(self): + """Test parsing --list-labels argument.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--list-labels", "MemError", "LogicError"]) + + assert args.list_labels == ["MemError", "LogicError"] + + def test_parse_import_json(self): + """Test parsing --import-all with --JSON.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--import-all", "./data", "--JSON"]) + + assert args.import_all == "./data" + assert args.json_format is True + assert args.csv_format is False + + def test_parse_import_csv(self): + """Test parsing --import-all with --CSV.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--import-all", "./data", "--CSV"]) + + assert args.import_all == "./data" + assert args.csv_format is True + assert args.json_format is False + + def test_parse_export_all(self): + """Test parsing --export-all argument.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--export-all", "./backup"]) + + assert args.export_all == "./backup" + + def test_parse_export_all_default(self): + """Test parsing --export-all with default folder.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--export-all"]) + + assert args.export_all == "exported_files" + + def test_parse_export_with_labels(self): + """Test parsing --export-all with --labels.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--export-all", "./backup", "--labels", "MemError"]) + + assert args.export_all == "./backup" + assert args.labels == ["MemError"] + + def test_parse_edit_add_label(self): + """Test parsing --edit with --add-label.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--edit", "123", "--add-label", "MemError"]) + + assert args.edit == "123" + assert args.add_label == ["MemError"] + + def test_parse_edit_remove_label(self): + """Test parsing --edit with --remove-label.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--edit", "123", "--remove-label", "LogicError"]) + + assert args.edit == "123" + assert args.remove_label == ["LogicError"] + + def test_parse_no_infra(self): + """Test parsing --no-infra flag.""" + from cli.main import create_parser + + parser = create_parser() + args = parser.parse_args(["--list-all", "--no-infra"]) + + assert args.no_infra is True + + def test_mutually_exclusive_format(self): + """Test --JSON and --CSV are mutually exclusive.""" + from cli.main import create_parser + + parser = create_parser() + + with pytest.raises(SystemExit): + parser.parse_args(["--import-all", "./data", "--JSON", "--CSV"]) + + +class TestValidateArgs: + """Tests for argument validation.""" + + def test_validate_conflicting_commands(self): + """Test validation catches conflicting commands.""" + from cli.main import create_parser, validate_args + + parser = create_parser() + + # Parse with list-all as it's valid + args = parser.parse_args(["--list-all"]) + # Manually set conflicting values + args.scrape = "config.json" + + # parser.error() causes SystemExit + with pytest.raises(SystemExit): + validate_args(args, parser) + + def test_validate_edit_without_labels(self): + """Test validation catches --edit without label flags.""" + from cli.main import create_parser, validate_args + + parser = create_parser() + args = parser.parse_args(["--edit", "123"]) + + # parser.error() causes SystemExit + with pytest.raises(SystemExit): + validate_args(args, parser) + + def test_validate_labels_without_export(self): + """Test validation catches --labels without --export-all.""" + from cli.main import create_parser, validate_args + + parser = create_parser() + args = parser.parse_args(["--list-all"]) + args.labels = ["MemError"] # Manually set + + # parser.error() causes SystemExit + with pytest.raises(SystemExit): + validate_args(args, parser) + + +class TestManageInfrastructure: + """Tests for infrastructure management.""" + + def test_manage_infrastructure_success(self, tmp_path): + """Test successful docker compose command.""" + from cli.main import manage_infrastructure + + with patch('cli.main.subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + + # Should not raise + manage_infrastructure("up -d", str(tmp_path)) + + mock_run.assert_called_once() + + def test_manage_infrastructure_docker_error(self, tmp_path): + """Test docker command failure.""" + from cli.main import manage_infrastructure + import subprocess + + with patch('cli.main.subprocess.run') as mock_run: + mock_run.side_effect = subprocess.CalledProcessError( + 1, "docker", stderr=b"error message" + ) + + with pytest.raises(SystemExit): + manage_infrastructure("up -d", str(tmp_path)) + + def test_manage_infrastructure_docker_not_found(self, tmp_path): + """Test docker not installed.""" + from cli.main import manage_infrastructure + + with patch('cli.main.subprocess.run') as mock_run: + mock_run.side_effect = FileNotFoundError() + + with pytest.raises(SystemExit): + manage_infrastructure("up -d", str(tmp_path)) + + +class TestMainExecution: + """Tests for main function execution.""" + + def test_main_no_args_interactive(self): + """Test main with no args starts interactive mode.""" + from cli.main import main + + with patch('cli.main.run_menu_loop') as mock_loop, \ + patch('cli.main.manage_infrastructure') as mock_infra, \ + patch('cli.main.os.path.dirname', return_value="/test"): + + with patch.object(sys, 'argv', ['fixmycodedb', '--no-infra']): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 0 + mock_loop.assert_called_once() + + def test_main_list_all(self): + """Test main with --list-all.""" + from cli.main import main + + with patch('cli.main.handle_list_all', return_value=0) as mock_handler, \ + patch('cli.main.manage_infrastructure'), \ + patch('cli.main.os.path.dirname', return_value="/test"): + + with patch.object(sys, 'argv', ['fixmycodedb', '--list-all', '--no-infra']): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 0 + mock_handler.assert_called_once() + + def test_main_scrape(self): + """Test main with --scrape.""" + from cli.main import main + + with patch('cli.main.handle_scrape', return_value=0) as mock_handler, \ + patch('cli.main.manage_infrastructure'), \ + patch('cli.main.os.path.dirname', return_value="/test"): + + with patch.object(sys, 'argv', ['fixmycodedb', '--scrape', 'config.json', '--no-infra']): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 0 + mock_handler.assert_called_once_with("config.json") + + def test_main_export_json(self): + """Test main with --export-all --JSON.""" + from cli.main import main + + with patch('cli.main.handle_export_all', return_value=0) as mock_handler, \ + patch('cli.main.manage_infrastructure'), \ + patch('cli.main.os.path.dirname', return_value="/test"): + + with patch.object(sys, 'argv', ['fixmycodedb', '--export-all', './out', '--JSON', '--no-infra']): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 0 + mock_handler.assert_called_once() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3ef7dbc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,100 @@ +""" +Pytest configuration and shared fixtures for all test modules. +""" +import pytest +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + + +# ============================================================================ +# Sample Data Fixtures +# ============================================================================ + +@pytest.fixture +def sample_code_entry_dict(): + """A sample CodeEntry as a dictionary.""" + return { + "_id": "507f1f77bcf86cd799439011", + "code_original": 'int main() { int *p; delete p; return 0; }', + "code_fixed": 'int main() { int *p = nullptr; if(p) delete p; return 0; }', + "code_hash": "a" * 64, + "repo": { + "url": "https://github.com/test/repo", + "commit_hash": "abc123def456", + "commit_date": "2024-01-15T10:30:00", + }, + "ingest_timestamp": "2024-01-15T10:30:00", + "labels": { + "cppcheck": ["nullPointer", "memleak"], + "clang": {}, + "groups": { + "memory_management": True, + "invalid_access": False, + "uninitialized": False, + "concurrency": False, + "logic_error": False, + "resource_leak": True, + "security_portability": False, + "code_quality_performance": False, + }, + }, + } + + +@pytest.fixture +def sample_repo_info(): + """Sample RepoInfo dictionary.""" + return { + "url": "https://github.com/test/repo", + "commit_hash": "abc123def456", + "commit_date": datetime(2024, 1, 15, 10, 30, 0), + } + + +@pytest.fixture +def sample_labels_groups(): + """Sample LabelsGroup dictionary.""" + return { + "memory_management": True, + "invalid_access": False, + "uninitialized": False, + "concurrency": False, + "logic_error": False, + "resource_leak": True, + "security_portability": False, + "code_quality_performance": False, + } + + +# ============================================================================ +# Mock Fixtures +# ============================================================================ + +@pytest.fixture +def mock_mongodb(): + """Mock AsyncIOMotorDatabase for MongoDB operations.""" + mock_db = MagicMock() + mock_collection = MagicMock() + + # Setup common async methods + mock_collection.insert_one = AsyncMock() + mock_collection.find_one = AsyncMock() + mock_collection.update_one = AsyncMock() + mock_collection.delete_one = AsyncMock() + mock_collection.find = MagicMock() + + mock_db.__getitem__ = MagicMock(return_value=mock_collection) + + return mock_db + + +@pytest.fixture +def mock_socket(): + """Mock socket for network operations.""" + mock_sock = MagicMock() + mock_sock.connect = MagicMock() + mock_sock.sendall = MagicMock() + mock_sock.recv = MagicMock() + mock_sock.settimeout = MagicMock() + mock_sock.close = MagicMock() + return mock_sock diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 0000000..db7e251 --- /dev/null +++ b/tests/fastapi/__init__.py @@ -0,0 +1 @@ +# FastAPI Module Tests diff --git a/tests/fastapi/test_crud.py b/tests/fastapi/test_crud.py new file mode 100644 index 0000000..311ab9f --- /dev/null +++ b/tests/fastapi/test_crud.py @@ -0,0 +1,344 @@ +""" +Unit tests for fastapi_app/crud.py +Tests CRUD operations with mocked MongoDB. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from bson import ObjectId +from datetime import datetime + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'fastapi_app')) + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def mock_db(): + """Create a mock database with mock collection.""" + db = MagicMock() + collection = MagicMock() + db.__getitem__ = MagicMock(return_value=collection) + return db, collection + + +@pytest.fixture +def sample_entry(): + """Sample CodeEntry-like object for testing.""" + class MockEntry: + def model_dump(self, by_alias=False, exclude=None): + data = { + "_id": "507f1f77bcf86cd799439011", + "code_original": "int main() {}", + "code_fixed": "int main() { return 0; }", + "code_hash": "a" * 64, + "repo": { + "url": "https://github.com/test/repo", + "commit_hash": "abc123", + "commit_date": "2024-01-15T10:30:00", + }, + "ingest_timestamp": "2024-01-15T10:30:00", + "labels": { + "cppcheck": ["nullPointer"], + "clang": {}, + "groups": {"memory_management": True}, + }, + } + if exclude and "id" in exclude: + data.pop("_id", None) + return data + return MockEntry() + + +# ============================================================================ +# Test create_entry +# ============================================================================ + +class TestCreateEntry: + """Tests for create_entry function.""" + + @pytest.mark.asyncio + async def test_create_entry_success(self, mock_db, sample_entry): + """Test successful entry creation.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.inserted_id = ObjectId("507f1f77bcf86cd799439011") + collection.insert_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.create_entry(db, sample_entry) + + assert result == "507f1f77bcf86cd799439011" + collection.insert_one.assert_called_once() + + +# ============================================================================ +# Test get_entry +# ============================================================================ + +class TestGetEntry: + """Tests for get_entry function.""" + + @pytest.mark.asyncio + async def test_get_entry_success(self, mock_db, sample_code_entry_dict): + """Test successful entry retrieval.""" + db, collection = mock_db + collection.find_one = AsyncMock(return_value=sample_code_entry_dict) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.get_entry(db, "507f1f77bcf86cd799439011") + + assert result is not None + collection.find_one.assert_called_once() + + @pytest.mark.asyncio + async def test_get_entry_not_found(self, mock_db): + """Test entry not found returns None.""" + db, collection = mock_db + collection.find_one = AsyncMock(return_value=None) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.get_entry(db, "507f1f77bcf86cd799439011") + + assert result is None + + @pytest.mark.asyncio + async def test_get_entry_invalid_id(self, mock_db): + """Test invalid ObjectId returns None.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.get_entry(db, "invalid_id") + + assert result is None + collection.find_one.assert_not_called() + + +# ============================================================================ +# Test update_entry +# ============================================================================ + +class TestUpdateEntry: + """Tests for update_entry function.""" + + @pytest.mark.asyncio + async def test_update_entry_success(self, mock_db): + """Test successful entry update.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.modified_count = 1 + collection.update_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.update_entry(db, "507f1f77bcf86cd799439011", {"code_fixed": "new code"}) + + assert result == 1 + collection.update_one.assert_called_once() + + @pytest.mark.asyncio + async def test_update_entry_not_found(self, mock_db): + """Test update on non-existent entry.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.modified_count = 0 + collection.update_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.update_entry(db, "507f1f77bcf86cd799439011", {"code_fixed": "new"}) + + assert result == 0 + + @pytest.mark.asyncio + async def test_update_entry_invalid_id(self, mock_db): + """Test update with invalid ObjectId.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.update_entry(db, "invalid", {"code_fixed": "new"}) + + assert result == 0 + + +# ============================================================================ +# Test delete_entry +# ============================================================================ + +class TestDeleteEntry: + """Tests for delete_entry function.""" + + @pytest.mark.asyncio + async def test_delete_entry_success(self, mock_db): + """Test successful entry deletion.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.deleted_count = 1 + collection.delete_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.delete_entry(db, "507f1f77bcf86cd799439011") + + assert result == 1 + + @pytest.mark.asyncio + async def test_delete_entry_not_found(self, mock_db): + """Test delete on non-existent entry.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.deleted_count = 0 + collection.delete_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.delete_entry(db, "507f1f77bcf86cd799439011") + + assert result == 0 + + @pytest.mark.asyncio + async def test_delete_entry_invalid_id(self, mock_db): + """Test delete with invalid ObjectId.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.delete_entry(db, "invalid") + + assert result == 0 + + +# ============================================================================ +# Test list_entries +# ============================================================================ + +class TestListEntries: + """Tests for list_entries function.""" + + @pytest.mark.asyncio + async def test_list_entries_empty(self, mock_db): + """Test listing with no entries.""" + db, collection = mock_db + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=[]) + mock_cursor.sort = MagicMock(return_value=mock_cursor) + collection.find = MagicMock(return_value=mock_cursor) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.list_entries(db) + + assert result == [] + + @pytest.mark.asyncio + async def test_list_entries_with_results(self, mock_db, sample_code_entry_dict): + """Test listing with results.""" + db, collection = mock_db + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=[sample_code_entry_dict]) + mock_cursor.sort = MagicMock(return_value=mock_cursor) + collection.find = MagicMock(return_value=mock_cursor) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.list_entries(db) + + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_list_entries_with_filter(self, mock_db, sample_code_entry_dict): + """Test listing with filter.""" + db, collection = mock_db + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=[sample_code_entry_dict]) + mock_cursor.sort = MagicMock(return_value=mock_cursor) + collection.find = MagicMock(return_value=mock_cursor) + + filter_dict = {"labels.groups.memory_management": True} + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.list_entries(db, filter_dict=filter_dict) + + collection.find.assert_called_once_with(filter_dict) + + @pytest.mark.asyncio + async def test_list_entries_invalid_id_filter(self, mock_db): + """Test listing with invalid _id filter returns empty.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.list_entries(db, filter_dict={"_id": "invalid"}) + + assert result == [] + + +# ============================================================================ +# Test cppcheck label functions +# ============================================================================ + +class TestCppcheckLabelFunctions: + """Tests for add_to_cppcheck_labels and remove_from_cppcheck_labels.""" + + @pytest.mark.asyncio + async def test_add_to_cppcheck_labels_success(self, mock_db): + """Test adding labels to cppcheck array.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.modified_count = 1 + collection.update_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.add_to_cppcheck_labels( + db, "507f1f77bcf86cd799439011", ["newLabel"] + ) + + assert result == 1 + + @pytest.mark.asyncio + async def test_add_to_cppcheck_labels_invalid_id(self, mock_db): + """Test adding labels with invalid id.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.add_to_cppcheck_labels(db, "invalid", ["label"]) + + assert result == 0 + + @pytest.mark.asyncio + async def test_remove_from_cppcheck_labels_success(self, mock_db): + """Test removing labels from cppcheck array.""" + db, collection = mock_db + mock_result = MagicMock() + mock_result.modified_count = 1 + collection.update_one = AsyncMock(return_value=mock_result) + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.remove_from_cppcheck_labels( + db, "507f1f77bcf86cd799439011", ["oldLabel"] + ) + + assert result == 1 + + @pytest.mark.asyncio + async def test_remove_from_cppcheck_labels_invalid_id(self, mock_db): + """Test removing labels with invalid id.""" + db, collection = mock_db + + with patch.dict(sys.modules, {'models': MagicMock()}): + import crud + result = await crud.remove_from_cppcheck_labels(db, "invalid", ["label"]) + + assert result == 0 diff --git a/tests/fastapi/test_main.py b/tests/fastapi/test_main.py new file mode 100644 index 0000000..a3c7ac1 --- /dev/null +++ b/tests/fastapi/test_main.py @@ -0,0 +1,353 @@ +""" +Unit tests for fastapi_app/main.py endpoints. +Tests API endpoints using FastAPI TestClient with mocked database. +""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from datetime import datetime +import sys +import os + +# Add fastapi_app to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'fastapi_app')) + +from fastapi.testclient import TestClient + + +# ============================================================================ +# Test Fixtures +# ============================================================================ + +@pytest.fixture +def mock_crud(): + """Mock the crud module.""" + with patch('main.crud') as mock: + yield mock + + +@pytest.fixture +def client(): + """Create a test client with mocked MongoDB connection.""" + with patch('main.AsyncIOMotorClient') as mock_motor: + mock_client = MagicMock() + mock_db = MagicMock() + mock_client.__getitem__ = MagicMock(return_value=mock_db) + mock_motor.return_value = mock_client + + from main import app + app.mongodb_client = mock_client + app.mongodb = mock_db + + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + +# ============================================================================ +# Test Label Mapping +# ============================================================================ + +class TestLabelMapping: + """Tests for LABEL_TO_GROUP_FIELD mapping.""" + + def test_label_mapping_exists(self): + """Test that label mapping is defined.""" + from main import LABEL_TO_GROUP_FIELD + + assert "MemError" in LABEL_TO_GROUP_FIELD + assert "LogicError" in LABEL_TO_GROUP_FIELD + assert LABEL_TO_GROUP_FIELD["MemError"] == "memory_management" + assert LABEL_TO_GROUP_FIELD["LogicError"] == "logic_error" + + def test_label_mapping_bidirectional(self): + """Test both friendly names and field names are mapped.""" + from main import LABEL_TO_GROUP_FIELD + + # Friendly names + assert LABEL_TO_GROUP_FIELD["MemError"] == "memory_management" + # Direct field names + assert LABEL_TO_GROUP_FIELD["memory_management"] == "memory_management" + + +# ============================================================================ +# Test POST /entries/ +# ============================================================================ + +class TestCreateEndpoint: + """Tests for POST /entries/ endpoint.""" + + def test_create_entry_success(self, client, sample_code_entry_dict): + """Test successful entry creation.""" + with patch('main.crud.create_entry', new_callable=AsyncMock) as mock_create: + mock_create.return_value = "507f1f77bcf86cd799439011" + + # Remove _id for creation + entry_data = sample_code_entry_dict.copy() + entry_data.pop("_id", None) + + response = client.post("/entries/", json=entry_data) + + assert response.status_code == 201 + assert response.json() == {"id": "507f1f77bcf86cd799439011"} + + def test_create_entry_duplicate(self, client, sample_code_entry_dict): + """Test duplicate entry returns 409.""" + from pymongo.errors import DuplicateKeyError + + with patch('main.crud.create_entry', new_callable=AsyncMock) as mock_create: + mock_create.side_effect = DuplicateKeyError("duplicate key") + + entry_data = sample_code_entry_dict.copy() + entry_data.pop("_id", None) + + response = client.post("/entries/", json=entry_data) + + assert response.status_code == 409 + + def test_create_entry_validation_error(self, client): + """Test invalid entry returns 422.""" + response = client.post("/entries/", json={"invalid": "data"}) + assert response.status_code == 422 + + +# ============================================================================ +# Test GET /entries/{entry_id} +# ============================================================================ + +class TestReadEndpoint: + """Tests for GET /entries/{entry_id} endpoint.""" + + def test_read_entry_success(self, client, sample_code_entry_dict): + """Test successful entry retrieval.""" + from fastapi_app.models import CodeEntry + + with patch('main.crud.get_entry', new_callable=AsyncMock) as mock_get: + mock_entry = CodeEntry(**sample_code_entry_dict) + mock_get.return_value = mock_entry + + response = client.get("/entries/507f1f77bcf86cd799439011") + + assert response.status_code == 200 + assert response.json()["_id"] == "507f1f77bcf86cd799439011" + + def test_read_entry_not_found(self, client): + """Test entry not found returns 404.""" + with patch('main.crud.get_entry', new_callable=AsyncMock) as mock_get: + mock_get.return_value = None + + response = client.get("/entries/507f1f77bcf86cd799439011") + + assert response.status_code == 404 + + +# ============================================================================ +# Test GET /entries/ +# ============================================================================ + +class TestGetAllEndpoint: + """Tests for GET /entries/ endpoint.""" + + def test_get_all_entries_empty(self, client): + """Test getting all entries when empty.""" + with patch('main.crud.list_entries', new_callable=AsyncMock) as mock_list: + mock_list.return_value = [] + + response = client.get("/entries/") + + assert response.status_code == 200 + assert response.json() == [] + + def test_get_all_entries_with_limit(self, client, sample_code_entry_dict): + """Test getting all entries with limit.""" + from fastapi_app.models import CodeEntry + + with patch('main.crud.list_entries', new_callable=AsyncMock) as mock_list: + mock_entry = CodeEntry(**sample_code_entry_dict) + mock_list.return_value = [mock_entry] + + response = client.get("/entries/?limit=50") + + assert response.status_code == 200 + mock_list.assert_called_once() + + +# ============================================================================ +# Test PUT /entries/{entry_id} +# ============================================================================ + +class TestUpdateEndpoint: + """Tests for PUT /entries/{entry_id} endpoint.""" + + def test_update_entry_success(self, client): + """Test successful entry update.""" + with patch('main.crud.update_entry', new_callable=AsyncMock) as mock_update: + mock_update.return_value = 1 + + response = client.put( + "/entries/507f1f77bcf86cd799439011", + json={"code_fixed": "updated code"} + ) + + assert response.status_code == 200 + assert response.json() == {"updated": 1} + + def test_update_entry_not_found(self, client): + """Test update on non-existent entry.""" + with patch('main.crud.update_entry', new_callable=AsyncMock) as mock_update: + mock_update.return_value = 0 + + response = client.put( + "/entries/507f1f77bcf86cd799439011", + json={"code_fixed": "updated code"} + ) + + assert response.status_code == 404 + + +# ============================================================================ +# Test DELETE /entries/{entry_id} +# ============================================================================ + +class TestDeleteEndpoint: + """Tests for DELETE /entries/{entry_id} endpoint.""" + + def test_delete_entry_success(self, client): + """Test successful entry deletion.""" + with patch('main.crud.delete_entry', new_callable=AsyncMock) as mock_delete: + mock_delete.return_value = 1 + + response = client.delete("/entries/507f1f77bcf86cd799439011") + + assert response.status_code == 200 + assert response.json() == {"deleted": 1} + + def test_delete_entry_not_found(self, client): + """Test delete on non-existent entry.""" + with patch('main.crud.delete_entry', new_callable=AsyncMock) as mock_delete: + mock_delete.return_value = 0 + + response = client.delete("/entries/507f1f77bcf86cd799439011") + + assert response.status_code == 404 + + +# ============================================================================ +# Test PATCH /entries/{entry_id}/labels +# ============================================================================ + +class TestUpdateLabelsEndpoint: + """Tests for PATCH /entries/{entry_id}/labels endpoint.""" + + def test_update_labels_add_group(self, client, sample_code_entry_dict): + """Test adding a group label.""" + from fastapi_app.models import CodeEntry + + with patch('main.crud.get_entry', new_callable=AsyncMock) as mock_get, \ + patch('main.crud.update_entry', new_callable=AsyncMock) as mock_update: + + mock_entry = CodeEntry(**sample_code_entry_dict) + mock_get.return_value = mock_entry + mock_update.return_value = 1 + + response = client.patch( + "/entries/507f1f77bcf86cd799439011/labels", + json={"add": ["MemError"], "remove": []} + ) + + assert response.status_code == 200 + + def test_update_labels_entry_not_found(self, client): + """Test updating labels on non-existent entry.""" + with patch('main.crud.get_entry', new_callable=AsyncMock) as mock_get: + mock_get.return_value = None + + response = client.patch( + "/entries/507f1f77bcf86cd799439011/labels", + json={"add": ["MemError"], "remove": []} + ) + + assert response.status_code == 404 + + def test_update_labels_add_cppcheck(self, client, sample_code_entry_dict): + """Test adding a cppcheck label (not a group label).""" + from fastapi_app.models import CodeEntry + + with patch('main.crud.get_entry', new_callable=AsyncMock) as mock_get, \ + patch('main.crud.add_to_cppcheck_labels', new_callable=AsyncMock) as mock_add: + + mock_entry = CodeEntry(**sample_code_entry_dict) + mock_get.return_value = mock_entry + mock_add.return_value = 1 + + response = client.patch( + "/entries/507f1f77bcf86cd799439011/labels", + json={"add": ["customLabel"], "remove": []} + ) + + assert response.status_code == 200 + mock_add.assert_called_once() + + +# ============================================================================ +# Test POST /entries/query/ +# ============================================================================ + +class TestQueryEndpoint: + """Tests for POST /entries/query/ endpoint.""" + + def test_query_entries_empty_filter(self, client): + """Test query with empty filter.""" + with patch('main.crud.list_entries', new_callable=AsyncMock) as mock_list: + mock_list.return_value = [] + + response = client.post("/entries/query/", json={}) + + assert response.status_code == 200 + + def test_query_entries_with_filter(self, client, sample_code_entry_dict): + """Test query with filter.""" + from fastapi_app.models import CodeEntry + + with patch('main.crud.list_entries', new_callable=AsyncMock) as mock_list: + mock_entry = CodeEntry(**sample_code_entry_dict) + mock_list.return_value = [mock_entry] + + response = client.post( + "/entries/query/", + json={ + "filter": {"labels.groups.memory_management": True}, + "limit": 50 + } + ) + + assert response.status_code == 200 + assert len(response.json()) == 1 + + +# ============================================================================ +# Test GET /entries/export-all +# ============================================================================ + +class TestExportAllEndpoint: + """Tests for GET /entries/export-all endpoint.""" + + def test_export_all_empty(self, client): + """Test export when no entries exist.""" + mock_cursor = MagicMock() + + async def async_gen(): + return + yield # Make it an async generator + + mock_cursor.__aiter__ = lambda self: async_gen() + + with patch.object(client.app, 'mongodb') as mock_db: + mock_collection = MagicMock() + mock_db.__getitem__ = MagicMock(return_value=mock_collection) + + mock_find = MagicMock(return_value=mock_cursor) + mock_find.sort = MagicMock(return_value=mock_cursor) + mock_collection.find = MagicMock(return_value=mock_find) + + response = client.get("/entries/export-all") + + assert response.status_code == 200 diff --git a/tests/fastapi/test_models.py b/tests/fastapi/test_models.py new file mode 100644 index 0000000..4f7de53 --- /dev/null +++ b/tests/fastapi/test_models.py @@ -0,0 +1,186 @@ +""" +Unit tests for fastapi_app/models.py +Tests Pydantic model validation and serialization. +""" +import pytest +from datetime import datetime +from pydantic import ValidationError + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'fastapi_app')) + +from models import ( + RepoInfo, + LabelsGroup, + Labels, + CodeEntry, + LabelUpdateRequest, +) + + +class TestRepoInfo: + """Tests for RepoInfo model.""" + + def test_valid_repo_info(self): + """Test creating a valid RepoInfo.""" + repo = RepoInfo( + url="https://github.com/test/repo", + commit_hash="abc123", + commit_date=datetime(2024, 1, 15, 10, 30, 0), + ) + assert repo.url == "https://github.com/test/repo" + assert repo.commit_hash == "abc123" + assert repo.commit_date == datetime(2024, 1, 15, 10, 30, 0) + + def test_repo_info_from_string_date(self): + """Test RepoInfo with string date (should be parsed).""" + repo = RepoInfo( + url="https://github.com/test/repo", + commit_hash="abc123", + commit_date="2024-01-15T10:30:00", + ) + assert isinstance(repo.commit_date, datetime) + + def test_repo_info_missing_url(self): + """Test RepoInfo requires url.""" + with pytest.raises(ValidationError): + RepoInfo(commit_hash="abc123", commit_date=datetime.now()) + + +class TestLabelsGroup: + """Tests for LabelsGroup model.""" + + def test_default_values(self): + """Test all defaults are False.""" + groups = LabelsGroup() + assert groups.memory_management is False + assert groups.invalid_access is False + assert groups.uninitialized is False + assert groups.concurrency is False + assert groups.logic_error is False + assert groups.resource_leak is False + assert groups.security_portability is False + assert groups.code_quality_performance is False + + def test_set_values(self): + """Test setting specific flags.""" + groups = LabelsGroup(memory_management=True, logic_error=True) + assert groups.memory_management is True + assert groups.logic_error is True + assert groups.invalid_access is False + + +class TestLabels: + """Tests for Labels model.""" + + def test_labels_with_groups(self): + """Test Labels with groups.""" + labels = Labels( + cppcheck=["nullPointer", "memleak"], + groups=LabelsGroup(memory_management=True), + ) + assert labels.cppcheck == ["nullPointer", "memleak"] + assert labels.groups.memory_management is True + assert labels.clang == {} + + def test_labels_default_cppcheck(self): + """Test default empty cppcheck list.""" + labels = Labels(groups=LabelsGroup()) + assert labels.cppcheck == [] + + def test_labels_requires_groups(self): + """Test Labels requires groups field.""" + with pytest.raises(ValidationError): + Labels(cppcheck=["test"]) + + +class TestCodeEntry: + """Tests for CodeEntry model.""" + + def test_valid_code_entry(self, sample_code_entry_dict): + """Test creating a valid CodeEntry.""" + entry = CodeEntry(**sample_code_entry_dict) + assert entry.code_original == sample_code_entry_dict["code_original"] + assert entry.code_hash == sample_code_entry_dict["code_hash"] + assert entry.repo.url == "https://github.com/test/repo" + assert entry.labels.groups.memory_management is True + + def test_code_entry_invalid_hash(self): + """Test CodeEntry rejects invalid code_hash format.""" + with pytest.raises(ValidationError) as exc_info: + CodeEntry( + code_original="int main() {}", + code_hash="not_a_valid_hash", # Not 64 hex chars + repo=RepoInfo( + url="https://github.com/test", + commit_hash="abc", + commit_date=datetime.now(), + ), + ingest_timestamp=datetime.now(), + labels=Labels(groups=LabelsGroup()), + ) + assert "code_hash" in str(exc_info.value) + + def test_code_entry_id_alias(self, sample_code_entry_dict): + """Test _id alias works correctly.""" + entry = CodeEntry(**sample_code_entry_dict) + assert entry.id == "507f1f77bcf86cd799439011" + + def test_code_entry_optional_id(self): + """Test CodeEntry works without _id.""" + entry = CodeEntry( + code_original="int main() {}", + code_hash="a" * 64, + repo=RepoInfo( + url="https://github.com/test", + commit_hash="abc", + commit_date=datetime.now(), + ), + ingest_timestamp=datetime.now(), + labels=Labels(groups=LabelsGroup()), + ) + assert entry.id is None + + def test_code_entry_optional_code_fixed(self): + """Test code_fixed is optional.""" + entry = CodeEntry( + code_original="int main() {}", + code_hash="a" * 64, + repo=RepoInfo( + url="https://github.com/test", + commit_hash="abc", + commit_date=datetime.now(), + ), + ingest_timestamp=datetime.now(), + labels=Labels(groups=LabelsGroup()), + ) + assert entry.code_fixed is None + + +class TestLabelUpdateRequest: + """Tests for LabelUpdateRequest model.""" + + def test_add_labels_only(self): + """Test request with only add labels.""" + req = LabelUpdateRequest(add=["MemError", "LogicError"]) + assert req.add == ["MemError", "LogicError"] + assert req.remove == [] + + def test_remove_labels_only(self): + """Test request with only remove labels.""" + req = LabelUpdateRequest(remove=["Concurrency"]) + assert req.add == [] + assert req.remove == ["Concurrency"] + + def test_both_add_and_remove(self): + """Test request with both add and remove.""" + req = LabelUpdateRequest(add=["MemError"], remove=["LogicError"]) + assert req.add == ["MemError"] + assert req.remove == ["LogicError"] + + def test_empty_request(self): + """Test empty request (valid but does nothing).""" + req = LabelUpdateRequest() + assert req.add == [] + assert req.remove == [] diff --git a/tests/scraper/__init__.py b/tests/scraper/__init__.py new file mode 100644 index 0000000..0c1de10 --- /dev/null +++ b/tests/scraper/__init__.py @@ -0,0 +1 @@ +# Scraper Module Tests diff --git a/tests/scraper/test_analyzers.py b/tests/scraper/test_analyzers.py new file mode 100644 index 0000000..2133f60 --- /dev/null +++ b/tests/scraper/test_analyzers.py @@ -0,0 +1,303 @@ +""" +Unit tests for scraper/labeling/analyzers.py +Tests analyzer wrappers with mocked subprocesses. +""" +import pytest +from unittest.mock import MagicMock, patch +import subprocess + + +class TestCppcheckAnalyzer: + """Tests for CppcheckAnalyzer class.""" + + def test_init_finds_cppcheck(self): + """Test analyzer finds cppcheck in PATH.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/cppcheck" + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + + assert analyzer.cppcheck_path == "/usr/bin/cppcheck" + assert analyzer.timeout == 30 + + def test_init_cppcheck_not_found(self): + """Test analyzer raises when cppcheck not found.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = None + + from scraper.labeling.analyzers import CppcheckAnalyzer + + with pytest.raises(RuntimeError, match="cppcheck not found"): + CppcheckAnalyzer(timeout=30) + + def test_init_with_temp_dir(self): + """Test analyzer accepts temp_dir parameter.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/cppcheck" + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30, temp_dir="/dev/shm") + + assert analyzer.temp_dir == "/dev/shm" + + def test_run_empty_code(self): + """Test run with empty code returns empty list.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/cppcheck" + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("") + + assert result == [] + + def test_run_whitespace_only(self): + """Test run with whitespace-only code returns empty list.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/cppcheck" + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run(" \n\t ") + + assert result == [] + + def test_run_finds_issues(self, tmp_path): + """Test run finds and parses issues.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which, \ + patch('scraper.labeling.analyzers.subprocess.run') as mock_run, \ + patch('scraper.labeling.analyzers.tempfile.NamedTemporaryFile') as mock_tempfile, \ + patch('scraper.labeling.analyzers.Path') as mock_path: + + mock_which.return_value = "/usr/bin/cppcheck" + + mock_file = MagicMock() + mock_file.name = str(tmp_path / "test.cpp") + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_tempfile.return_value = mock_file + + mock_result = MagicMock() + mock_result.stderr = "/tmp/test.cpp:5:10: error: Null pointer dereference [nullPointer]\n" + mock_run.return_value = mock_result + + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("int main() { int *p; *p = 1; }") + + assert len(result) == 1 + assert result[0]["id"] == "nullPointer" + + def test_run_filters_suppressed_issues(self, tmp_path): + """Test run filters out suppressed issues.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which, \ + patch('scraper.labeling.analyzers.subprocess.run') as mock_run, \ + patch('scraper.labeling.analyzers.tempfile.NamedTemporaryFile') as mock_tempfile, \ + patch('scraper.labeling.analyzers.Path') as mock_path: + + mock_which.return_value = "/usr/bin/cppcheck" + + mock_file = MagicMock() + mock_file.name = str(tmp_path / "test.cpp") + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_tempfile.return_value = mock_file + + mock_result = MagicMock() + mock_result.stderr = """ +/tmp/test.cpp:1:0: information: Missing include [missingInclude] +/tmp/test.cpp:5:10: error: Null pointer [nullPointer] +""" + mock_run.return_value = mock_result + + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("int main() {}") + + # missingInclude should be filtered + issue_ids = [r["id"] for r in result] + assert "missingInclude" not in issue_ids + assert "nullPointer" in issue_ids + + def test_run_handles_timeout(self, tmp_path): + """Test run handles subprocess timeout.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which, \ + patch('scraper.labeling.analyzers.subprocess.run') as mock_run, \ + patch('scraper.labeling.analyzers.tempfile.NamedTemporaryFile') as mock_tempfile, \ + patch('scraper.labeling.analyzers.Path') as mock_path: + + mock_which.return_value = "/usr/bin/cppcheck" + + mock_file = MagicMock() + mock_file.name = str(tmp_path / "test.cpp") + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_tempfile.return_value = mock_file + + mock_run.side_effect = subprocess.TimeoutExpired("cppcheck", 30) + + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("int main() {}") + + assert result == [] + + def test_run_handles_file_not_found(self, tmp_path): + """Test run handles cppcheck not found at runtime.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which, \ + patch('scraper.labeling.analyzers.subprocess.run') as mock_run, \ + patch('scraper.labeling.analyzers.tempfile.NamedTemporaryFile') as mock_tempfile, \ + patch('scraper.labeling.analyzers.Path') as mock_path: + + mock_which.return_value = "/usr/bin/cppcheck" + + mock_file = MagicMock() + mock_file.name = str(tmp_path / "test.cpp") + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_tempfile.return_value = mock_file + + mock_run.side_effect = FileNotFoundError() + + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("int main() {}") + + assert result == [] + + def test_run_no_issues_found(self, tmp_path): + """Test run with clean code.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which, \ + patch('scraper.labeling.analyzers.subprocess.run') as mock_run, \ + patch('scraper.labeling.analyzers.tempfile.NamedTemporaryFile') as mock_tempfile, \ + patch('scraper.labeling.analyzers.Path') as mock_path: + + mock_which.return_value = "/usr/bin/cppcheck" + + mock_file = MagicMock() + mock_file.name = str(tmp_path / "test.cpp") + mock_file.__enter__ = MagicMock(return_value=mock_file) + mock_file.__exit__ = MagicMock(return_value=False) + mock_tempfile.return_value = mock_file + + mock_result = MagicMock() + mock_result.stderr = "" + mock_run.return_value = mock_result + + mock_path_instance = MagicMock() + mock_path.return_value = mock_path_instance + + from scraper.labeling.analyzers import CppcheckAnalyzer + + analyzer = CppcheckAnalyzer(timeout=30) + result = analyzer.run("int main() { return 0; }") + + assert result == [] + + +class TestClangTidyAnalyzer: + """Tests for ClangTidyAnalyzer class.""" + + def test_init_finds_clang_tidy(self): + """Test analyzer finds clang-tidy in PATH.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/clang-tidy" + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + analyzer = ClangTidyAnalyzer(timeout=30) + + assert analyzer.clang_tidy_path == "/usr/bin/clang-tidy" + + def test_init_clang_tidy_not_found(self): + """Test analyzer raises when clang-tidy not found.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = None + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + with pytest.raises(RuntimeError, match="clang-tidy not found"): + ClangTidyAnalyzer(timeout=30) + + def test_run_empty_code(self): + """Test run with empty code.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/clang-tidy" + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + analyzer = ClangTidyAnalyzer(timeout=30) + result = analyzer.run("") + + assert result == [] + + def test_parse_clang_output_warning(self): + """Test parsing clang-tidy warning output.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/clang-tidy" + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + analyzer = ClangTidyAnalyzer(timeout=30) + + output = "/tmp/test.cpp:5:10: warning: some warning [check-name]" + result = analyzer._parse_clang_output(output) + + assert len(result) == 1 + assert result[0]["id"] == "check-name" + + def test_parse_clang_output_multiple(self): + """Test parsing multiple warnings.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/clang-tidy" + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + analyzer = ClangTidyAnalyzer(timeout=30) + + output = """ +/tmp/test.cpp:5:10: warning: warning 1 [check-1] +/tmp/test.cpp:10:5: error: error 1 [check-2] +""" + result = analyzer._parse_clang_output(output) + + assert len(result) == 2 + ids = [r["id"] for r in result] + assert "check-1" in ids + assert "check-2" in ids + + def test_parse_clang_output_generic_warning(self): + """Test parsing warning without check name.""" + with patch('scraper.labeling.analyzers.shutil.which') as mock_which: + mock_which.return_value = "/usr/bin/clang-tidy" + + from scraper.labeling.analyzers import ClangTidyAnalyzer + + analyzer = ClangTidyAnalyzer(timeout=30) + + output = "/tmp/test.cpp:5:10: warning: some generic warning" + result = analyzer._parse_clang_output(output) + + assert len(result) == 1 + assert result[0]["id"] == "generic-warning" diff --git a/tests/scraper/test_config_mapper.py b/tests/scraper/test_config_mapper.py new file mode 100644 index 0000000..ad24710 --- /dev/null +++ b/tests/scraper/test_config_mapper.py @@ -0,0 +1,144 @@ +""" +Unit tests for scraper/labeling/config_mapper.py +Tests configuration-based label mapping. +""" +import pytest +import json + + +class TestConfigBasedMapper: + """Tests for ConfigBasedMapper class.""" + + @pytest.fixture + def mock_config_path(self, tmp_path): + """Create a mock labels_config.json file.""" + config = { + "error_classification": { + "memory_management": ["memleak", "deallocuse", "doubleFree"], + "invalid_access": ["nullPointer", "arrayIndexOutOfBounds"], + "uninitialized": ["uninitvar", "uninitMemberVar"], + "concurrency": ["raceCondition", "deadlock"], + "logic_error": ["duplicateBreak", "unreachableCode"], + "resource_leak": ["resourceLeak", "fdLeak"], + "security_portability": ["bufferAccessOutOfBounds"], + "code_quality_performance": ["unusedVariable", "constParameter"], + }, + "ignore_list": ["syntaxError", "preprocessorError", "checkersReport"], + } + config_path = tmp_path / "labels_config.json" + config_path.write_text(json.dumps(config)) + return str(config_path) + + def test_mapper_init(self, mock_config_path): + """Test mapper initialization.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + assert mapper.error_classification is not None + assert mapper.ignore_set is not None + assert "syntaxError" in mapper.ignore_set + + def test_filter_issues_removes_ignored(self, mock_config_path): + """Test filter_issues removes ignored issues.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + issues = ["nullPointer", "syntaxError", "memleak", "preprocessorError"] + filtered = mapper.filter_issues(issues) + + assert "nullPointer" in filtered + assert "memleak" in filtered + assert "syntaxError" not in filtered + assert "preprocessorError" not in filtered + + def test_filter_issues_empty_list(self, mock_config_path): + """Test filter_issues with empty list.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + filtered = mapper.filter_issues([]) + + assert filtered == [] + + def test_filter_issues_all_ignored(self, mock_config_path): + """Test filter_issues when all are ignored.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + issues = ["syntaxError", "preprocessorError", "checkersReport"] + filtered = mapper.filter_issues(issues) + + assert filtered == [] + + def test_map_to_groups_single_category(self, mock_config_path): + """Test mapping single issue to group.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + groups = mapper.map_to_groups(["nullPointer"]) + + assert groups["invalid_access"] is True + assert groups["memory_management"] is False + assert groups["uninitialized"] is False + + def test_map_to_groups_multiple_categories(self, mock_config_path): + """Test mapping multiple issues to multiple groups.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + groups = mapper.map_to_groups(["nullPointer", "memleak", "raceCondition"]) + + assert groups["invalid_access"] is True + assert groups["memory_management"] is True + assert groups["concurrency"] is True + assert groups["uninitialized"] is False + + def test_map_to_groups_all_false_when_empty(self, mock_config_path): + """Test all groups are False when no issues.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + groups = mapper.map_to_groups([]) + + assert all(v is False for v in groups.values()) + assert len(groups) == 8 + + def test_map_to_groups_unknown_issue(self, mock_config_path): + """Test unknown issues don't affect groups.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + groups = mapper.map_to_groups(["unknownIssue"]) + + assert all(v is False for v in groups.values()) + + def test_issue_to_category_mapping(self, mock_config_path): + """Test issue_to_category reverse mapping.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + assert mapper.issue_to_category["nullPointer"] == "invalid_access" + assert mapper.issue_to_category["memleak"] == "memory_management" + assert mapper.issue_to_category["raceCondition"] == "concurrency" + assert mapper.issue_to_category["unusedVariable"] == "code_quality_performance" + + def test_map_to_groups_ignores_filtered_issues(self, mock_config_path): + """Test map_to_groups ignores issues in ignore set.""" + from scraper.labeling.config_mapper import ConfigBasedMapper + + mapper = ConfigBasedMapper(mock_config_path) + + # Even if we pass ignored issues, they shouldn't affect groups + groups = mapper.map_to_groups(["syntaxError", "nullPointer"]) + + assert groups["invalid_access"] is True + # syntaxError is in ignore list, shouldn't map to anything diff --git a/tests/scraper/test_config_utils.py b/tests/scraper/test_config_utils.py new file mode 100644 index 0000000..72bdb3d --- /dev/null +++ b/tests/scraper/test_config_utils.py @@ -0,0 +1,176 @@ +""" +Unit tests for scraper/config/config_utils.py +Tests configuration loading and validation. +""" +import pytest +from unittest.mock import patch, mock_open +import json + + +class TestLoadConfig: + """Tests for load_config function.""" + + def test_load_valid_config(self): + """Test loading valid configuration.""" + config_data = { + "repositories": [ + "https://github.com/test/repo1", + {"url": "https://github.com/test/repo2"} + ], + "github_tokens": ["token1", "token2"], + "target_record_count": 500 + } + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("test.json") + + assert len(result.repositories) == 2 + assert result.repositories[0].url == "https://github.com/test/repo1" + assert result.github_tokens == ["token1", "token2"] + assert result.target_record_count == 500 + + def test_load_config_with_dates(self): + """Test loading config with date fields.""" + config_data = { + "repositories": [ + { + "url": "https://github.com/test/repo", + "start_date": "2024-01-01", + "end_date": "2024-12-31" + } + ] + } + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("config.json") + + assert len(result.repositories) == 1 + assert result.repositories[0].start_date is not None + + def test_load_file_not_found(self): + """Test handling file not found returns empty config.""" + with patch('builtins.open', side_effect=FileNotFoundError()): + from scraper.config.config_utils import load_config + + result = load_config("nonexistent.json") + + # Returns empty ScraperConfig + assert len(result.repositories) == 0 + + def test_load_invalid_json(self): + """Test handling invalid JSON returns empty config.""" + mock_file = mock_open(read_data="not valid json {") + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("invalid.json") + + # Returns empty ScraperConfig + assert len(result.repositories) == 0 + + def test_load_empty_repositories(self): + """Test loading config with empty repositories.""" + config_data = {"repositories": []} + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("empty.json") + + assert len(result.repositories) == 0 + + def test_load_config_with_fix_regexes(self): + """Test loading config with custom fix regexes.""" + config_data = { + "repositories": [ + {"url": "https://github.com/test/repo", "fix_regexes": [r"custom.*fix"]} + ], + "fix_regexes": [r"global.*fix"] + } + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("config.json") + + assert len(result.repositories) == 1 + # Repo has its own fix_regexes + assert result.repositories[0].fix_regexes == [r"custom.*fix"] + + def test_load_config_with_legacy_token(self): + """Test loading config with legacy single token.""" + config_data = { + "repositories": ["https://github.com/test/repo"], + "github_token": "legacy_token" + } + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("config.json") + + assert result.github_token == "legacy_token" + + def test_load_config_with_workers(self): + """Test loading config with worker count.""" + config_data = { + "repositories": ["https://github.com/test/repo"], + "num_consumer_workers": 8 + } + mock_file = mock_open(read_data=json.dumps(config_data)) + + with patch('builtins.open', mock_file): + from scraper.config.config_utils import load_config + + result = load_config("config.json") + + assert result.num_consumer_workers == 8 + + +class TestParseDate: + """Tests for parse_date function.""" + + def test_parse_valid_date_string(self): + """Test parsing valid date string.""" + from scraper.config.config_utils import parse_date + + result = parse_date("2024-06-15") + + assert result is not None + assert result.year == 2024 + assert result.month == 6 + assert result.day == 15 + + def test_parse_none(self): + """Test parsing None returns None.""" + from scraper.config.config_utils import parse_date + + result = parse_date(None) + + assert result is None + + def test_parse_invalid_date_string(self): + """Test parsing invalid date returns None.""" + from scraper.config.config_utils import parse_date + + result = parse_date("not-a-date") + + assert result is None + + def test_parse_empty_string(self): + """Test parsing empty string returns None.""" + from scraper.config.config_utils import parse_date + + result = parse_date("") + + assert result is None diff --git a/tests/scraper/test_engine.py b/tests/scraper/test_engine.py new file mode 100644 index 0000000..2f8ca48 --- /dev/null +++ b/tests/scraper/test_engine.py @@ -0,0 +1,343 @@ +""" +Unit tests for scraper/core/engine.py +Tests the Producer-Consumer scraper engine functions. +""" +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +import hashlib + + +class TestCalculateHash: + """Tests for calculate_hash function.""" + + def test_hash_empty_string(self): + """Test hashing empty string.""" + from scraper.core.engine import calculate_hash + + result = calculate_hash("") + expected = hashlib.sha256("".encode("utf-8")).hexdigest() + + assert result == expected + + def test_hash_simple_string(self): + """Test hashing simple string.""" + from scraper.core.engine import calculate_hash + + result = calculate_hash("hello world") + expected = hashlib.sha256("hello world".encode("utf-8")).hexdigest() + + assert result == expected + + def test_hash_code_content(self): + """Test hashing code content.""" + from scraper.core.engine import calculate_hash + + code = "int main() { return 0; }" + result = calculate_hash(code) + + assert len(result) == 64 # SHA-256 produces 64 hex chars + assert result == hashlib.sha256(code.encode("utf-8")).hexdigest() + + def test_hash_deterministic(self): + """Test that same input produces same hash.""" + from scraper.core.engine import calculate_hash + + result1 = calculate_hash("test") + result2 = calculate_hash("test") + + assert result1 == result2 + + def test_hash_different_inputs(self): + """Test that different inputs produce different hashes.""" + from scraper.core.engine import calculate_hash + + result1 = calculate_hash("abc") + result2 = calculate_hash("abd") + + assert result1 != result2 + + +class TestGetRepoSlug: + """Tests for get_repo_slug function.""" + + def test_standard_url(self): + """Test parsing standard GitHub URL.""" + from scraper.core.engine import get_repo_slug + + result = get_repo_slug("https://github.com/owner/repo") + + assert result == "owner/repo" + + def test_url_with_git_suffix(self): + """Test parsing URL with .git suffix.""" + from scraper.core.engine import get_repo_slug + + result = get_repo_slug("https://github.com/owner/repo.git") + + assert result == "owner/repo" + + def test_url_with_trailing_slash(self): + """Test parsing URL with trailing content.""" + from scraper.core.engine import get_repo_slug + + result = get_repo_slug("https://github.com/owner/repo/tree/main") + + assert result == "owner/repo" + + def test_invalid_url(self): + """Test invalid URL raises ValueError.""" + from scraper.core.engine import get_repo_slug + + with pytest.raises(ValueError, match="Invalid GitHub URL"): + get_repo_slug("https://gitlab.com/owner/repo") + + +class TestCandidateTask: + """Tests for CandidateTask dataclass.""" + + def test_create_task(self): + """Test creating CandidateTask.""" + from scraper.core.engine import CandidateTask + + task = CandidateTask( + code_original="int x;", + code_fixed="int x = 0;", + repo_url="https://github.com/test/repo", + commit_sha="abc123", + commit_date="2024-01-01T00:00:00", + base_name="file.c" + ) + + assert task.code_original == "int x;" + assert task.code_fixed == "int x = 0;" + assert task.repo_url == "https://github.com/test/repo" + assert task.commit_sha == "abc123" + + +class TestGetGithubContent: + """Tests for get_github_content function.""" + + def test_get_content_success(self): + """Test successful content retrieval.""" + from scraper.core.engine import get_github_content + + mock_repo = MagicMock() + mock_content = MagicMock() + mock_content.decoded_content = b"int main() {}" + mock_repo.get_contents.return_value = mock_content + + result = get_github_content(mock_repo, "abc123", "file.c") + + assert result == "int main() {}" + mock_repo.get_contents.assert_called_once_with("file.c", ref="abc123") + + def test_get_content_error(self): + """Test content retrieval returns empty on error.""" + from scraper.core.engine import get_github_content + + mock_repo = MagicMock() + mock_repo.get_contents.side_effect = Exception("Not found") + + result = get_github_content(mock_repo, "abc123", "file.c") + + assert result == "" + + +class TestGetAllRepoFiles: + """Tests for get_all_repo_files function.""" + + def test_get_files_success(self): + """Test successful file list retrieval.""" + from scraper.core.engine import get_all_repo_files + + mock_repo = MagicMock() + mock_tree = MagicMock() + mock_element1 = MagicMock() + mock_element1.path = "src/main.c" + mock_element2 = MagicMock() + mock_element2.path = "include/header.h" + mock_tree.tree = [mock_element1, mock_element2] + mock_repo.get_git_tree.return_value = mock_tree + + result = get_all_repo_files(mock_repo, "abc123") + + assert result == ["src/main.c", "include/header.h"] + mock_repo.get_git_tree.assert_called_once_with("abc123", recursive=True) + + def test_get_files_error(self): + """Test file list returns empty on error.""" + from scraper.core.engine import get_all_repo_files + + mock_repo = MagicMock() + mock_repo.get_git_tree.side_effect = Exception("Error") + + result = get_all_repo_files(mock_repo, "abc123") + + assert result == [] + + +class TestFindCorrespondingFile: + """Tests for find_corresponding_file function.""" + + def test_find_header_for_source(self): + """Test finding header file for source file.""" + from scraper.core.engine import find_corresponding_file + + all_files = ["src/main.c", "src/main.h", "src/utils.c"] + + result = find_corresponding_file("src/main.c", [".h"], all_files) + + assert result == "src/main.h" + + def test_find_source_for_header(self): + """Test finding source file for header.""" + from scraper.core.engine import find_corresponding_file + + all_files = ["src/main.c", "src/main.h", "src/utils.cpp"] + + result = find_corresponding_file("src/main.h", [".c", ".cpp"], all_files) + + assert result == "src/main.c" + + def test_no_corresponding_file(self): + """Test when no corresponding file exists.""" + from scraper.core.engine import find_corresponding_file + + all_files = ["src/main.c", "src/other.h"] + + result = find_corresponding_file("src/main.c", [".h"], all_files) + + assert result is None + + +class TestFormatContext: + """Tests for format_context function.""" + + def test_format_with_both(self): + """Test formatting with header and implementation.""" + from scraper.core.engine import format_context + + result = format_context("int x;", "int main() {}") + + assert "int x;" in result + assert "int main() {}" in result + + def test_format_with_empty_header(self): + """Test formatting with empty header.""" + from scraper.core.engine import format_context + + result = format_context("", "int main() {}") + + assert "int main() {}" in result + + def test_format_with_empty_impl(self): + """Test formatting with empty implementation.""" + from scraper.core.engine import format_context + + result = format_context("int x;", "") + + assert "int x;" in result + + +class TestInsertPayloadToDb: + """Tests for insert_payload_to_db function.""" + + def test_insert_success(self): + """Test successful DB insert.""" + from scraper.core.engine import insert_payload_to_db + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"id": "123"} + + with patch('scraper.core.engine.requests.post', return_value=mock_response): + result = insert_payload_to_db({"code": "test"}) + + assert result == "123" + + def test_insert_duplicate(self): + """Test duplicate entry returns None.""" + from scraper.core.engine import insert_payload_to_db + + mock_response = MagicMock() + mock_response.status_code = 409 + + with patch('scraper.core.engine.requests.post', return_value=mock_response): + result = insert_payload_to_db({"code": "test"}) + + assert result is None + + def test_insert_error(self): + """Test API error returns None.""" + from scraper.core.engine import insert_payload_to_db + import requests as req + + with patch('scraper.core.engine.requests.post', side_effect=req.exceptions.ConnectionError("Error")): + result = insert_payload_to_db({"code": "test"}) + + assert result is None + + +class TestSavePayloadToFile: + """Tests for save_payload_to_file function.""" + + def test_save_creates_file(self, tmp_path): + """Test saving payload creates file.""" + from scraper.core.engine import save_payload_to_file + + payload = {"code": "int main() {}", "code_hash": "abc123"} + + save_payload_to_file(payload, str(tmp_path)) + + files = list(tmp_path.glob("*.json")) + assert len(files) == 1 + + def test_save_with_custom_directory(self, tmp_path): + """Test saving to custom directory.""" + from scraper.core.engine import save_payload_to_file + import json + + payload = {"code": "test", "code_hash": "hash123"} + output_dir = tmp_path / "output" + + save_payload_to_file(payload, str(output_dir)) + + assert output_dir.exists() + files = list(output_dir.glob("*.json")) + assert len(files) == 1 + + # Verify content + with open(files[0]) as f: + saved = json.load(f) + assert saved["code"] == "test" + + +class TestPoisonPill: + """Tests for poison pill constant.""" + + def test_poison_pill_is_none(self): + """Test POISON_PILL is None.""" + from scraper.core.engine import POISON_PILL + + assert POISON_PILL is None + + +class TestApiUrl: + """Tests for API_URL constant.""" + + def test_api_url_default(self): + """Test default API_URL value.""" + import os + # Clear env var if set + original = os.environ.pop("API_URL", None) + + try: + # Need to reload module to get default + import importlib + import scraper.core.engine as engine + importlib.reload(engine) + + assert "fastapi" in engine.API_URL or "localhost" in engine.API_URL or "8000" in engine.API_URL + finally: + if original: + os.environ["API_URL"] = original diff --git a/tests/scraper/test_labeler.py b/tests/scraper/test_labeler.py new file mode 100644 index 0000000..77d3897 --- /dev/null +++ b/tests/scraper/test_labeler.py @@ -0,0 +1,181 @@ +""" +Unit tests for scraper/labeling/labeler.py +Tests labeling functionality with mocked analyzers. +""" +import pytest +from unittest.mock import MagicMock, patch +import json +import tempfile +import os + + +class TestLabeler: + """Tests for Labeler class.""" + + @pytest.fixture + def mock_labels_config(self, tmp_path): + """Create a mock labels_config.json file.""" + config = { + "error_classification": { + "memory_management": ["memleak", "deallocuse"], + "invalid_access": ["nullPointer", "arrayIndexOutOfBounds"], + "uninitialized": ["uninitvar", "uninitMemberVar"], + "concurrency": ["raceCondition"], + "logic_error": ["duplicateBreak", "unreachableCode"], + "resource_leak": ["resourceLeak"], + "security_portability": ["bufferAccessOutOfBounds"], + "code_quality_performance": ["unusedVariable", "constParameter"], + }, + "ignore_list": ["syntaxError", "preprocessorError"], + } + config_path = tmp_path / "labels_config.json" + config_path.write_text(json.dumps(config)) + return str(config_path) + + def test_labeler_init(self, mock_labels_config): + """Test Labeler initialization.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer: + mock_analyzer.return_value = MagicMock() + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + + assert labeler.cppcheck is not None + assert labeler.mapper is not None + + def test_analyze_with_issues(self, mock_labels_config): + """Test analyze finds issues in buggy code.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer = MagicMock() + mock_analyzer.run.return_value = [ + {"id": "nullPointer"}, + {"id": "memleak"}, + ] + mock_analyzer_class.return_value = mock_analyzer + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + result = labeler.analyze("int main() { int *p; *p = 1; }") + + assert "cppcheck" in result + assert "groups" in result + assert "nullPointer" in result["cppcheck"] + assert "memleak" in result["cppcheck"] + + def test_analyze_with_fixed_code(self, mock_labels_config): + """Test analyze compares buggy and fixed code.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer = MagicMock() + # Buggy code has two issues + # Fixed code has one issue + call_count = [0] + + def run_side_effect(code): + call_count[0] += 1 + if call_count[0] == 1: + return [{"id": "nullPointer"}, {"id": "memleak"}] + return [{"id": "memleak"}] # nullPointer was fixed + + mock_analyzer.run.side_effect = run_side_effect + mock_analyzer_class.return_value = mock_analyzer + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + result = labeler.analyze("buggy code", "fixed code") + + # Only nullPointer should be in result (it was fixed) + assert "nullPointer" in result["cppcheck"] + assert "memleak" not in result["cppcheck"] + + def test_analyze_no_issues(self, mock_labels_config): + """Test analyze with clean code.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer = MagicMock() + mock_analyzer.run.return_value = [] + mock_analyzer_class.return_value = mock_analyzer + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + result = labeler.analyze("int main() { return 0; }") + + assert result["cppcheck"] == [] + + def test_analyze_filters_ignored_issues(self, mock_labels_config): + """Test analyze filters out ignored issues.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer = MagicMock() + mock_analyzer.run.return_value = [ + {"id": "syntaxError"}, # Should be filtered + {"id": "nullPointer"}, # Should remain + ] + mock_analyzer_class.return_value = mock_analyzer + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + result = labeler.analyze("buggy code") + + assert "syntaxError" not in result["cppcheck"] + assert "nullPointer" in result["cppcheck"] + + def test_analyze_maps_to_groups(self, mock_labels_config): + """Test analyze maps issues to groups correctly.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer = MagicMock() + mock_analyzer.run.return_value = [ + {"id": "nullPointer"}, + {"id": "memleak"}, + ] + mock_analyzer_class.return_value = mock_analyzer + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + result = labeler.analyze("buggy code") + + assert result["groups"]["invalid_access"] is True + assert result["groups"]["memory_management"] is True + assert result["groups"]["concurrency"] is False + + def test_extract_unique_issues(self, mock_labels_config): + """Test _extract_unique_issues extracts unique IDs.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + mock_analyzer_class.return_value = MagicMock() + + from scraper.labeling.labeler import Labeler + + labeler = Labeler(timeout=30, config_path=mock_labels_config) + + results = [ + {"id": "nullPointer"}, + {"id": "nullPointer"}, # Duplicate + {"id": "memleak"}, + {"id": "unknown"}, # Should be filtered + ] + + unique = labeler._extract_unique_issues(results) + + assert len(unique) == 2 + assert "nullPointer" in unique + assert "memleak" in unique + assert "unknown" not in unique + + def test_labeler_with_temp_dir(self, mock_labels_config, tmp_path): + """Test Labeler passes temp_dir to analyzer.""" + with patch('scraper.labeling.labeler.CppcheckAnalyzer') as mock_analyzer_class: + from scraper.labeling.labeler import Labeler + + labeler = Labeler( + timeout=30, + config_path=mock_labels_config, + temp_dir=str(tmp_path) + ) + + mock_analyzer_class.assert_called_once_with( + timeout=30, + temp_dir=str(tmp_path) + ) diff --git a/tests/scraper/test_scraper_config.py b/tests/scraper/test_scraper_config.py new file mode 100644 index 0000000..12b6396 --- /dev/null +++ b/tests/scraper/test_scraper_config.py @@ -0,0 +1,253 @@ +""" +Unit tests for scraper/config/scraper_config.py +Tests ScraperConfig and RepoConfig dataclasses. +""" +import pytest +from unittest.mock import patch +import os +from datetime import date + + +class TestRepoConfig: + """Tests for RepoConfig dataclass.""" + + def test_init_with_url_only(self): + """Test creating RepoConfig with just URL.""" + from scraper.config.scraper_config import RepoConfig + + config = RepoConfig(url="https://github.com/owner/repo") + + assert config.url == "https://github.com/owner/repo" + assert config.start_date is None + assert config.end_date is None + assert config.fix_regexes == [] + + def test_init_with_dates(self): + """Test creating RepoConfig with date range.""" + from scraper.config.scraper_config import RepoConfig + + start = date(2024, 1, 1) + end = date(2024, 12, 31) + config = RepoConfig( + url="https://github.com/owner/repo", + start_date=start, + end_date=end + ) + + assert config.start_date == start + assert config.end_date == end + + def test_init_with_regexes(self): + """Test creating RepoConfig with fix regexes.""" + from scraper.config.scraper_config import RepoConfig + + regexes = [r"fix.*bug", r"resolve.*issue"] + config = RepoConfig( + url="https://github.com/owner/repo", + fix_regexes=regexes + ) + + assert config.fix_regexes == regexes + assert len(config.fix_regexes) == 2 + + def test_full_config(self): + """Test fully specified RepoConfig.""" + from scraper.config.scraper_config import RepoConfig + + config = RepoConfig( + url="https://github.com/test/repo", + start_date=date(2023, 6, 1), + end_date=date(2023, 12, 31), + fix_regexes=[r"bug\s*fix"] + ) + + assert config.url == "https://github.com/test/repo" + assert config.start_date.month == 6 + assert len(config.fix_regexes) == 1 + + +class TestScraperConfig: + """Tests for ScraperConfig dataclass.""" + + def test_init_minimal(self): + """Test ScraperConfig with minimal required fields.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig(repositories=repos) + + assert len(config.repositories) == 1 + assert config.github_tokens == [] + assert config.target_record_count == 1000 + assert config.queue_max_size == 100 + + def test_init_with_tokens(self): + """Test ScraperConfig with GitHub tokens.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + github_tokens=["token1", "token2"] + ) + + assert config.github_tokens == ["token1", "token2"] + + def test_init_with_custom_workers(self): + """Test ScraperConfig with custom worker count.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + num_consumer_workers=4 + ) + + assert config.num_consumer_workers == 4 + + def test_init_with_custom_temp_dir(self): + """Test ScraperConfig with custom temp directory.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + temp_work_dir="/custom/temp" + ) + + assert config.temp_work_dir == "/custom/temp" + + def test_init_with_legacy_token(self): + """Test ScraperConfig with legacy single token.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + github_token="legacy_token" + ) + + assert config.github_token == "legacy_token" + + def test_get_effective_tokens_from_list(self): + """Test get_effective_tokens returns list tokens.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + github_tokens=["token1", "token2"] + ) + + with patch.dict(os.environ, {}, clear=True): + if "GITHUB_TOKEN" in os.environ: + del os.environ["GITHUB_TOKEN"] + + tokens = config.get_effective_tokens() + + assert "token1" in tokens + assert "token2" in tokens + + def test_get_effective_tokens_includes_legacy(self): + """Test get_effective_tokens includes legacy token.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + github_tokens=["token1"], + github_token="legacy_token" + ) + + with patch.dict(os.environ, {}, clear=True): + tokens = config.get_effective_tokens() + + assert "token1" in tokens + assert "legacy_token" in tokens + + def test_get_effective_tokens_includes_env(self): + """Test get_effective_tokens includes environment token.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig(repositories=repos) + + with patch.dict(os.environ, {"GITHUB_TOKEN": "env_token"}): + tokens = config.get_effective_tokens() + + assert "env_token" in tokens + + def test_get_effective_tokens_no_duplicates(self): + """Test get_effective_tokens removes duplicates.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig( + repositories=repos, + github_tokens=["same_token"], + github_token="same_token" + ) + + with patch.dict(os.environ, {"GITHUB_TOKEN": "same_token"}): + tokens = config.get_effective_tokens() + + # Should contain only one instance + assert tokens.count("same_token") == 1 + + def test_get_effective_tokens_empty(self): + """Test get_effective_tokens with no tokens.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig(repositories=repos) + + with patch.dict(os.environ, {}, clear=True): + # Ensure GITHUB_TOKEN is not set + env_backup = os.environ.get("GITHUB_TOKEN") + if "GITHUB_TOKEN" in os.environ: + del os.environ["GITHUB_TOKEN"] + + try: + tokens = config.get_effective_tokens() + assert tokens == [] + finally: + if env_backup: + os.environ["GITHUB_TOKEN"] = env_backup + + def test_default_temp_dir_uses_tempfile(self): + """Test default temp_work_dir uses tempfile.gettempdir().""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + import tempfile + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig(repositories=repos) + + # Should use system temp directory + assert config.temp_work_dir == tempfile.gettempdir() or callable(config.temp_work_dir) + + def test_default_workers_based_on_cpu(self): + """Test default workers is based on CPU count.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [RepoConfig(url="https://github.com/test/repo")] + config = ScraperConfig(repositories=repos) + + # Should be at least 1 + assert config.num_consumer_workers >= 1 + + def test_multiple_repositories(self): + """Test ScraperConfig with multiple repositories.""" + from scraper.config.scraper_config import ScraperConfig, RepoConfig + + repos = [ + RepoConfig(url="https://github.com/owner1/repo1"), + RepoConfig(url="https://github.com/owner2/repo2"), + RepoConfig(url="https://github.com/owner3/repo3"), + ] + config = ScraperConfig( + repositories=repos, + target_record_count=5000 + ) + + assert len(config.repositories) == 3 + assert config.target_record_count == 5000 diff --git a/tests/scraper/test_server.py b/tests/scraper/test_server.py new file mode 100644 index 0000000..a3b070e --- /dev/null +++ b/tests/scraper/test_server.py @@ -0,0 +1,179 @@ +""" +Unit tests for scraper/network/server.py +Tests socket server with mocked connections. +""" +import pytest +from unittest.mock import MagicMock, patch +import socket + + +class TestSendProgress: + """Tests for send_progress function.""" + + def test_send_progress_with_connection(self): + """Test sending progress when connected.""" + import scraper.network.server as server + + mock_conn = MagicMock() + server._current_conn = mock_conn + + server.send_progress(5, 10, "abc1234") + + mock_conn.sendall.assert_called_once() + call_args = mock_conn.sendall.call_args[0][0] + assert b"PROGRESS: 5/10" in call_args + assert b"abc1234" in call_args + + def test_send_progress_without_connection(self): + """Test sending progress when not connected.""" + import scraper.network.server as server + + server._current_conn = None + + # Should not raise + server.send_progress(5, 10, "abc1234") + + def test_send_progress_with_error(self): + """Test sending progress when socket errors.""" + import scraper.network.server as server + + mock_conn = MagicMock() + mock_conn.sendall.side_effect = OSError("connection reset") + server._current_conn = mock_conn + + # Should not raise, just log + server.send_progress(5, 10, "abc1234") + + +class TestStartServer: + """Tests for start_server function.""" + + def test_server_handles_scrape_command(self): + """Test server processes SCRAPE command.""" + from scraper.network.server import start_server + + mock_callback = MagicMock() + + with patch('scraper.network.server.socket.socket') as mock_socket_class: + mock_server_socket = MagicMock() + mock_conn = MagicMock() + + mock_socket_class.return_value = mock_server_socket + + # First accept returns connection, second raises to exit loop + call_count = [0] + + def accept_side_effect(): + call_count[0] += 1 + if call_count[0] == 1: + return (mock_conn, ("127.0.0.1", 12345)) + raise KeyboardInterrupt() + + mock_server_socket.accept.side_effect = accept_side_effect + mock_conn.recv.return_value = b"SCRAPE config.json" + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + + try: + start_server(mock_callback) + except KeyboardInterrupt: + pass + + mock_callback.assert_called_once() + # Verify ACK was sent + mock_conn.sendall.assert_called() + + def test_server_handles_invalid_command(self): + """Test server handles invalid command.""" + from scraper.network.server import start_server + + mock_callback = MagicMock() + + with patch('scraper.network.server.socket.socket') as mock_socket_class: + mock_server_socket = MagicMock() + mock_conn = MagicMock() + + mock_socket_class.return_value = mock_server_socket + + call_count = [0] + + def accept_side_effect(): + call_count[0] += 1 + if call_count[0] == 1: + return (mock_conn, ("127.0.0.1", 12345)) + raise KeyboardInterrupt() + + mock_server_socket.accept.side_effect = accept_side_effect + mock_conn.recv.return_value = b"INVALID command" + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + + try: + start_server(mock_callback) + except KeyboardInterrupt: + pass + + # Callback should NOT be called for invalid command + mock_callback.assert_not_called() + # Error response should be sent + assert mock_conn.sendall.called + call_args = mock_conn.sendall.call_args[0][0] + assert b"ERROR" in call_args + + def test_server_handles_empty_data(self): + """Test server handles empty data.""" + from scraper.network.server import start_server + + mock_callback = MagicMock() + + with patch('scraper.network.server.socket.socket') as mock_socket_class: + mock_server_socket = MagicMock() + mock_conn = MagicMock() + + mock_socket_class.return_value = mock_server_socket + + call_count = [0] + + def accept_side_effect(): + call_count[0] += 1 + if call_count[0] <= 2: + return (mock_conn, ("127.0.0.1", 12345)) + raise KeyboardInterrupt() + + mock_server_socket.accept.side_effect = accept_side_effect + # First connection sends empty, second sends valid then exits + recv_count = [0] + + def recv_side_effect(size): + recv_count[0] += 1 + if recv_count[0] == 1: + return b"" # Empty + return b"SCRAPE config.json" + + mock_conn.recv.side_effect = recv_side_effect + mock_conn.__enter__ = MagicMock(return_value=mock_conn) + mock_conn.__exit__ = MagicMock(return_value=False) + + try: + start_server(mock_callback) + except KeyboardInterrupt: + pass + + def test_server_binds_correctly(self): + """Test server binds to correct address.""" + from scraper.network.server import start_server + + mock_callback = MagicMock() + + with patch('scraper.network.server.socket.socket') as mock_socket_class: + mock_server_socket = MagicMock() + mock_socket_class.return_value = mock_server_socket + mock_server_socket.accept.side_effect = KeyboardInterrupt() + + try: + start_server(mock_callback) + except KeyboardInterrupt: + pass + + mock_server_socket.bind.assert_called_once_with(("0.0.0.0", 8080)) + mock_server_socket.listen.assert_called_once()