From 8536fec8ef875a9fa5b65382ad062049dca25064 Mon Sep 17 00:00:00 2001 From: Eli <43382407+eli64s@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:48:40 -0600 Subject: [PATCH] Reduce complexity of various methods. --- readmeai/cli/options.py | 15 +-- readmeai/core/model.py | 160 ++++++++++++++---------- readmeai/exceptions.py | 26 ++-- readmeai/main.py | 6 +- readmeai/services/git_metadata.py | 2 +- readmeai/services/git_utils.py | 5 +- readmeai/settings/dependency_files.toml | 7 ++ readmeai/settings/ignore_files.toml | 1 + scripts/run_batch.sh | 6 +- tests/conftest.py | 21 ++++ tests/test_core/test_model.py | 120 +++++++++++------- tests/test_core/test_preprocess.py | 96 +++++++------- tests/test_exceptions.py | 19 +-- 13 files changed, 275 insertions(+), 209 deletions(-) diff --git a/readmeai/cli/options.py b/readmeai/cli/options.py index 1e51efc0..2b1e1087 100644 --- a/readmeai/cli/options.py +++ b/readmeai/cli/options.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Optional import click from click import Context, Parameter @@ -11,9 +12,9 @@ def prompt_for_custom_image( - context: Context | None, - parameter: Parameter | None, - value: str | None, + context: Optional[Context], + parameter: Optional[Parameter], + value: Optional[str], ) -> str: """Prompt the user for a custom image URL.""" if value == ImageOptions.CUSTOM.name: @@ -40,9 +41,7 @@ def prompt_for_custom_image( badges = click.option( "-b", "--badges", - type=click.Choice( - [opt.value for opt in BadgeOptions], case_sensitive=False - ), + type=click.Choice([opt.value for opt in BadgeOptions], case_sensitive=False), default=BadgeOptions.DEFAULT.value, help="""\ Badge icon style types to select from when generating README.md badges. The following options are currently available:\n @@ -66,9 +65,7 @@ def prompt_for_custom_image( image = click.option( "-i", "--image", - type=click.Choice( - [opt.name for opt in ImageOptions], case_sensitive=False - ), + type=click.Choice([opt.name for opt in ImageOptions], case_sensitive=False), default=ImageOptions.DEFAULT.name, callback=prompt_for_custom_image, show_choices=True, diff --git a/readmeai/core/model.py b/readmeai/core/model.py index 46776e02..72f91abd 100644 --- a/readmeai/core/model.py +++ b/readmeai/core/model.py @@ -5,8 +5,14 @@ from contextlib import asynccontextmanager from typing import Any, Dict, List, Tuple, Union -import aiohttp import openai +from httpx import ( + AsyncClient, + HTTPStatusError, + Limits, + NetworkError, + TimeoutException, +) from tenacity import ( retry, retry_if_exception_type, @@ -32,20 +38,28 @@ def __init__(self, config: settings.AppConfig) -> None: """Initializes the GPT language model API handler.""" self.cache = {} self.config = config - self.encoder = config.llm.encoding self.logger = Logger(__name__) - self.prompts = config.prompts - self.tokens = config.llm.tokens - self.tokens_max = config.llm.tokens_max - self.http_client = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=20), - connector=aiohttp.TCPConnector( - limit=200, limit_per_host=100, enable_cleanup_closed=True - ), - ) - self.rate_limit_semaphore = asyncio.Semaphore(config.llm.rate_limit) + self._llm_attributes() + self._http_client() self._handle_response = functools.lru_cache(maxsize=100)(self._handle_response) + def _llm_attributes(self): + """Initializes basic attributes for the class.""" + self.encoder = self.config.llm.encoding + self.prompts = self.config.prompts + self.tokens = self.config.llm.tokens + self.tokens_max = self.config.llm.tokens_max + self.rate_limit = self.config.llm.rate_limit + + def _http_client(self): + """Configures the HTTP client for the class.""" + self.http_client = AsyncClient( + http2=True, + timeout=20, + limits=Limits(max_keepalive_connections=100, max_connections=20), + ) + self.rate_limit_semaphore = asyncio.Semaphore(self.rate_limit) + @asynccontextmanager async def use_api(self) -> None: """Context manager for HTTP client used by the LLM API.""" @@ -56,7 +70,7 @@ async def use_api(self) -> None: async def close(self) -> None: """Closes the HTTP client.""" - await self.http_client.close() + await self.http_client.aclose() async def batch_request( self, @@ -66,21 +80,42 @@ async def batch_request( ) -> List[str]: """Generates text for the README.md file using GPT language models.""" prompts = await self._set_prompt_context(file_context, dependencies, summaries) + responses = await self._batch_prompts(prompts) + return responses + + async def _batch_prompts( + self, prompts: List[Union[str, Tuple[str, str]]], batch_size: int = 5 + ): + """Processes prompts in batches and returns the generated text.""" responses = [] - for batch in self._batch_prompts(prompts): + + for batch in self._generate_batches(prompts, batch_size): batch_responses = await asyncio.gather( - *[self._process_prompt(prompt) for prompt in batch] + *[self._process_batch(prompt) for prompt in batch] # , return_exceptions=True ) responses.extend(batch_responses) + return responses - def _batch_prompts( - self, prompts: List[Union[str, Tuple[str, str]]], batch_size: int = 5 - ) -> List[List[Union[str, Tuple[str, str]]]]: - """Batches prompts for the LLM API.""" - for i in range(0, len(prompts), batch_size): - yield prompts[i : i + batch_size] + def _generate_batches(self, items: List[Any], batch_size: int): + """Generator to create batches from a list of items.""" + for i in range(0, len(items), batch_size): + yield items[i : i + batch_size] + + async def _process_batch(self, prompt: Dict[str, Any]) -> str: + """Processes a prompt and returns the generated text.""" + if prompt["type"] == "summaries": + return await self._handle_code_summary_response(prompt["context"]) + else: + formatted_prompt = self._get_prompt_context( + prompt["type"], prompt["context"] + ) + tokens = adjust_max_tokens(self.tokens, formatted_prompt) + _, summary = await self._handle_response( + prompt["type"], formatted_prompt, tokens + ) + return summary async def _set_prompt_context( self, @@ -128,35 +163,31 @@ async def _set_prompt_context( ] ] - def _inject_prompt_context(self, prompt_type, context) -> str: + def _get_prompt_context(self, prompt_type, context) -> str: """Generates a prompt for the LLM API.""" + prompt_template = self._get_prompt_template(prompt_type) + if not prompt_template: + self.logger.error(f"Prompt type '{prompt_type}' not found.") + return "" + return self._inject_prompt_context(prompt_template, context) + + def _get_prompt_template(self, prompt_type: str) -> str: + """Retrieves the template for the given prompt type.""" prompt_templates = { "features": self.prompts.features, "overview": self.prompts.overview, "slogan": self.prompts.slogan, } - prompt_template = prompt_templates.get(prompt_type) + return prompt_templates.get(prompt_type, "") - if prompt_template: - return prompt_template.format(*[context[key] for key in context]) - else: - self.logger.error(f"Unknown prompt type: {prompt_type}") + def _inject_prompt_context(self, template: str, context: dict) -> str: + """Formats the template with the provided context.""" + try: + return template.format(*[context[key] for key in context]) + except KeyError as exc: + self.logger.error(f"Missing context for prompt key: {exc}") return "" - async def _process_prompt(self, prompt: Dict[str, Any]) -> str: - """Processes a prompt and returns the generated text.""" - if prompt["type"] == "summaries": - return await self._handle_code_summary_response(prompt["context"]) - else: - formatted_prompt = self._inject_prompt_context( - prompt["type"], prompt["context"] - ) - tokens = adjust_max_tokens(self.tokens, formatted_prompt) - _, summary = await self._handle_response( - prompt["type"], formatted_prompt, tokens - ) - return summary - async def _handle_code_summary_response( self, file_context: List[Tuple[str, str]] ) -> List[Tuple[str, str]]: @@ -190,9 +221,9 @@ async def _handle_code_summary_response( wait=wait_exponential(multiplier=1, min=2, max=6), retry=retry_if_exception_type( ( - aiohttp.ClientConnectionError, - aiohttp.ClientResponseError, - aiohttp.ServerTimeoutError, + HTTPStatusError, + NetworkError, + TimeoutException, openai.error.OpenAIError, ) ), @@ -228,24 +259,25 @@ async def _handle_response( prompt = truncate_tokens(self.encoder, prompt, tokens) try: - async with self.rate_limit_semaphore, self.http_client.post( - self.config.llm.endpoint, - headers={"Authorization": f"Bearer {openai.api_key}"}, - json={ - "messages": [ - { - "role": "system", - "content": self.config.llm.content, - }, - {"role": "user", "content": prompt}, - ], - "model": self.config.llm.model, - "temperature": self.config.llm.temperature, - "max_tokens": tokens, - }, - ) as response: + async with self.rate_limit_semaphore: + response = await self.http_client.post( + self.config.llm.endpoint, + headers={"Authorization": f"Bearer {openai.api_key}"}, + json={ + "messages": [ + { + "role": "system", + "content": self.config.llm.content, + }, + {"role": "user", "content": prompt}, + ], + "model": self.config.llm.model, + "temperature": self.config.llm.temperature, + "max_tokens": tokens, + }, + ) response.raise_for_status() - llm_response = await response.json() + llm_response = response.json() llm_text = llm_response["choices"][0]["message"]["content"] llm_text = ( format_sentence(llm_text) @@ -257,9 +289,9 @@ async def _handle_response( return index, llm_text except ( - aiohttp.ClientConnectionError, - aiohttp.ClientResponseError, - aiohttp.ServerTimeoutError, + HTTPStatusError, + NetworkError, + TimeoutException, openai.error.OpenAIError, ) as exc: error_msg = f"Error generating text for {index}: {exc}" diff --git a/readmeai/exceptions.py b/readmeai/exceptions.py index 95b8c219..8caf774b 100644 --- a/readmeai/exceptions.py +++ b/readmeai/exceptions.py @@ -6,16 +6,10 @@ class ReadmeAiException(Exception): """Base exception for the readme-ai application.""" - pass + ... -class RepositoryError(ReadmeAiException): - """Exceptions related to repository operations.""" - - pass - - -class GitCloneError(RepositoryError): +class GitCloneError(ReadmeAiException): """Could not clone repository.""" def __init__(self, repository: str, *args): @@ -23,16 +17,12 @@ def __init__(self, repository: str, *args): super().__init__(f"Failed to clone repository: {repository}", *args) -class ReadmeGenerationError(ReadmeAiException): +class ReadmeGeneratorError(ReadmeAiException): """Exceptions related to readme generation.""" - pass - - -class ApiCommunicationError(ReadmeAiException): - """Exceptions related to external APIs.""" - - pass + def __init__(self, traceback, *args): + self.traceback = traceback + super().__init__(f"Error generating readme: {traceback}", *args) class FileSystemError(ReadmeAiException): @@ -46,10 +36,10 @@ def __init__(self, message, path, *args): class FileReadError(FileSystemError): """Could not read file.""" - pass + ... class FileWriteError(FileSystemError): """Could not write file.""" - pass + ... diff --git a/readmeai/main.py b/readmeai/main.py index 44ff73c1..de8cbcea 100644 --- a/readmeai/main.py +++ b/readmeai/main.py @@ -23,7 +23,7 @@ from readmeai.core.logger import Logger from readmeai.core.model import ModelHandler from readmeai.core.preprocess import FileData, process_repository -from readmeai.exceptions import ReadmeGenerationError +from readmeai.exceptions import ReadmeGeneratorError from readmeai.markdown.builder import build_readme_md from readmeai.services.git_utils import clone_to_temporary_directory @@ -82,7 +82,7 @@ async def readme_agent(conf: AppConfig, conf_helper: ConfigHelper) -> None: logger.info(f"README.md file generated successfully @ {conf.files.output}") except Exception as exc: - raise ReadmeGenerationError(exc, traceback.format_exc()) from exc + raise ReadmeGeneratorError(traceback.format_exc()) from exc def main( @@ -124,7 +124,7 @@ def main( asyncio.run(readme_agent(conf, conf_helper)) except Exception as exc: - raise ReadmeGenerationError(exc, traceback.format_exc()) from exc + raise ReadmeGeneratorError(exc, traceback.format_exc()) from exc def setup_environment(config: AppConfig, api_key: str) -> None: diff --git a/readmeai/services/git_metadata.py b/readmeai/services/git_metadata.py index d5b3cf7d..7ce459ad 100644 --- a/readmeai/services/git_metadata.py +++ b/readmeai/services/git_metadata.py @@ -77,7 +77,7 @@ async def _fetch_git_metadata( async def git_api_request( session: aiohttp.ClientSession, repo_url: str -) -> GitHubRepoMetadata | None: +) -> Optional[GitHubRepoMetadata]: """Retrieves repo metadata and returns a GitHubRepoMetadata instance.""" api_url = await fetch_git_api_url(repo_url) if not api_url: diff --git a/readmeai/services/git_utils.py b/readmeai/services/git_utils.py index f8f30111..5a15e528 100644 --- a/readmeai/services/git_utils.py +++ b/readmeai/services/git_utils.py @@ -4,6 +4,7 @@ import platform import shutil from pathlib import Path +from typing import Optional import git @@ -67,7 +68,7 @@ def fetch_git_file_url(file_path: str, full_name: str, repo_url: str) -> str: return file_path -def find_git_executable() -> Path | None: +def find_git_executable() -> Optional[Path]: """Find the path to the git executable, if available.""" git_exec_path = os.environ.get("GIT_PYTHON_GIT_EXECUTABLE") if git_exec_path: @@ -102,7 +103,7 @@ def validate_file_permissions(temp_dir: Path) -> None: ) -def validate_git_executable(git_exec_path: str | None) -> None: +def validate_git_executable(git_exec_path: str) -> None: """Validate the path to the git executable.""" if not git_exec_path or not Path(git_exec_path).exists(): raise ValueError(f"Git executable not found at {git_exec_path}") diff --git a/readmeai/settings/dependency_files.toml b/readmeai/settings/dependency_files.toml index 8bef0dae..007792ad 100644 --- a/readmeai/settings/dependency_files.toml +++ b/readmeai/settings/dependency_files.toml @@ -2,6 +2,13 @@ [dependency_files] dependency_files = [ + # Docker + 'Dockerfile', + 'docker-compose.yml', + 'docker-compose.yaml', + 'docker-compose.override.yml', + 'docker-compose.dev.yml', + # C/C++ 'CMakeLists.txt', 'Makefile', diff --git a/readmeai/settings/ignore_files.toml b/readmeai/settings/ignore_files.toml index d6da0904..88f3bb92 100644 --- a/readmeai/settings/ignore_files.toml +++ b/readmeai/settings/ignore_files.toml @@ -65,6 +65,7 @@ extensions = [ #'json5', #'jsonl', 'key', + 'lock', 'lockb', 'log', 'md', diff --git a/scripts/run_batch.sh b/scripts/run_batch.sh index 6b46346a..2b8303db 100644 --- a/scripts/run_batch.sh +++ b/scripts/run_batch.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -version="0.4.992" +version="0.4.998" run_date=$(date +"%Y%m%d") filenames=( #"readme-litellm" @@ -44,8 +44,8 @@ for index in "${!repositories[@]}"; do alignment=${align[$RANDOM % ${#align[@]}]} rand_choice=$((RANDOM % 2)) - #cmd="python3 -m readmeai.cli.commands -o \"$filename\" -r \"$repo\"" - cmd="readmeai -o \"$filename\" -r \"$repo\"" + cmd="python3 -m readmeai.cli.commands -o \"$filename\" -r \"$repo\"" + #cmd="readmeai -o \"$filename\" -r \"$repo\"" if [ "$random_badge" != "default" ]; then cmd+=" -b \"$random_badge\"" diff --git a/tests/conftest.py b/tests/conftest.py index b2975812..96f2f44f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,8 @@ load_config, load_config_helper, ) +from readmeai.core.model import ModelHandler +from readmeai.core.preprocess import FileData, RepoProcessor @pytest.fixture(scope="session") @@ -36,3 +38,22 @@ def mock_summaries(): ("/path/to/file2.py", "This is summary for file2.py"), (".github/workflows/ci.yml", "This is summary for ci.yml"), ] + + +@pytest.fixture(scope="session") +def mock_file_data(mock_dependencies): + """Returns the default file data.""" + return FileData( + path="/path/to/file1.py", + name="file1.py", + content="This is content of file1.py", + extension="py", + tokens=10, + dependencies=mock_dependencies, + ) + + +@pytest.fixture(scope="session") +def repo_processor(mock_config, mock_config_helper): + """Fixture for RepoProcessor class.""" + return RepoProcessor(mock_config, mock_config_helper) diff --git a/tests/test_core/test_model.py b/tests/test_core/test_model.py index 20f2cc1b..875e5722 100644 --- a/tests/test_core/test_model.py +++ b/tests/test_core/test_model.py @@ -1,8 +1,8 @@ """Unit tests for the GPT LLM API handler.""" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, Mock, call, patch -import aiohttp +import httpx import pytest from readmeai.core.model import ModelHandler @@ -23,9 +23,7 @@ def json(self): def raise_for_status(self): """Raise an error if the status code is not 200.""" if self.status_code != 200: - raise aiohttp.ClientResponseError( - message="HTTP Error", request=MockRequest() - ) + raise httpx.HTTPStatusError(message="HTTP Error", request=MockRequest()) class MockRequest: @@ -36,7 +34,7 @@ def __init__(self): self.url = "http://mockurl.com" -class MockHTTPStatusError(aiohttp.ClientResponseError): +class MockHTTPStatusError(httpx.HTTPStatusError): """Mock HTTP status error.""" def __init__(self): @@ -49,13 +47,32 @@ def __init__(self): @pytest.mark.asyncio -async def test_batch_request(mock_config, mock_dependencies, mock_summaries): - handler = ModelHandler(mock_config) - handler._process_prompt = AsyncMock(side_effect=lambda p: f"Processed: {p}") - responses = await handler.batch_request( - [Mock(), Mock()], mock_dependencies, mock_summaries +async def test_batch_request( + mock_config, mock_file_data, mock_dependencies, mock_summaries +): + """Test the batch_request function.""" + model_handler = ModelHandler(mock_config) + patch.object(model_handler, "_set_prompt_context", return_value=...) + patch.object(model_handler, "_batch_prompts", return_value=...) + responses = await model_handler.batch_request( + mock_file_data, mock_dependencies, mock_summaries ) - assert "Processed" in responses[0] + model_handler.close() + assert isinstance(responses, list) + + +def test_generate_batches(mock_config): + """Test the _generate_batches function.""" + batch_size = 3 + items = [1, 2, 3, 4, 5, 6, 7, 8] + model_handler = ModelHandler(mock_config) + batches = list(model_handler._generate_batches(items, batch_size)) + model_handler.close() + assert isinstance(batches, list) + assert len(batches) == 3 + assert batches[0] == [1, 2, 3] + assert batches[1] == [4, 5, 6] + assert batches[2] == [7, 8] @pytest.mark.asyncio @@ -81,6 +98,7 @@ async def test_set_prompt_context(mock_config, mock_dependencies, mock_summaries mock_dependencies, mock_summaries, ) + handler.close() assert len(prompts) == 4 expected_prompts = [] for prompt, expected in zip(prompts, expected_prompts): @@ -88,73 +106,89 @@ async def test_set_prompt_context(mock_config, mock_dependencies, mock_summaries @pytest.mark.asyncio -async def test_process_prompt_summaries(mock_config): +async def test_process_batch_summaries(mock_config): """Test the _process_prompt function.""" handler = ModelHandler(mock_config) handler._handle_code_summary_response = AsyncMock(return_value="Processed summary") mock_prompt = {"type": "summaries", "context": "Some context"} - result = await handler._process_prompt(mock_prompt) + result = await handler._process_batch(mock_prompt) + handler.close() assert result == "Processed summary" @pytest.mark.asyncio -async def test_process_prompt_other_types(mock_config): +async def test_process_batch_other_types(mock_config): """Test the _process_prompt function.""" handler = ModelHandler(mock_config) - handler._inject_prompt_context = Mock(return_value="Injected prompt") + handler._get_prompt_context = Mock(return_value="Injected prompt") handler._handle_response = AsyncMock( return_value=("type", "Processed other prompt") ) mock_prompt = {"type": "overview", "context": "Some context"} - result = await handler._process_prompt(mock_prompt) + result = await handler._process_batch(mock_prompt) + handler.close() assert result == "Processed other prompt" @pytest.mark.asyncio async def test_handle_response(mock_config): """Test the _handle_response function.""" + content = "Extension for Python code files?" handler = ModelHandler(mock_config) handler.http_client.post_async = AsyncMock( - side_effect=[ - MockResponse( - json_data={ - "choices": [ - { - "message": { - "content": "Python is a programming language for .py files." - } - } - ] - } - ), - ], - ) - index, response = await handler._handle_response( - "overview", "what programming language is .py?", 50 + side_effect=MockResponse( + json_data={"choices": [{"message": {"content": content}}]} + ) ) + index, response = await handler._handle_response("overview", content, 30) await handler.close() - assert index == "overview" - assert "python" in response.lower() + assert "py" in response @pytest.mark.asyncio -async def test_handle_response_client_connection_error(mock_config): - with patch("aiohttp.ClientSession.post") as mock_post: - mock_post.side_effect = aiohttp.ClientConnectionError() +async def test_handle_response_http_status_error(mock_config): + with patch( + "httpx.AsyncClient.post", + side_effect=httpx.HTTPStatusError( + response=MockResponse(status_code=404), + request=MockRequest(), + message="HttpStatusError", + ), + ): + handler = ModelHandler(mock_config) + index, response = await handler._handle_response("overview", "test prompt", 50) + await handler.close() + assert "HttpStatusError" in response + + +@pytest.mark.asyncio +async def test_handle_response_network_error(mock_config): + with patch( + "httpx.AsyncClient.post", + side_effect=httpx.NetworkError( + request=MockRequest(), + message="NetworkError", + ), + ): handler = ModelHandler(mock_config) index, response = await handler._handle_response("overview", "test prompt", 50) await handler.close() - assert "aiohttp.client_exceptions.ClientConnectionError" in response + assert "NetworkError" in response @pytest.mark.asyncio -async def test_handle_response_server_timeout_error(mock_config): - with patch("aiohttp.ClientSession.post") as mock_post: - mock_post.side_effect = aiohttp.ServerTimeoutError() +async def test_handle_response_timeout_error(mock_config): + with patch( + "httpx.AsyncClient.post", + side_effect=httpx.TimeoutException( + request=MockRequest(), + message="TimeoutException", + ), + ): handler = ModelHandler(mock_config) index, response = await handler._handle_response("overview", "test prompt", 50) await handler.close() - assert "aiohttp.client_exceptions.ServerTimeoutError" in response + assert "TimeoutException" in response @pytest.mark.asyncio diff --git a/tests/test_core/test_preprocess.py b/tests/test_core/test_preprocess.py index db213cfe..07c750e7 100644 --- a/tests/test_core/test_preprocess.py +++ b/tests/test_core/test_preprocess.py @@ -9,9 +9,29 @@ @pytest.fixture -def repo_processor(mock_config, mock_config_helper): - """Fixture for RepoProcessor.""" - return RepoProcessor(mock_config, mock_config_helper) +def mock_file_data(): + file1 = FileData( + path="path/to/file1.py", + name="file1.py", + content="", + extension="py", + dependencies=["dependency1"], + ) + file2 = FileData( + path="path/to/file2.js", + name="file2.js", + content="", + extension="js", + dependencies=["dependency2"], + ) + file3 = FileData( + path="path/to/file3.txt", + name="file3.txt", + content="", + extension="txt", + dependencies=[], + ) + return [file1, file2, file3] def test_generate_contents(repo_processor, tmp_path): @@ -53,9 +73,7 @@ def test_generate_file_info(repo_processor, tmp_path): def test_generate_file_info_exception_handling(repo_processor, caplog): """Test the generate_file_info method.""" mock_file = MagicMock() - mock_file.open.side_effect = UnicodeDecodeError( - "utf-8", b"", 0, 1, "error" - ) + mock_file.open.side_effect = UnicodeDecodeError("utf-8", b"", 0, 1, "error") mock_path = MagicMock() mock_path.rglob.return_value = [mock_file] list(repo_processor.generate_file_info(mock_path)) @@ -81,32 +99,33 @@ def test_extract_dependencies(repo_processor): ) mock_parser = MagicMock() mock_parser.parse.return_value = ["flask==1.1.4"] - with patch( - "readmeai.parsers.factory.parser_factory", return_value=mock_parser - ): + with patch("readmeai.parsers.factory.parser_factory", return_value=mock_parser): result = repo_processor.extract_dependencies(file_data) assert "flask" in result -def test_language_mapper(repo_processor): - """Test the language_mapper method.""" +@pytest.mark.parametrize( + "file_extension, expected", + [ + ("py", "python"), + ("js", "javascript"), + ("md", "markdown"), + ("txt", "text"), + ("rs", "rust"), + ], +) +def test_language_mapping(repo_processor, file_extension, expected): + """Test method that maps file extensions to programming languages.""" contents = [ FileData( - name="main.py", - path=Path("main.py"), - content="import streamlit as st\nimport pandas as pd", - extension="py", - ), - FileData( - name="README.md", - path=Path("README.md"), - content="## This is a test README file", - extension="md", + name=f"main.{file_extension}", + path=Path(f"main.{file_extension}"), + content="...", + extension=file_extension, ), ] updated = repo_processor.language_mapper(contents) - assert updated[0].language == "python" - assert updated[1].language == "markdown" + assert updated[0].language == expected @patch("readmeai.core.tokens.token_counter", return_value=7) @@ -140,32 +159,6 @@ def test_tokenize_content_offline_mode(repo_processor): assert result[0].tokens == 0 -@pytest.fixture -def mock_file_data(): - file1 = FileData( - path="path/to/file1.py", - name="file1.py", - content="", - extension="py", - dependencies=["dependency1"], - ) - file2 = FileData( - path="path/to/file2.js", - name="file2.js", - content="", - extension="js", - dependencies=["dependency2"], - ) - file3 = FileData( - path="path/to/file3.txt", - name="file3.txt", - content="", - extension="txt", - dependencies=[], - ) - return [file1, file2, file3] - - def test_get_dependencies_normal_behavior( mock_file_data, mock_config, mock_config_helper ): @@ -184,9 +177,6 @@ def test_get_dependencies_exception_handling( ): """Test the get_dependencies method.""" processor = RepoProcessor(mock_config, mock_config_helper) - processor.extract_dependencies = MagicMock( - side_effect=Exception("Test exception") - ) + processor.extract_dependencies = MagicMock(side_effect=Exception("Test exception")) dependencies = processor.get_dependencies(mock_file_data) - assert isinstance(dependencies, list) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index b95e7cc9..b4727a3c 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,38 +1,31 @@ """Tests for the custom exceptions module.""" from readmeai.exceptions import ( - ApiCommunicationError, FileReadError, FileSystemError, FileWriteError, GitCloneError, ReadmeAiException, - ReadmeGenerationError, + ReadmeGeneratorError, ) def test_readme_ai_exception(): """Test the ReadmeAIException class.""" ex = ReadmeAiException("General error") - assert str(ex) == "General error" + assert isinstance(ex, Exception) def test_git_clone_exception(): """Test the RepositoryCloneException class.""" - ex = GitCloneError("https://example.com/repo", ValueError()) - assert "Failed to clone repository" in str(ex) + ex = GitCloneError("https://example.com/repo", "Traceback") + assert isinstance(ex, ReadmeAiException) def test_readme_generation_exception(): """Test the ReadmeGenerationException class.""" - ex = ReadmeGenerationError("Error during README generation") - assert str(ex) == "Error during README generation" - - -def test_api_communication_exception(): - """Test the APICommunicationException class.""" - ex = ApiCommunicationError("API communication error") - assert str(ex) == "API communication error" + ex = ReadmeGeneratorError("Traceback") + assert isinstance(ex, ReadmeAiException) def test_read_file_exception():