diff --git a/.github/workflows/python-test.yaml b/.github/workflows/python-test.yaml index 45902a32..a80770f9 100644 --- a/.github/workflows/python-test.yaml +++ b/.github/workflows/python-test.yaml @@ -3,7 +3,7 @@ name: "Run Python Tests" on: pull_request: types: [opened, reopened, synchronize] - branches: [dev] + branches: [dev, dev-workflow-test] paths: - "src/ecooptimizer/**/*.py" diff --git a/.gitignore b/.gitignore index 51b86108..95b60b23 100644 --- a/.gitignore +++ b/.gitignore @@ -286,4 +286,24 @@ TSWLatexianTemp* # DRAW.IO files *.drawio -*.drawio.bkp \ No newline at end of file +*.drawio.bkp + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] + +.venv/ + +# Rope +.ropeproject + +*.egg-info/ + +# Package files +outputs/ +build/ +tests/temp_dir/ + +# Coverage +.coverage +coverage.* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..17662ddc --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "plugin/capstone--sco-vs-code-plugin"] + path = plugin/capstone--sco-vs-code-plugin + url = https://github.com/ssm-lab/capstone--sco-vs-code-plugin.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..2ad9d923 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.7.4 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format + \ No newline at end of file diff --git a/docs/projMngmnt/Rev0_Team_Contrib.pdf b/docs/projMngmnt/Rev0_Team_Contrib.pdf index 4d8f2f1a..b614dae0 100644 Binary files a/docs/projMngmnt/Rev0_Team_Contrib.pdf and b/docs/projMngmnt/Rev0_Team_Contrib.pdf differ diff --git a/src/analyzers/__init__.py b/plugin/README.md similarity index 100% rename from src/analyzers/__init__.py rename to plugin/README.md diff --git a/plugin/capstone--sco-vs-code-plugin b/plugin/capstone--sco-vs-code-plugin new file mode 160000 index 00000000..55908450 --- /dev/null +++ b/plugin/capstone--sco-vs-code-plugin @@ -0,0 +1 @@ +Subproject commit 55908450f8041d4a4ad041eada803597bf5d0bfc diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..81ef3535 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,139 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "ecooptimizer" +version = "0.0.1" +dependencies = [ + "pylint", + "rope", + "astor", + "codecarbon", + "asttokens", + "uvicorn", + "fastapi", + "pydantic", + "libcst", + "websockets", +] +requires-python = ">=3.9" +authors = [ + { name = "Sevhena Walker" }, + { name = "Mya Hussain" }, + { name = "Nivetha Kuruparan" }, + { name = "Ayushi Amin" }, + { name = "Tanveer Brar" }, +] + +description = "A source code eco optimizer" +readme = "README.md" +license = { file = "LICENSE" } + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-mock", + "ruff", + "coverage", + "pyright", + "pre-commit", +] + +[project.scripts] +eco-local = "ecooptimizer.__main__:main" +eco-ext = "ecooptimizer.api.__main__:main" +eco-ext-dev = "ecooptimizer.api.__main__:dev" + +[project.urls] +Documentation = "https://readthedocs.org" +Repository = "https://github.com/ssm-lab/capstone--source-code-optimizer" +"Bug Tracker" = "https://github.com/ssm-lab/capstone--source-code-optimizer/issues" + +[tool.pytest.ini_options] +norecursedirs = ["tests/temp*", "tests/input", "tests/_input_copies"] +addopts = ["--basetemp=tests/temp_dir"] +testpaths = ["tests"] +pythonpath = "src" + +[tool.coverage.run] +omit = [ + "*/__main__.py", + '*/__init__.py', + '*/utils/*', + "*/test_*.py", + "*/analyzers/*_analyzer.py", + "*/api/app.py", +] + +[tool.ruff] +extend-exclude = [ + "*tests/input/**/*.py", + "tests/_input_copies", + "tests/temp_dir", +] +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # Enforce Python Error rules (e.g., syntax errors, exceptions). + "UP", # Check for unnecessary passes and other unnecessary constructs. + "ANN001", # Ensure type annotations are present where needed. + "ANN002", + "ANN003", + "ANN401", + "INP", # Flag invalid Python patterns or usage. + "PTH", # Check path-like or import-related issues. + "F", # Enforce function-level checks (e.g., complexity, arguments). + "B", # Enforce best practices for Python coding (general style rules). + "PT", # Enforce code formatting and Pythonic idioms. + "W", # Enforce warnings (e.g., suspicious constructs or behaviours). + "A", # Flag common anti-patterns or bad practices. + "RUF", # Ruff-specific rules. + "ARG", # Check for function argument issues., +] + +# Avoid enforcing line-length violations (`E501`) +ignore = ["E501", "RUF003"] + +# Avoid trying to fix flake8-bugbear (`B`) violations. +unfixable = ["B"] + +# Ignore `E402` (import violations) in all `__init__.py` files, and in selected subdirectories. +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402"] +"**/{tests,docs,tools}/*" = ["E402", "ANN", "INP001"] + +[tool.ruff.lint.flake8-annotations] +suppress-none-returning = true +mypy-init-return = true + +[tool.pyright] +include = ["src", "tests"] +exclude = ["tests/input", "tests/_input*", "tests/temp_dir"] + +disableBytesTypePromotions = true +reportAttributeAccessIssue = false +reportPropertyTypeMismatch = true +reportFunctionMemberAccess = true +reportMissingImports = true +reportUnusedVariable = "warning" +reportDuplicateImport = "warning" +reportUntypedFunctionDecorator = true +reportUntypedClassDecorator = true +reportUntypedBaseClass = true +reportUntypedNamedTuple = true +reportPrivateUsage = true +reportConstantRedefinition = "warning" +reportDeprecated = "warning" +reportIncompatibleMethodOverride = true +reportIncompatibleVariableOverride = true +reportInconsistentConstructor = true +reportOverlappingOverload = true +reportMissingTypeArgument = true +reportCallInDefaultInitializer = "warning" +reportUnnecessaryIsInstance = "warning" +reportUnnecessaryCast = "warning" +reportUnnecessaryComparison = true +reportMatchNotExhaustive = "warning" diff --git a/src/analyzers/base_analyzer.py b/src/analyzers/base_analyzer.py deleted file mode 100644 index cad46036..00000000 --- a/src/analyzers/base_analyzer.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC, abstractmethod - -class BaseAnalyzer(ABC): - def __init__(self, code_path: str): - self.code_path = code_path - - @abstractmethod - def analyze(self): - pass diff --git a/src/analyzers/pylint_analyzer.py b/src/analyzers/pylint_analyzer.py deleted file mode 100644 index c8675a50..00000000 --- a/src/analyzers/pylint_analyzer.py +++ /dev/null @@ -1,70 +0,0 @@ -import subprocess -import json -from analyzers.base_analyzer import BaseAnalyzer - -class PylintAnalyzer(BaseAnalyzer): - def __init__(self, code_path: str): - super().__init__(code_path) - self.code_smells = { - "R0902": "Large Class", # Too many instance attributes - "R0913": "Long Parameter List", # Too many arguments - "R0915": "Long Method", # Too many statements - "C0200": "Complex List Comprehension", # Loop can be simplified - "C0103": "Invalid Naming Convention", # Non-standard names - # Add other pylint codes as needed - } - - def analyze(self): - """ - Runs Pylint on the specified code path and returns a report of code smells. - """ - pylint_command = [ - "pylint", "--output-format=json", self.code_path - ] - - try: - result = subprocess.run(pylint_command, capture_output=True, text=True, check=True) - pylint_output = result.stdout - report = self._parse_pylint_output(pylint_output) - return report - except subprocess.CalledProcessError as e: - print("Pylint analysis failed:", e) - return {} - except FileNotFoundError: - print("Pylint is not installed or not found in PATH.") - return {} - except json.JSONDecodeError: - print("Failed to parse pylint output. Check if pylint output is in JSON format.") - return {} - - def _parse_pylint_output(self, output: str): - """ - Parses the Pylint JSON output to identify specific code smells. - """ - try: - pylint_results = json.loads(output) - except json.JSONDecodeError: - print("Error: Failed to parse pylint output") - return [] - - code_smell_report = [] - - for entry in pylint_results: - message_id = entry.get("message-id") - if message_id in self.code_smells: - code_smell_report.append({ - "type": self.code_smells[message_id], - "message": entry.get("message"), - "line": entry.get("line"), - "column": entry.get("column"), - "path": entry.get("path") - }) - - return code_smell_report - -# Example usage -if __name__ == "__main__": - analyzer = PylintAnalyzer("your_file.py") - report = analyzer.analyze() - for issue in report: - print(f"{issue['type']} at {issue['path']}:{issue['line']}:{issue['column']} - {issue['message']}") diff --git a/src/README.md b/src/ecooptimizer/README.md similarity index 100% rename from src/README.md rename to src/ecooptimizer/README.md diff --git a/src/ecooptimizer/__init__.py b/src/ecooptimizer/__init__.py new file mode 100644 index 00000000..493243ca --- /dev/null +++ b/src/ecooptimizer/__init__.py @@ -0,0 +1,9 @@ +# Path of current directory +from pathlib import Path + +DIRNAME = Path(__file__).parent + +# Entire project directory path +SAMPLE_PROJ_DIR = (DIRNAME / Path("../../tests/input/project_car_stuff")).resolve() +SOURCE = SAMPLE_PROJ_DIR / "main.py" +TEST_FILE = SAMPLE_PROJ_DIR / "test_main.py" diff --git a/src/ecooptimizer/__main__.py b/src/ecooptimizer/__main__.py new file mode 100644 index 00000000..bbe683c2 --- /dev/null +++ b/src/ecooptimizer/__main__.py @@ -0,0 +1,132 @@ +import ast +import logging +from pathlib import Path +import shutil +from tempfile import TemporaryDirectory, mkdtemp # noqa: F401 + +import libcst as cst + +from .utils.output_manager import LoggingManager +from .utils.output_manager import save_file, save_json_files, copy_file_to_output + + +from .api.routes.refactor_smell import ChangedFile, RefactoredData + +from .measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter + +from .analyzers.analyzer_controller import AnalyzerController + +from .refactorers.refactorer_controller import RefactorerController + +from . import ( + SAMPLE_PROJ_DIR, + SOURCE, +) + +from .config import CONFIG + +loggingManager = LoggingManager() + +CONFIG["loggingManager"] = loggingManager + +detect_logger = loggingManager.loggers["detect"] +refactor_logger = loggingManager.loggers["refactor"] + +CONFIG["detectLogger"] = detect_logger +CONFIG["refactorLogger"] = refactor_logger + + +# FILE CONFIGURATION IN __init__.py !!! + + +def main(): + # Save ast + save_file("source_ast.txt", ast.dump(ast.parse(SOURCE.read_text()), indent=4), "w") + save_file("source_cst.txt", str(cst.parse_module(SOURCE.read_text())), "w") + + # Measure initial energy + energy_meter = CodeCarbonEnergyMeter() + energy_meter.measure_energy(Path(SOURCE)) + initial_emissions = energy_meter.emissions + + if not initial_emissions: + logging.error("Could not retrieve initial emissions. Exiting.") + exit(1) + + analyzer_controller = AnalyzerController() + # update_smell_registry(["no-self-use"]) + smells_data = analyzer_controller.run_analysis(SOURCE) + save_json_files("code_smells.json", [smell.model_dump() for smell in smells_data]) + + copy_file_to_output(SOURCE, "refactored-test-case.py") + refactorer_controller = RefactorerController() + output_paths = [] + + for smell in smells_data: + # Use the line below and comment out "with TemporaryDirectory()" if you want to see the refactored code + # It basically copies the source directory into a temp dir that you can find in your systems TEMP folder + # It varies per OS. The location of the folder can be found in the 'refactored-data.json' file in outputs. + # If you use the other line know that you will have to manually delete the temp dir after running the + # code. It will NOT auto delete which, hence allowing you to see the refactoring results + + # tempDir = mkdtemp(prefix="ecooptimizer-") # < UNCOMMENT THIS LINE and shift code under to the left + + with TemporaryDirectory() as tempDir: # COMMENT OUT THIS ONE + source_copy = Path(tempDir) / SAMPLE_PROJ_DIR.name + target_file_copy = Path(str(SOURCE).replace(str(SAMPLE_PROJ_DIR), str(source_copy), 1)) + + # source_copy = project_copy / SOURCE.name + + shutil.copytree(SAMPLE_PROJ_DIR, source_copy) + + try: + modified_files: list[Path] = refactorer_controller.run_refactorer( + target_file_copy, source_copy, smell, overwrite=False + ) + except NotImplementedError as e: + print(e) + continue + + energy_meter.measure_energy(target_file_copy) + final_emissions = energy_meter.emissions + + if not final_emissions: + refactor_logger.error("Could not retrieve final emissions. Discarding refactoring.") + print("Refactoring Failed.\n") + + elif final_emissions >= initial_emissions: + refactor_logger.info("No measured energy savings. Discarding refactoring.\n") + print("Refactoring Failed.\n") + + else: + refactor_logger.info("Energy saved!") + refactor_logger.info( + f"Initial emissions: {initial_emissions} | Final emissions: {final_emissions}" + ) + + print("Refactoring Succesful!\n") + + refactor_data = RefactoredData( + tempDir=tempDir, + targetFile=ChangedFile(original=str(SOURCE), refactored=str(target_file_copy)), + energySaved=(final_emissions - initial_emissions), + affectedFiles=[ + ChangedFile( + original=str(file).replace(str(source_copy), str(SAMPLE_PROJ_DIR)), + refactored=str(file), + ) + for file in modified_files + ], + ) + + output_paths = refactor_data.affectedFiles + + # In reality the original code will now be overwritten but thats too much work + + save_json_files("refactoring-data.json", refactor_data.model_dump()) # type: ignore + + print(output_paths) + + +if __name__ == "__main__": + main() diff --git a/src/measurement/__init__.py b/src/ecooptimizer/analyzers/__init__.py similarity index 100% rename from src/measurement/__init__.py rename to src/ecooptimizer/analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/analyzer_controller.py b/src/ecooptimizer/analyzers/analyzer_controller.py new file mode 100644 index 00000000..65835b0c --- /dev/null +++ b/src/ecooptimizer/analyzers/analyzer_controller.py @@ -0,0 +1,137 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path +from typing import Callable, Any + +from ..data_types.smell_record import SmellRecord + +from ..config import CONFIG + +from ..data_types.smell import Smell + +from .pylint_analyzer import PylintAnalyzer +from .ast_analyzer import ASTAnalyzer +from .astroid_analyzer import AstroidAnalyzer + +from ..utils.smells_registry import retrieve_smell_registry + + +class AnalyzerController: + def __init__(self): + """Initializes analyzers for different analysis methods.""" + self.pylint_analyzer = PylintAnalyzer() + self.ast_analyzer = ASTAnalyzer() + self.astroid_analyzer = AstroidAnalyzer() + + def run_analysis(self, file_path: Path, selected_smells: str | list[str] = "ALL"): + """ + Runs multiple analysis tools on the given Python file and logs the results. + Returns a list of detected code smells. + """ + + smells_data: list[Smell] = [] + + if not selected_smells: + raise TypeError("At least 1 smell must be selected for detection") + + SMELL_REGISTRY = retrieve_smell_registry(selected_smells) + + try: + pylint_smells = self.filter_smells_by_method(SMELL_REGISTRY, "pylint") + ast_smells = self.filter_smells_by_method(SMELL_REGISTRY, "ast") + astroid_smells = self.filter_smells_by_method(SMELL_REGISTRY, "astroid") + + CONFIG["detectLogger"].info("🟒 Starting analysis process") + CONFIG["detectLogger"].info(f"πŸ“‚ Analyzing file: {file_path}") + + if pylint_smells: + CONFIG["detectLogger"].info(f"πŸ” Running Pylint analysis on {file_path}") + pylint_options = self.generate_pylint_options(pylint_smells) + pylint_results = self.pylint_analyzer.analyze(file_path, pylint_options) + smells_data.extend(pylint_results) + CONFIG["detectLogger"].info( + f"βœ… Pylint analysis completed. {len(pylint_results)} smells detected." + ) + + if ast_smells: + CONFIG["detectLogger"].info(f"πŸ” Running AST analysis on {file_path}") + ast_options = self.generate_custom_options(ast_smells) + ast_results = self.ast_analyzer.analyze(file_path, ast_options) + smells_data.extend(ast_results) + CONFIG["detectLogger"].info( + f"βœ… AST analysis completed. {len(ast_results)} smells detected." + ) + + if astroid_smells: + CONFIG["detectLogger"].info(f"πŸ” Running Astroid analysis on {file_path}") + astroid_options = self.generate_custom_options(astroid_smells) + astroid_results = self.astroid_analyzer.analyze(file_path, astroid_options) + smells_data.extend(astroid_results) + CONFIG["detectLogger"].info( + f"βœ… Astroid analysis completed. {len(astroid_results)} smells detected." + ) + + if smells_data: + CONFIG["detectLogger"].info("⚠️ Detected Code Smells:") + for smell in smells_data: + if smell.occurences: + first_occurrence = smell.occurences[0] + total_occurrences = len(smell.occurences) + line_info = ( + f"(Starting at Line {first_occurrence.line}, {total_occurrences} occurrences)" + if total_occurrences > 1 + else f"(Line {first_occurrence.line})" + ) + else: + line_info = "" + + CONFIG["detectLogger"].info(f" β€’ {smell.symbol} {line_info}: {smell.message}") + else: + CONFIG["detectLogger"].info("πŸŽ‰ No code smells detected.") + + except Exception as e: + CONFIG["detectLogger"].error(f"❌ Error during analysis: {e!s}") + + return smells_data + + @staticmethod + def filter_smells_by_method( + smell_registry: dict[str, SmellRecord], method: str + ) -> dict[str, SmellRecord]: + filtered = { + name: smell + for name, smell in smell_registry.items() + if smell["enabled"] and (method == smell["analyzer_method"]) + } + return filtered + + @staticmethod + def generate_pylint_options(filtered_smells: dict[str, SmellRecord]) -> list[str]: + pylint_smell_symbols = [] + extra_pylint_options = [ + "--disable=all", + ] + + for symbol, smell in zip(filtered_smells.keys(), filtered_smells.values()): + pylint_smell_symbols.append(symbol) + + if len(smell["analyzer_options"]) > 0: + for param_data in smell["analyzer_options"].values(): + flag = param_data["flag"] + value = param_data["value"] + if value: + extra_pylint_options.append(f"{flag}={value}") + + extra_pylint_options.append(f"--enable={','.join(pylint_smell_symbols)}") + return extra_pylint_options + + @staticmethod + def generate_custom_options( + filtered_smells: dict[str, SmellRecord], + ) -> list[tuple[Callable, dict[str, Any]]]: # type: ignore + ast_options = [] + for smell in filtered_smells.values(): + method = smell["checker"] + options = smell["analyzer_options"] + ast_options.append((method, options)) + + return ast_options diff --git a/src/ecooptimizer/analyzers/ast_analyzer.py b/src/ecooptimizer/analyzers/ast_analyzer.py new file mode 100644 index 00000000..e9c0b051 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzer.py @@ -0,0 +1,27 @@ +from typing import Callable, Any +from pathlib import Path +from ast import AST, parse + + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class ASTAnalyzer(Analyzer): + def analyze( + self, + file_path: Path, + extra_options: list[tuple[Callable[[Path, AST], list[Smell]], dict[str, Any]]], + ): + smells_data: list[Smell] = [] + + source_code = file_path.read_text() + + tree = parse(source_code) + + for detector, params in extra_options: + if callable(detector): + result = detector(file_path, tree, **params) + smells_data.extend(result) + + return smells_data diff --git a/src/refactorer/__init__.py b/src/ecooptimizer/analyzers/ast_analyzers/__init__.py similarity index 100% rename from src/refactorer/__init__.py rename to src/ecooptimizer/analyzers/ast_analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py new file mode 100644 index 00000000..3fa39d86 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_element_chain.py @@ -0,0 +1,73 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LECSmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def detect_long_element_chain(file_path: Path, tree: ast.AST, threshold: int = 5) -> list[LECSmell]: + """ + Detects long element chains in the given Python code and returns a list of Smell objects. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold (int): The minimum length of a dictionary chain. Default is 3. + + Returns: + list[Smell]: A list of Smell objects, each containing details about a detected long chain. + """ + # Initialize an empty list to store detected Smell objects + results: list[LECSmell] = [] + used_lines = set() + + # Function to calculate the length of a dictionary chain and detect long chains + def check_chain(node: ast.Subscript, chain_length: int = 0): + # Ensure each line is only reported once + if node.lineno in used_lines: + return + + current = node + # Traverse through the chain to count its length + while isinstance(current, ast.Subscript): + chain_length += 1 + current = current.value + + print(chain_length) + if chain_length >= threshold: + # Create a descriptive message for the detected long chain + message = f"Dictionary chain too long ({chain_length}/{threshold})" + print(node.lineno) + # Instantiate a Smell object with details about the detected issue + smell = LECSmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-element-chain", + message=message, + messageId=CustomSmell.LONG_ELEMENT_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + used_lines.add(node.lineno) + results.append(smell) + + # Traverse the AST to identify nodes representing dictionary chains + for node in ast.walk(tree): + if isinstance(node, ast.Subscript): + check_chain(node) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py new file mode 100644 index 00000000..2ff0fccb --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_lambda_expression.py @@ -0,0 +1,152 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LLESmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def count_expressions(node: ast.expr) -> int: + """ + Recursively counts the number of sub-expressions inside a lambda body. + Ensures `sum()` only operates on integers. + """ + if isinstance(node, (ast.BinOp, ast.BoolOp, ast.Compare, ast.Call, ast.IfExp)): + return 1 + sum( + count_expressions(child) + for child in ast.iter_child_nodes(node) + if isinstance(child, ast.expr) + ) + + # Ensure all recursive calls return an integer + return sum( + ( + count_expressions(child) + for child in ast.iter_child_nodes(node) + if isinstance(child, ast.expr) + ), + start=0, + ) + + +# Helper function to get the string representation of the lambda expression +def get_lambda_code(lambda_node: ast.Lambda) -> str: + """ + Constructs the string representation of a lambda expression. + + Args: + lambda_node (ast.Lambda): The lambda node to reconstruct. + + Returns: + str: The string representation of the lambda expression. + """ + # Reconstruct the lambda arguments and body as a string + args = ", ".join(arg.arg for arg in lambda_node.args.args) + + # Convert the body to a string by using ast's built-in functionality + body = ast.unparse(lambda_node.body) + + # Combine to form the lambda expression + return f"lambda {args}: {body}" + + +def detect_long_lambda_expression( + file_path: Path, + tree: ast.AST, + threshold_length: int = 100, + threshold_count: int = 5, +) -> list[LLESmell]: + """ + Detects lambda functions that are too long, either by the number of expressions or the total length in characters. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold_length (int): The maximum number of characters allowed in the lambda expression. + threshold_count (int): The maximum number of expressions allowed inside the lambda function. + + Returns: + list[Smell]: A list of Smell objects, each containing details about detected long lambda functions. + """ + # Initialize an empty list to store detected Smell objects + results: list[LLESmell] = [] + used_lines = set() + + # Function to check the length of lambda expressions + def check_lambda(node: ast.Lambda): + """ + Analyzes a lambda node to check if it exceeds the specified thresholds + for the number of expressions or total character length. + + Args: + node (ast.Lambda): The lambda node to analyze. + """ + # Count the number of expressions in the lambda body + lambda_length = count_expressions(node.body) + + # Check if the lambda expression exceeds the threshold based on the number of expressions + if lambda_length >= threshold_count: + message = f"Lambda function too long ({lambda_length}/{threshold_count} expressions)" + # Initialize the Smell instance + smell = LLESmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-lambda-expression", + message=message, + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + if node.lineno in used_lines: + return + used_lines.add(node.lineno) + results.append(smell) + + # Convert the lambda function to a string and check its total length in characters + lambda_code = get_lambda_code(node) + if len(lambda_code) > threshold_length: + message = f"Lambda function too long ({len(lambda_code)} characters, max {threshold_length})" + smell = LLESmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-lambda-expr", + message=message, + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + if node.lineno in used_lines: + return + used_lines.add(node.lineno) + results.append(smell) + + # Walk through the AST to find lambda expressions + for node in ast.walk(tree): + if isinstance(node, ast.Lambda): + check_lambda(node) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py new file mode 100644 index 00000000..b3d59c73 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_long_message_chain.py @@ -0,0 +1,85 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.smell import LMCSmell +from ...data_types.custom_fields import AdditionalInfo, Occurence + + +def compute_chain_length(node: ast.expr) -> int: + """ + Recursively determines how many consecutive calls exist in a chain + ending at 'node'. Each .something() is +1. + """ + if isinstance(node, ast.Call): + # We have a call, so that's +1 + if isinstance(node.func, ast.Attribute): + # The chain might continue if node.func.value is also a call + return 1 + compute_chain_length(node.func.value) + else: + return 1 + elif isinstance(node, ast.Attribute): + # If it's just an attribute (like `details` or `obj.x`), + # we keep looking up the chain but *don’t increment*, + # because we only count calls. + return compute_chain_length(node.value) + else: + # If it's a Name or something else, we stop + return 0 + + +def detect_long_message_chain( + file_path: Path, tree: ast.AST, threshold: int = 5 +) -> list[LMCSmell]: + """ + Detects long message chains in the given Python code. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + threshold (int): The minimum number of chained method calls to flag as a long chain. Default is 5. + + Returns: + list[Smell]: A list of Smell objects, each containing details about the detected long chains. + """ + # Initialize an empty list to store detected Smell objects + results: list[LMCSmell] = [] + used_lines = set() + + # Walk through the AST to find method calls and attribute chains + for node in ast.walk(tree): + # Check only method calls (Call node whose func is an Attribute) + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + length = compute_chain_length(node) + if length >= threshold: + line = node.lineno + # Make sure we haven’t already reported on this line + if line not in used_lines: + used_lines.add(line) + + message = f"Method chain too long ({length}/{threshold})" + # Create the smell object + smell = LMCSmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol="long-message-chain", + message=message, + messageId=CustomSmell.LONG_MESSAGE_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=node.lineno, + endLine=node.end_lineno, + column=node.col_offset, + endColumn=node.end_col_offset, + ) + ], + additionalInfo=AdditionalInfo(), + ) + results.append(smell) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py new file mode 100644 index 00000000..01c893c6 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_repeated_calls.py @@ -0,0 +1,68 @@ +import ast +from collections import defaultdict +from pathlib import Path + +import astor + +from ...data_types.custom_fields import CRCInfo, Occurence + +from ...data_types.smell import CRCSmell + +from ...utils.smell_enums import CustomSmell + + +def detect_repeated_calls(file_path: Path, tree: ast.AST, threshold: int = 3): + results: list[CRCSmell] = [] + + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.For, ast.While)): + call_counts: dict[str, list[ast.Call]] = defaultdict(list) + modified_lines = set() + + for subnode in ast.walk(node): + if isinstance(subnode, (ast.Assign, ast.AugAssign)): + # targets = [target.id for target in getattr(subnode, "targets", []) if isinstance(target, ast.Name)] + modified_lines.add(subnode.lineno) + + for subnode in ast.walk(node): + if isinstance(subnode, ast.Call): + callString = astor.to_source(subnode).strip() + call_counts[callString].append(subnode) + + for callString, occurrences in call_counts.items(): + if len(occurrences) >= threshold: + skip_due_to_modification = any( + line in modified_lines + for start_line, end_line in zip( + [occ.lineno for occ in occurrences[:-1]], + [occ.lineno for occ in occurrences[1:]], + ) + for line in range(start_line + 1, end_line) + ) + + if skip_due_to_modification: + continue + + smell = CRCSmell( + path=str(file_path), + type="performance", + obj=None, + module=file_path.stem, + symbol="cached-repeated-calls", + message=f"Repeated function call detected ({len(occurrences)}/{threshold}). Consider caching the result: {callString}", + messageId=CustomSmell.CACHE_REPEATED_CALLS.value, + confidence="HIGH" if len(occurrences) > threshold else "MEDIUM", + occurences=[ + Occurence( + line=occ.lineno, + endLine=occ.end_lineno, + column=occ.col_offset, + endColumn=occ.end_col_offset, + ) + for occ in occurrences + ], + additionalInfo=CRCInfo(repetitions=len(occurrences), callString=callString), + ) + results.append(smell) + + return results diff --git a/src/ecooptimizer/analyzers/ast_analyzers/detect_unused_variables_and_attributes.py b/src/ecooptimizer/analyzers/ast_analyzers/detect_unused_variables_and_attributes.py new file mode 100644 index 00000000..60bbea53 --- /dev/null +++ b/src/ecooptimizer/analyzers/ast_analyzers/detect_unused_variables_and_attributes.py @@ -0,0 +1,121 @@ +import ast +from pathlib import Path + +from ...utils.smell_enums import CustomSmell + +from ...data_types.custom_fields import AdditionalInfo, Occurence +from ...data_types.smell import UVASmell + + +def detect_unused_variables_and_attributes(file_path: Path, tree: ast.AST) -> list[UVASmell]: + """ + Detects unused variables and class attributes in the given Python code. + + Args: + file_path (Path): The file path to analyze. + tree (ast.AST): The Abstract Syntax Tree (AST) of the source code. + + Returns: + list[Smell]: A list of Smell objects containing details about detected unused variables or attributes. + """ + # Store variable and attribute declarations and usage + results: list[UVASmell] = [] + declared_vars = set() + used_vars = set() + + # Helper function to gather declared variables (including class attributes) + def gather_declarations(node: ast.AST): + """ + Identifies declared variables or class attributes. + + Args: + node (ast.AST): The AST node to analyze. + """ + # For assignment statements (variables or class attributes) + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): # Simple variable + declared_vars.add(target.id) + elif isinstance(target, ast.Attribute): # Class attribute + declared_vars.add(f"{target.value.id}.{target.attr}") # type: ignore + + # For class attribute assignments (e.g., self.attribute) + elif isinstance(node, ast.ClassDef): + for class_node in ast.walk(node): + if isinstance(class_node, ast.Assign): + for target in class_node.targets: + if isinstance(target, ast.Name): + declared_vars.add(target.id) + elif isinstance(target, ast.Attribute): + declared_vars.add(f"{target.value.id}.{target.attr}") # type: ignore + + # Helper function to gather used variables and class attributes + def gather_usages(node: ast.AST): + """ + Identifies variables or class attributes that are used. + + Args: + node (ast.AST): The AST node to analyze. + """ + if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load): # Variable usage + used_vars.add(node.id) + elif isinstance(node, ast.Attribute) and isinstance(node.ctx, ast.Load): # Attribute usage + # Check if the attribute is accessed as `self.attribute` + if isinstance(node.value, ast.Name) and node.value.id == "self": + # Only add to used_vars if it’s in the form of `self.attribute` + used_vars.add(f"self.{node.attr}") + + # Gather declared and used variables + for node in ast.walk(tree): + gather_declarations(node) + gather_usages(node) + + # Detect unused variables by finding declared variables not in used variables + unused_vars = declared_vars - used_vars + + for var in unused_vars: + # Locate the line number for each unused variable or attribute + line_no, column_no = 0, 0 + symbol = "" + for node in ast.walk(tree): + if isinstance(node, ast.Name) and node.id == var: + line_no = node.lineno + column_no = node.col_offset + symbol = "unused-variable" + break + elif ( + isinstance(node, ast.Attribute) + and f"self.{node.attr}" == var + and isinstance(node.value, ast.Name) + and node.value.id == "self" + ): + line_no = node.lineno + column_no = node.col_offset + symbol = "unused-attribute" + break + + # Create a Smell object for the unused variable or attribute + smell = UVASmell( + path=str(file_path), + module=file_path.stem, + obj=None, + type="convention", + symbol=symbol, + message=f"Unused variable or attribute '{var}'", + messageId=CustomSmell.UNUSED_VAR_OR_ATTRIBUTE.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=line_no, + endLine=None, + column=column_no, + endColumn=None, + ) + ], + additionalInfo=AdditionalInfo(), + ) + + results.append(smell) + + # Return the list of detected Smell objects + return results diff --git a/src/ecooptimizer/analyzers/astroid_analyzer.py b/src/ecooptimizer/analyzers/astroid_analyzer.py new file mode 100644 index 00000000..e2622c4d --- /dev/null +++ b/src/ecooptimizer/analyzers/astroid_analyzer.py @@ -0,0 +1,32 @@ +from typing import Callable, Any +from pathlib import Path +from astroid import nodes, parse + + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class AstroidAnalyzer(Analyzer): + def analyze( + self, + file_path: Path, + extra_options: list[ + tuple[ + Callable[[Path, nodes.Module], list[Smell]], + dict[str, Any], + ] + ], + ): + smells_data: list[Smell] = [] + + source_code = file_path.read_text() + + tree = parse(source_code) + + for detector, params in extra_options: + if callable(detector): + result = detector(file_path, tree, **params) + smells_data.extend(result) + + return smells_data diff --git a/src/testing/__init__.py b/src/ecooptimizer/analyzers/astroid_analyzers/__init__.py similarity index 100% rename from src/testing/__init__.py rename to src/ecooptimizer/analyzers/astroid_analyzers/__init__.py diff --git a/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py b/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py new file mode 100644 index 00000000..442c6452 --- /dev/null +++ b/src/ecooptimizer/analyzers/astroid_analyzers/detect_string_concat_in_loop.py @@ -0,0 +1,266 @@ +from pathlib import Path +import re +from astroid import nodes, util, parse, AttributeInferenceError + +from ...data_types.custom_fields import Occurence, SCLInfo +from ...data_types.smell import SCLSmell +from ...utils.smell_enums import CustomSmell + + +def detect_string_concat_in_loop(file_path: Path, tree: nodes.Module): + """ + Detects string concatenation inside loops within a Python AST tree. + + Parameters: + file_path (Path): The file path to analyze. + tree (nodes.Module): The parsed AST tree of the Python code. + + Returns: + list[dict]: A list of dictionaries containing details about detected string concatenation smells. + """ + smells: list[SCLSmell] = [] + in_loop_counter = 0 + current_loops: list[nodes.NodeNG] = [] + # current_semlls = { var_name : ( index of smell, index of loop )} + current_smells: dict[str, tuple[int, int]] = {} + + def create_smell(node: nodes.Assign): + nonlocal current_loops, current_smells + + if node.lineno and node.col_offset: + smells.append( + SCLSmell( + path=str(file_path), + module=file_path.name, + obj=None, + type="performance", + symbol="string-concat-loop", + message="String concatenation inside loop detected", + messageId=CustomSmell.STR_CONCAT_IN_LOOP.value, + confidence="UNDEFINED", + occurences=[create_smell_occ(node)], + additionalInfo=SCLInfo( + innerLoopLine=current_loops[ + current_smells[node.targets[0].as_string()][1] + ].lineno, # type: ignore + concatTarget=node.targets[0].as_string(), + ), + ) + ) + + def create_smell_occ(node: nodes.Assign | nodes.AugAssign) -> Occurence: + return Occurence( + line=node.lineno, # type: ignore + endLine=node.end_lineno, + column=node.col_offset, # type: ignore + endColumn=node.end_col_offset, + ) + + def visit(node: nodes.NodeNG): + nonlocal smells, in_loop_counter, current_loops, current_smells + + if isinstance(node, (nodes.For, nodes.While)): + in_loop_counter += 1 + current_loops.append(node) + for stmt in node.body: + visit(stmt) + + in_loop_counter -= 1 + + current_smells = { + key: val for key, val in current_smells.items() if val[1] != in_loop_counter + } + current_loops.pop() + + elif in_loop_counter > 0 and isinstance(node, nodes.Assign): + target = None + value = None + + if len(node.targets) == 1 > 1: + return + + target = node.targets[0] + value = node.value + + if target and isinstance(value, nodes.BinOp) and value.op == "+": + if ( + target.as_string() not in current_smells + and is_string_type(node) + and is_concatenating_with_self(value, target) + and is_not_referenced(node) + ): + current_smells[target.as_string()] = ( + len(smells), + in_loop_counter - 1, + ) + create_smell(node) + elif target.as_string() in current_smells and is_concatenating_with_self( + value, target + ): + smell_id = current_smells[target.as_string()][0] + smells[smell_id].occurences.append(create_smell_occ(node)) + else: + for child in node.get_children(): + visit(child) + + def is_not_referenced(node: nodes.Assign): + nonlocal current_loops + + loop_source_str = current_loops[-1].as_string() + loop_source_str = loop_source_str.replace(node.as_string(), "", 1) + lines = loop_source_str.splitlines() + for line in lines: + if ( + line.find(node.targets[0].as_string()) != -1 + and re.search(rf"\b{re.escape(node.targets[0].as_string())}\b\s*=", line) is None + ): + return False + return True + + def is_concatenating_with_self(binop_node: nodes.BinOp, target: nodes.NodeNG): + """Check if the BinOp node includes the target variable being added.""" + + def is_same_variable(var1: nodes.NodeNG, var2: nodes.NodeNG): + if isinstance(var1, nodes.Name) and isinstance(var2, nodes.AssignName): + return var1.name == var2.name + if isinstance(var1, nodes.Attribute) and isinstance(var2, nodes.AssignAttr): + return var1.as_string() == var2.as_string() + if isinstance(var1, nodes.Subscript) and isinstance(var2, nodes.Subscript): + if isinstance(var1.slice, nodes.Const) and isinstance(var2.slice, nodes.Const): + return var1.as_string() == var2.as_string() + if isinstance(var1, nodes.BinOp) and var1.op == "+": + return is_same_variable(var1.left, target) or is_same_variable(var1.right, target) + return False + + left, right = binop_node.left, binop_node.right + return is_same_variable(left, target) or is_same_variable(right, target) + + def is_string_type(node: nodes.Assign) -> bool: + target = node.targets[0] + + # Check type hints first + if has_type_hints_str(node, target): + return True + + # Infer types + for inferred in target.infer(): + if inferred.repr_name() == "str": + return True + if isinstance(inferred, util.UninferableBase): + print(f"here: {node}") + if has_str_format(node.value) or has_str_interpolation(node.value): + return True + for var in node.value.nodes_of_class( + (nodes.Name, nodes.Attribute, nodes.Subscript) + ): + if var.as_string() == target.as_string(): + for inferred_target in var.infer(): + if inferred_target.repr_name() == "str": + return True + + print(f"Checking type hints for {var}") + if has_type_hints_str(node, var): + return True + + return False + + def has_type_hints_str(context: nodes.NodeNG, target: nodes.NodeNG) -> bool: + """Checks if a variable has an explicit type hint for `str`""" + parent = context.scope() + + # Function argument type hints + if isinstance(parent, nodes.FunctionDef) and parent.args.args: + for arg, ann in zip(parent.args.args, parent.args.annotations): + print(f"arg: {arg}, target: {target}, ann: {ann}") + if arg.name == target.as_string() and ann and ann.as_string() == "str": + return True + + # Class attributes (annotations in class scope or __init__) + if "self." in target.as_string(): + class_def = parent.frame() + if not isinstance(class_def, nodes.ClassDef): + class_def = next( + ( + ancestor + for ancestor in context.node_ancestors() + if isinstance(ancestor, nodes.ClassDef) + ), + None, + ) + + if class_def: + attr_name = target.as_string().replace("self.", "") + try: + for attr in class_def.instance_attr(attr_name): + if ( + isinstance(attr, nodes.AnnAssign) + and attr.annotation.as_string() == "str" + ): + return True + if any(inf.repr_name() == "str" for inf in attr.infer()): + return True + except AttributeInferenceError: + pass + + # Global/scope variable annotations before assignment + for child in parent.nodes_of_class((nodes.AnnAssign, nodes.Assign)): + if child == context: + break + if ( + isinstance(child, nodes.AnnAssign) + and child.target.as_string() == target.as_string() + ): + return child.annotation.as_string() == "str" + print("checking var types") + if isinstance(child, nodes.Assign) and is_string_type(child): + return True + + return False + + def has_str_format(node: nodes.NodeNG): + if isinstance(node, nodes.BinOp) and node.op == "+": + str_repr = node.as_string() + match = re.search("{.*}", str_repr) + if match: + return True + + return False + + def has_str_interpolation(node: nodes.NodeNG): + if isinstance(node, nodes.BinOp) and node.op == "+": + str_repr = node.as_string() + match = re.search("%[a-z]", str_repr) + if match: + return True + return False + + def transform_augassign_to_assign(code_file: str): + """ + Changes all AugAssign occurences to Assign in a code file. + + :param code_file: The source code file as a string + :return: The same string source code with all AugAssign stmts changed to Assign + """ + str_code = code_file.splitlines() + + for i in range(len(str_code)): + eq_col = str_code[i].find(" +=") + + if eq_col == -1: + continue + + target_var = str_code[i][0:eq_col].strip() + + # Replace '+=' with '=' to form an Assign string + str_code[i] = str_code[i].replace("+=", f"= {target_var} +", 1) + + return "\n".join(str_code) + + # Change all AugAssigns to Assigns + tree = parse(transform_augassign_to_assign(file_path.read_text())) + + # Start traversal + for child in tree.get_children(): + visit(child) + + return smells diff --git a/src/ecooptimizer/analyzers/base_analyzer.py b/src/ecooptimizer/analyzers/base_analyzer.py new file mode 100644 index 00000000..a20673f4 --- /dev/null +++ b/src/ecooptimizer/analyzers/base_analyzer.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + + +from ..data_types.smell import Smell + + +class Analyzer(ABC): + @abstractmethod + def analyze(self, file_path: Path, extra_options: list[Any]) -> list[Smell]: + pass diff --git a/src/ecooptimizer/analyzers/pylint_analyzer.py b/src/ecooptimizer/analyzers/pylint_analyzer.py new file mode 100644 index 00000000..e11f2e22 --- /dev/null +++ b/src/ecooptimizer/analyzers/pylint_analyzer.py @@ -0,0 +1,61 @@ +from io import StringIO +import json +from pathlib import Path +from pylint.lint import Run +from pylint.reporters.json_reporter import JSON2Reporter + +from ..config import CONFIG + +from ..data_types.custom_fields import AdditionalInfo, Occurence + +from .base_analyzer import Analyzer +from ..data_types.smell import Smell + + +class PylintAnalyzer(Analyzer): + def _build_smells(self, pylint_smells: dict): # type: ignore + """Casts initial list of pylint smells to the Eco Optimizer's Smell configuration.""" + smells: list[Smell] = [] + + for smell in pylint_smells: + smells.append( + Smell( + confidence=smell["confidence"], + message=smell["message"], + messageId=smell["messageId"], + module=smell["module"], + obj=smell["obj"], + path=smell["absolutePath"], + symbol=smell["symbol"], + type=smell["type"], + occurences=[ + Occurence( + line=smell["line"], + endLine=smell["endLine"], + column=smell["column"], + endColumn=smell["endColumn"], + ) + ], + additionalInfo=AdditionalInfo(), + ) + ) + + return smells + + def analyze(self, file_path: Path, extra_options: list[str]): + smells_data: list[Smell] = [] + pylint_options = [str(file_path), *extra_options] + + with StringIO() as buffer: + reporter = JSON2Reporter(buffer) + + try: + Run(pylint_options, reporter=reporter, exit=False) + buffer.seek(0) + smells_data.extend(self._build_smells(json.loads(buffer.getvalue())["messages"])) + except json.JSONDecodeError as e: + CONFIG["detectLogger"].error(f"❌ Failed to parse JSON output from pylint: {e}") # type: ignore + except Exception as e: + CONFIG["detectLogger"].error(f"❌ An error occurred during pylint analysis: {e}") # type: ignore + + return smells_data diff --git a/src/utils/__init__.py b/src/ecooptimizer/api/__init__.py similarity index 100% rename from src/utils/__init__.py rename to src/ecooptimizer/api/__init__.py diff --git a/src/ecooptimizer/api/__main__.py b/src/ecooptimizer/api/__main__.py new file mode 100644 index 00000000..aa1f1713 --- /dev/null +++ b/src/ecooptimizer/api/__main__.py @@ -0,0 +1,57 @@ +import logging +import sys +import uvicorn + +from .app import app + +from ..config import CONFIG + + +class HealthCheckFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + return "/health" not in record.getMessage() + + +# Apply the filter to Uvicorn's access logger +logging.getLogger("uvicorn.access").addFilter(HealthCheckFilter()) + + +def start(): + # ANSI codes + RESET = "\u001b[0m" + BLUE = "\u001b[36m" + PURPLE = "\u001b[35m" + + mode_message = f"{CONFIG['mode'].upper()} MODE" + msg_len = len(mode_message) + + print(f"\n\t\t\t***{'*'*msg_len}***") + print(f"\t\t\t* {BLUE}{mode_message}{RESET} *") + print(f"\t\t\t***{'*'*msg_len}***\n") + if CONFIG["mode"] == "production": + print(f"{PURPLE}hint: add --dev flag at the end to ignore energy checks\n") + + logging.info("πŸš€ Running EcoOptimizer Application...") + logging.info(f"{'=' * 100}\n") + uvicorn.run( + app, + host="127.0.0.1", + port=8000, + log_level="info", + access_log=True, + timeout_graceful_shutdown=2, + ) + + +def main(): + CONFIG["mode"] = "development" if "--dev" in sys.argv else "production" + start() + + +def dev(): + CONFIG["mode"] = "development" + start() + + +if __name__ == "__main__": + main() diff --git a/src/ecooptimizer/api/app.py b/src/ecooptimizer/api/app.py new file mode 100644 index 00000000..bace8451 --- /dev/null +++ b/src/ecooptimizer/api/app.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI +from .routes import RefactorRouter, DetectRouter, LogRouter + + +app = FastAPI(title="Ecooptimizer") + +# Include API routes +app.include_router(RefactorRouter) +app.include_router(DetectRouter) +app.include_router(LogRouter) + + +@app.get("/health") +async def ping(): + return {"status": "ok"} diff --git a/src/ecooptimizer/api/routes/__init__.py b/src/ecooptimizer/api/routes/__init__.py new file mode 100644 index 00000000..b0b59465 --- /dev/null +++ b/src/ecooptimizer/api/routes/__init__.py @@ -0,0 +1,5 @@ +from .refactor_smell import router as RefactorRouter +from .detect_smells import router as DetectRouter +from .show_logs import router as LogRouter + +__all__ = ["DetectRouter", "LogRouter", "RefactorRouter"] diff --git a/src/ecooptimizer/api/routes/detect_smells.py b/src/ecooptimizer/api/routes/detect_smells.py new file mode 100644 index 00000000..fb86357c --- /dev/null +++ b/src/ecooptimizer/api/routes/detect_smells.py @@ -0,0 +1,66 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +import time + +from ...config import CONFIG + +from ...analyzers.analyzer_controller import AnalyzerController +from ...data_types.smell import Smell + +router = APIRouter() + +analyzer_controller = AnalyzerController() + + +class SmellRequest(BaseModel): + file_path: str + enabled_smells: list[str] + + +@router.post("/smells", response_model=list[Smell]) +def detect_smells(request: SmellRequest): + """ + Detects code smells in a given file, logs the process, and measures execution time. + """ + + CONFIG["detectLogger"].info(f"{'=' * 100}") + CONFIG["detectLogger"].info(f"πŸ“‚ Received smell detection request for: {request.file_path}") + + start_time = time.time() + + try: + file_path_obj = Path(request.file_path) + + if not file_path_obj.exists(): + CONFIG["detectLogger"].error(f"❌ File does not exist: {file_path_obj}") + raise FileNotFoundError(f"File not found: {file_path_obj}") + + CONFIG["detectLogger"].debug( + f"πŸ”Ž Enabled smells: {', '.join(request.enabled_smells) if request.enabled_smells else 'None'}" + ) + + # Run analysis + CONFIG["detectLogger"].info(f"🎯 Running analysis on: {file_path_obj}") + smells_data = analyzer_controller.run_analysis(file_path_obj, request.enabled_smells) + + execution_time = round(time.time() - start_time, 2) + CONFIG["detectLogger"].info(f"πŸ“Š Execution Time: {execution_time} seconds") + + CONFIG["detectLogger"].info( + f"🏁 Analysis completed for {file_path_obj}. {len(smells_data)} smells found." + ) + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + + return smells_data + + except FileNotFoundError as e: + CONFIG["detectLogger"].error(f"❌ File not found: {e}") + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=404, detail=str(e)) from e + + except Exception as e: + CONFIG["detectLogger"].error(f"❌ Error during smell detection: {e!s}") + CONFIG["detectLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=500, detail="Internal server error") from e diff --git a/src/ecooptimizer/api/routes/refactor_smell.py b/src/ecooptimizer/api/routes/refactor_smell.py new file mode 100644 index 00000000..ae762401 --- /dev/null +++ b/src/ecooptimizer/api/routes/refactor_smell.py @@ -0,0 +1,192 @@ +# pyright: reportOptionalMemberAccess=false +import shutil +import math +from pathlib import Path +from tempfile import mkdtemp +import traceback +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +from typing import Any, Optional + +from ...config import CONFIG +from ...analyzers.analyzer_controller import AnalyzerController +from ...exceptions import EnergySavingsError, RefactoringError, remove_readonly +from ...refactorers.refactorer_controller import RefactorerController +from ...measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter +from ...data_types.smell import Smell + +router = APIRouter() +analyzer_controller = AnalyzerController() +refactorer_controller = RefactorerController() +energy_meter = CodeCarbonEnergyMeter() + + +class ChangedFile(BaseModel): + original: str + refactored: str + + +class RefactoredData(BaseModel): + tempDir: str + targetFile: ChangedFile + energySaved: Optional[float] = None + affectedFiles: list[ChangedFile] + + +class RefactorRqModel(BaseModel): + source_dir: str + smell: Smell + + +class RefactorResModel(BaseModel): + refactoredData: Optional[RefactoredData] = None + updatedSmells: list[Smell] + + +@router.post("/refactor", response_model=RefactorResModel) +def refactor(request: RefactorRqModel): + """Handles the refactoring process for a given smell.""" + CONFIG["refactorLogger"].info(f"{'=' * 100}") + CONFIG["refactorLogger"].info("πŸ”„ Received refactor request.") + + try: + CONFIG["refactorLogger"].info( + f"πŸ” Analyzing smell: {request.smell.symbol} in {request.source_dir}" + ) + refactor_data, updated_smells = perform_refactoring(Path(request.source_dir), request.smell) + + CONFIG["refactorLogger"].info( + f"βœ… Refactoring process completed. Updated smells: {len(updated_smells)}" + ) + + if refactor_data: + refactor_data = clean_refactored_data(refactor_data) + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + return RefactorResModel(refactoredData=refactor_data, updatedSmells=updated_smells) + + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + return RefactorResModel(updatedSmells=updated_smells) + + except OSError as e: + CONFIG["refactorLogger"].error(f"❌ OS error: {e!s}") + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + CONFIG["refactorLogger"].error(f"❌ Refactoring error: {e!s}") + CONFIG["refactorLogger"].info(f"{'=' * 100}\n") + raise HTTPException(status_code=400, detail=str(e)) from e + + +def perform_refactoring(source_dir: Path, smell: Smell): + """Executes the refactoring process for a given smell.""" + target_file = Path(smell.path) + + CONFIG["refactorLogger"].info( + f"πŸš€ Starting refactoring for {smell.symbol} at line {smell.occurences[0].line} in {target_file}" + ) + + if not source_dir.is_dir(): + CONFIG["refactorLogger"].error(f"❌ Directory does not exist: {source_dir}") + raise OSError(f"Directory {source_dir} does not exist.") + + initial_emissions = measure_energy(target_file) + + if not initial_emissions: + CONFIG["refactorLogger"].error("❌ Could not retrieve initial emissions.") + raise RuntimeError("Could not retrieve initial emissions.") + + CONFIG["refactorLogger"].info(f"πŸ“Š Initial emissions: {initial_emissions} kg CO2") + + temp_dir = mkdtemp(prefix="ecooptimizer-") + source_copy = Path(temp_dir) / source_dir.name + target_file_copy = Path(str(target_file).replace(str(source_dir), str(source_copy), 1)) + + shutil.copytree(source_dir, source_copy, ignore=shutil.ignore_patterns(".git*")) + + modified_files = [] + try: + modified_files: list[Path] = refactorer_controller.run_refactorer( + target_file_copy, source_copy, smell + ) + except NotImplementedError: + print("Not implemented yet.") + except Exception as e: + print(f"An unexpected error occured: {e!s}") + traceback.print_exc() + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise RefactoringError(str(target_file), str(e)) from e + + final_emissions = measure_energy(target_file_copy) + + if not final_emissions: + print("❌ Could not retrieve final emissions. Discarding refactoring.") + + CONFIG["refactorLogger"].error( + "❌ Could not retrieve final emissions. Discarding refactoring." + ) + + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise RuntimeError("Could not retrieve final emissions.") + + if CONFIG["mode"] == "production" and final_emissions >= initial_emissions: + CONFIG["refactorLogger"].info(f"πŸ“Š Final emissions: {final_emissions} kg CO2") + CONFIG["refactorLogger"].info("⚠️ No measured energy savings. Discarding refactoring.") + + print("❌ Could not retrieve final emissions. Discarding refactoring.") + + shutil.rmtree(temp_dir, onerror=remove_readonly) + raise EnergySavingsError(str(target_file), "Energy was not saved after refactoring.") + + CONFIG["refactorLogger"].info( + f"βœ… Energy saved! Initial: {initial_emissions}, Final: {final_emissions}" + ) + + refactor_data = { + "tempDir": temp_dir, + "targetFile": { + "original": str(target_file.resolve()), + "refactored": str(target_file_copy.resolve()), + }, + "energySaved": initial_emissions - final_emissions + if not math.isnan(initial_emissions - final_emissions) + else None, + "affectedFiles": [ + { + "original": str(file.resolve()).replace( + str(source_copy.resolve()), str(source_dir.resolve()) + ), + "refactored": str(file.resolve()), + } + for file in modified_files + ], + } + + updated_smells = analyzer_controller.run_analysis(target_file_copy) + return refactor_data, updated_smells + + +def measure_energy(file: Path): + energy_meter.measure_energy(file) + return energy_meter.emissions + + +def clean_refactored_data(refactor_data: dict[str, Any]): + """Ensures the refactored data is correctly structured and handles missing fields.""" + try: + return RefactoredData( + tempDir=refactor_data.get("tempDir", ""), + targetFile=ChangedFile( + original=refactor_data["targetFile"].get("original", ""), + refactored=refactor_data["targetFile"].get("refactored", ""), + ), + energySaved=refactor_data.get("energySaved", None), + affectedFiles=[ + ChangedFile( + original=file.get("original", ""), + refactored=file.get("refactored", ""), + ) + for file in refactor_data.get("affectedFiles", []) + ], + ) + except KeyError as e: + CONFIG["refactorLogger"].error(f"❌ Missing expected key in refactored data: {e}") + raise HTTPException(status_code=500, detail=f"Missing key: {e}") from e diff --git a/src/ecooptimizer/api/routes/show_logs.py b/src/ecooptimizer/api/routes/show_logs.py new file mode 100644 index 00000000..d9b1b647 --- /dev/null +++ b/src/ecooptimizer/api/routes/show_logs.py @@ -0,0 +1,90 @@ +# pyright: reportOptionalMemberAccess=false + +import asyncio +from pathlib import Path +from fastapi import APIRouter, WebSocketException +from fastapi.websockets import WebSocketState, WebSocket, WebSocketDisconnect +from pydantic import BaseModel + +from ...utils.output_manager import LoggingManager +from ...config import CONFIG + +router = APIRouter() + + +class LogInit(BaseModel): + log_dir: str + + +@router.post("/logs/init") +def initialize_logs(log_init: LogInit): + try: + loggingManager = LoggingManager(Path(log_init.log_dir), CONFIG["mode"] == "production") + CONFIG["loggingManager"] = loggingManager + CONFIG["detectLogger"] = loggingManager.loggers["detect"] + CONFIG["refactorLogger"] = loggingManager.loggers["refactor"] + + return {"message": "Logging initialized succesfully."} + except Exception as e: + raise WebSocketException(code=500, reason=str(e)) from e + + +@router.websocket("/logs/main") +async def websocket_main_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["main"]) + + +@router.websocket("/logs/detect") +async def websocket_detect_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["detect"]) + + +@router.websocket("/logs/refactor") +async def websocket_refactor_logs(websocket: WebSocket): + await websocket_log_stream(websocket, CONFIG["loggingManager"].log_files["refactor"]) + + +async def listen_for_disconnect(websocket: WebSocket): + """Listens for client disconnects.""" + try: + while True: + await websocket.receive() + + if websocket.client_state == WebSocketState.DISCONNECTED: + raise WebSocketDisconnect() + except WebSocketDisconnect: + print("WebSocket disconnected from client.") + raise + except Exception as e: + print(f"Unexpected error in listener: {e}") + + +async def websocket_log_stream(websocket: WebSocket, log_file: Path): + """Streams log file content via WebSocket.""" + await websocket.accept() + + # Start background task to listen for disconnect + listener_task = asyncio.create_task(listen_for_disconnect(websocket)) + + try: + with log_file.open(encoding="utf-8") as file: + file.seek(0, 2) # Start at file end + while not listener_task.done(): + if websocket.application_state != WebSocketState.CONNECTED: + raise WebSocketDisconnect(reason="Connection closed") + + line = file.readline() + if line: + await websocket.send_text(line) + else: + await asyncio.sleep(0.5) + except FileNotFoundError: + await websocket.send_text("Error: Log file not found.") + except WebSocketDisconnect as e: + print(e.reason) + except Exception as e: + print(f"Unexpected error: {e}") + finally: + listener_task.cancel() + if websocket.client_state != WebSocketState.DISCONNECTED: + await websocket.close() diff --git a/src/ecooptimizer/config.py b/src/ecooptimizer/config.py new file mode 100644 index 00000000..af693926 --- /dev/null +++ b/src/ecooptimizer/config.py @@ -0,0 +1,20 @@ +from logging import Logger +import logging +from typing import TypedDict + +from .utils.output_manager import LoggingManager + + +class Config(TypedDict): + mode: str + loggingManager: LoggingManager | None + detectLogger: Logger + refactorLogger: Logger + + +CONFIG: Config = { + "mode": "production", + "loggingManager": None, + "detectLogger": logging.getLogger("detect"), + "refactorLogger": logging.getLogger("refactor"), +} diff --git a/src/ecooptimizer/data_types/__init__.py b/src/ecooptimizer/data_types/__init__.py new file mode 100644 index 00000000..1c130bb6 --- /dev/null +++ b/src/ecooptimizer/data_types/__init__.py @@ -0,0 +1,36 @@ +from .custom_fields import ( + AdditionalInfo, + CRCInfo, + Occurence, + SCLInfo, +) + +from .smell import ( + Smell, + CRCSmell, + SCLSmell, + LECSmell, + LLESmell, + LMCSmell, + LPLSmell, + UVASmell, + MIMSmell, + UGESmell, +) + +__all__ = [ + "AdditionalInfo", + "CRCInfo", + "CRCSmell", + "LECSmell", + "LLESmell", + "LMCSmell", + "LPLSmell", + "MIMSmell", + "Occurence", + "SCLInfo", + "SCLSmell", + "Smell", + "UGESmell", + "UVASmell", +] diff --git a/src/ecooptimizer/data_types/custom_fields.py b/src/ecooptimizer/data_types/custom_fields.py new file mode 100644 index 00000000..f57000f8 --- /dev/null +++ b/src/ecooptimizer/data_types/custom_fields.py @@ -0,0 +1,26 @@ +from typing import Optional +from pydantic import BaseModel + + +class Occurence(BaseModel): + line: int + endLine: int | None + column: int + endColumn: int | None + + +class AdditionalInfo(BaseModel): + innerLoopLine: Optional[int] = None + concatTarget: Optional[str] = None + repetitions: Optional[int] = None + callString: Optional[str] = None + + +class CRCInfo(AdditionalInfo): + callString: str # type: ignore + repetitions: int # type: ignore + + +class SCLInfo(AdditionalInfo): + innerLoopLine: int # type: ignore + concatTarget: str # type: ignore diff --git a/src/ecooptimizer/data_types/smell.py b/src/ecooptimizer/data_types/smell.py new file mode 100644 index 00000000..a12401ce --- /dev/null +++ b/src/ecooptimizer/data_types/smell.py @@ -0,0 +1,50 @@ +from pydantic import BaseModel +from typing import Optional + +from .custom_fields import CRCInfo, Occurence, AdditionalInfo, SCLInfo + + +class Smell(BaseModel): + """ + Represents a code smell detected in a source file, including its location, type, and related metadata. + + Attributes: + confidence (str): The level of confidence for the smell detection (e.g., "high", "medium", "low"). + message (str): A descriptive message explaining the nature of the smell. + messageId (str): A unique identifier for the specific message or warning related to the smell. + module (str): The name of the module or component in which the smell is located. + obj (str): The specific object (e.g., function, class) associated with the smell. + path (str): The relative path to the source file from the project root. + symbol (str): The symbol or code construct (e.g., variable, method) involved in the smell. + type (str): The type or category of the smell (e.g., "complexity", "duplication"). + occurences (list[Occurence]): A list of individual occurences of a same smell, contains positional info. + additionalInfo (AddInfo): (Optional) Any custom information m for a type of smell + """ + + confidence: str + message: str + messageId: str + module: str + obj: str | None + path: str + symbol: str + type: str + occurences: list[Occurence] + additionalInfo: Optional[AdditionalInfo] = None + + +class CRCSmell(Smell): + additionalInfo: CRCInfo # type: ignore + + +class SCLSmell(Smell): + additionalInfo: SCLInfo # type: ignore + + +LECSmell = Smell +LLESmell = Smell +LMCSmell = Smell +LPLSmell = Smell +UVASmell = Smell +MIMSmell = Smell +UGESmell = Smell diff --git a/src/ecooptimizer/data_types/smell_record.py b/src/ecooptimizer/data_types/smell_record.py new file mode 100644 index 00000000..31736939 --- /dev/null +++ b/src/ecooptimizer/data_types/smell_record.py @@ -0,0 +1,23 @@ +from typing import Any, Callable, TypedDict + +from ..refactorers.base_refactorer import BaseRefactorer + + +class SmellRecord(TypedDict): + """ + Represents a code smell configuration used for analysis and refactoring details. + + Attributes: + id (str): The unique identifier for the specific smell or rule. + enabled (bool): Indicates whether the smell detection is enabled. + analyzer_method (Any): The method used for analysis. Could be a string (e.g., "pylint") or a Callable (for AST). + refactorer (Type[Any]): The class responsible for refactoring the detected smell. + analyzer_options (dict[str, Any]): Optional configuration options for the analyzer method. + """ + + id: str + enabled: bool + analyzer_method: str + checker: Callable | None # type: ignore + refactorer: type[BaseRefactorer] # type: ignore # Refers to a class, not an instance + analyzer_options: dict[str, Any] # type: ignore diff --git a/src/ecooptimizer/exceptions.py b/src/ecooptimizer/exceptions.py new file mode 100644 index 00000000..298a5327 --- /dev/null +++ b/src/ecooptimizer/exceptions.py @@ -0,0 +1,25 @@ +import os +import stat + + +class RefactoringError(Exception): + """Exception raised for errors that occured during the refcatoring process. + + Attributes: + targetFile -- file being refactored + message -- explanation of the error + """ + + def __init__(self, targetFile: str, message: str) -> None: + self.targetFile = targetFile + super().__init__(message) + + +class EnergySavingsError(RefactoringError): + pass + + +def remove_readonly(func, path, _): # noqa: ANN001 + # "Clear the readonly bit and reattempt the removal" + os.chmod(path, stat.S_IWRITE) # noqa: PTH101 + func(path) diff --git a/src/measurement/measurement_utils.py b/src/ecooptimizer/measurements/__init__.py similarity index 100% rename from src/measurement/measurement_utils.py rename to src/ecooptimizer/measurements/__init__.py diff --git a/src/ecooptimizer/measurements/base_energy_meter.py b/src/ecooptimizer/measurements/base_energy_meter.py new file mode 100644 index 00000000..425b1fc0 --- /dev/null +++ b/src/ecooptimizer/measurements/base_energy_meter.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from pathlib import Path + + +class BaseEnergyMeter(ABC): + def __init__(self): + """ + Base class for energy meters to measure the emissions of a given file. + + :param file_path: Path to the file to measure energy consumption. + :param logger: Logger instance to handle log messages. + """ + self.emissions = None + + @abstractmethod + def measure_energy(self, file_path: Path): + """ + Abstract method to measure the energy consumption of the specified file. + Must be implemented by subclasses. + """ + pass diff --git a/src/ecooptimizer/measurements/codecarbon_energy_meter.py b/src/ecooptimizer/measurements/codecarbon_energy_meter.py new file mode 100644 index 00000000..99c0aa83 --- /dev/null +++ b/src/ecooptimizer/measurements/codecarbon_energy_meter.py @@ -0,0 +1,80 @@ +import logging +import os +from pathlib import Path +import sys +import subprocess +import pandas as pd +from tempfile import TemporaryDirectory +from codecarbon import EmissionsTracker + +from .base_energy_meter import BaseEnergyMeter + + +class CodeCarbonEnergyMeter(BaseEnergyMeter): + def __init__(self): + """ + Initializes the CodeCarbonEnergyMeter with a file path and logger. + + :param file_path: Path to the file to measure energy consumption. + :param logger: Logger instance for logging events. + """ + super().__init__() + self.emissions_data = None + + def measure_energy(self, file_path: Path): + """ + Measures the carbon emissions for the specified file by running it with CodeCarbon. + Logs each step and stores the emissions data if available. + """ + logging.info(f"Starting CodeCarbon energy measurement on {file_path.name}") + + with TemporaryDirectory() as custom_temp_dir: + os.environ["TEMP"] = custom_temp_dir # For Windows + os.environ["TMPDIR"] = custom_temp_dir # For Unix-based systems + + # TODO: Save to logger so doesn't print to console + tracker = EmissionsTracker( + output_dir=custom_temp_dir, + allow_multiple_runs=True, + tracking_mode="process", + log_level="error", + ) # type: ignore + tracker.start() + + try: + subprocess.run( + [sys.executable, file_path], capture_output=True, text=True, check=True + ) + logging.info("CodeCarbon measurement completed successfully.") + except subprocess.CalledProcessError as e: + logging.error(f"Error executing file '{file_path}': {e}") + finally: + self.emissions = tracker.stop() + emissions_file = custom_temp_dir / Path("emissions.csv") + + if emissions_file.exists(): + self.emissions_data = self.extract_emissions_csv(emissions_file) + else: + logging.error( + "Emissions file was not created due to an error during execution." + ) + self.emissions_data = None + + def extract_emissions_csv(self, csv_file_path: Path): + """ + Extracts emissions data from a CSV file generated by CodeCarbon. + + :param csv_file_path: Path to the CSV file. + :return: Dictionary containing the last row of emissions data or None if an error occurs. + """ + str_csv_path = str(csv_file_path) + if csv_file_path.exists(): + try: + df = pd.read_csv(str_csv_path) + return df.to_dict(orient="records")[-1] + except Exception as e: + logging.info(f"Error reading file '{str_csv_path}': {e}") + return None + else: + logging.info(f"File '{str_csv_path}' does not exist.") + return None diff --git a/src/ecooptimizer/refactorers/__init__.py b/src/ecooptimizer/refactorers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/refactorers/base_refactorer.py b/src/ecooptimizer/refactorers/base_refactorer.py new file mode 100644 index 00000000..e0d0c3b7 --- /dev/null +++ b/src/ecooptimizer/refactorers/base_refactorer.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Generic, TypeVar + +from ..data_types.smell import Smell + +T = TypeVar("T", bound=Smell) + + +class BaseRefactorer(ABC, Generic[T]): + def __init__(self): + self.modified_files: list[Path] = [] + + @abstractmethod + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: T, + output_file: Path, + overwrite: bool = True, + ): + pass diff --git a/src/ecooptimizer/refactorers/concrete/__init__.py b/src/ecooptimizer/refactorers/concrete/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py b/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py new file mode 100644 index 00000000..cf7b3834 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/list_comp_any_all.py @@ -0,0 +1,96 @@ +import ast +from pathlib import Path +from asttokens import ASTTokens + +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import UGESmell + + +class UseAGeneratorRefactorer(BaseRefactorer[UGESmell]): + def __init__(self): + super().__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: UGESmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactors an unnecessary list comprehension by converting it to a generator expression. + Modifies the specified instance in the file directly if it results in lower emissions. + """ + line_number = smell.occurences[0].line + start_column = smell.occurences[0].column + end_column = smell.occurences[0].endColumn + + # Load the source file as a list of lines + with target_file.open() as file: + original_lines = file.readlines() + + # Check bounds for line number + if not (1 <= line_number <= len(original_lines)): + return + + # Extract the specific line to refactor + target_line = original_lines[line_number - 1] + + # Preserve the original indentation + leading_whitespace = target_line[: len(target_line) - len(target_line.lstrip())] + + # Remove leading whitespace for parsing + stripped_line = target_line.lstrip() + + # Parse the stripped line + try: + atok = ASTTokens(stripped_line, parse=True) + if not atok.tree: + return + target_ast = atok.tree + except (SyntaxError, ValueError): + return + + # modified = False + + # Traverse the AST and locate the list comprehension at the specified column range + for node in ast.walk(target_ast): + if isinstance(node, ast.ListComp): + # Check if end_col_offset exists and is valid + end_col_offset = getattr(node, "end_col_offset", None) + if end_col_offset is None: + continue + + # Check if the node matches the specified column range + if node.col_offset >= start_column - 1 and end_col_offset <= end_column: + # Calculate offsets relative to the original line + start_offset = node.col_offset + len(leading_whitespace) + end_offset = end_col_offset + len(leading_whitespace) + + # Check if parentheses are already present + if target_line[start_offset - 1] == "(" and target_line[end_offset] == ")": + # Parentheses already exist, avoid adding redundant ones + refactored_code = ( + target_line[:start_offset] + + f"{target_line[start_offset + 1 : end_offset - 1]}" + + target_line[end_offset:] + ) + else: + # Add parentheses explicitly if not already wrapped + refactored_code = ( + target_line[:start_offset] + + f"({target_line[start_offset + 1 : end_offset - 1]})" + + target_line[end_offset:] + ) + + original_lines[line_number - 1] = refactored_code + # modified = True + break + + if overwrite: + with target_file.open("w") as f: + f.writelines(original_lines) + else: + with output_file.open("w") as f: + f.writelines(original_lines) diff --git a/src/ecooptimizer/refactorers/concrete/long_element_chain.py b/src/ecooptimizer/refactorers/concrete/long_element_chain.py new file mode 100644 index 00000000..dc246e3d --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_element_chain.py @@ -0,0 +1,343 @@ +import ast +import json +from pathlib import Path +import re +from typing import Any, Optional + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import LECSmell + + +class DictAccess: + """Represents a dictionary access pattern found in code.""" + + def __init__( + self, + dictionary_name: str, + full_access: str, + nesting_level: int, + line_number: int, + col_offset: int, + path: Path, + node: ast.AST, + ): + self.dictionary_name = dictionary_name + self.full_access = full_access + self.nesting_level = nesting_level + self.col_offset = col_offset + self.line_number = line_number + self.path = path + self.node = node + + +class LongElementChainRefactorer(MultiFileRefactorer[LECSmell]): + """ + Refactors long element chains by flattening nested dictionaries. + Only implements flatten dictionary strategy as it proved most effective for energy savings. + """ + + def __init__(self): + super().__init__() + self.dict_name: set[str] = set() + self.access_patterns: set[DictAccess] = set() + self.min_value = float("inf") + self.dict_assignment: Optional[dict[str, Any]] = None + self.initial_parsing = True + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: LECSmell, + output_file: Path, # noqa: ARG002 + overwrite: bool = True, # noqa: ARG002 + ) -> None: + """Main refactoring method that processes the target file and related files.""" + self.target_file = target_file + line_number = smell.occurences[0].line + + tree = ast.parse(target_file.read_text()) + self._find_dict_names(tree, line_number) + + # Abort if dictionary access is too shallow + self.traverse_and_process(source_dir) + if self.min_value <= 1: + return + + self.initial_parsing = False + self.traverse_and_process(source_dir) + + def _find_dict_names(self, tree: ast.AST, line_number: int) -> None: + """Extract dictionary names from the AST at the given line number.""" + for node in ast.walk(tree): + if not ( + isinstance(node, ast.Subscript) + and hasattr(node, "lineno") + and node.lineno == line_number + ): + continue + + if isinstance(node.value, ast.Name): + self.dict_name.add(node.value.id) + else: + dict_name = self._extract_dict_name(node.value) + if dict_name: + self.dict_name.add(dict_name) + self.dict_name.add(dict_name.split(".")[-1]) + + def _extract_dict_name(self, node: ast.AST) -> Optional[str]: + """Extract dictionary name from attribute access chains.""" + while isinstance(node, ast.Subscript): + node = node.value + + if isinstance(node, ast.Attribute): + return f"{node.value.id}.{node.attr}" + return None + + def _process_file(self, file: Path): + tree = ast.parse(file.read_text()) + if self.initial_parsing: + self._find_access_pattern_in_file(tree, file) + else: + self.find_dict_assignment_in_file(tree) + if self._refactor_all_in_file(file): + return True + + return False + + # finds all access patterns in the file + def _find_access_pattern_in_file(self, tree: ast.AST, path: Path): + offset = set() + for node in ast.walk(tree): + if isinstance(node, ast.Subscript): # Check for dictionary access (Subscript) + dict_name, full_access, line_number, col_offset = self.extract_full_dict_access( + node + ) + + if (line_number, col_offset) in offset: + continue + offset.add((line_number, col_offset)) + + if dict_name.split(".")[-1] in self.dict_name: + nesting_level = self._count_nested_subscripts(node) + access = DictAccess( + dict_name, full_access, nesting_level, line_number, col_offset, path, node + ) + self.access_patterns.add(access) + print(self.access_patterns) + self.min_value = min(self.min_value, nesting_level) + + def extract_full_dict_access(self, node: ast.Subscript): + """Extracts the full dictionary access chain as a string.""" + access_chain = [] + curr = node + # Traverse nested subscripts to build access path + while isinstance(curr, ast.Subscript): + if isinstance(curr.slice, ast.Constant): # Python 3.8+ + access_chain.append(f"['{curr.slice.value}']") + curr = curr.value # Move to parent node + + # Get the dictionary root (can be a variable or an attribute) + if isinstance(curr, ast.Name): + dict_name = curr.id # Simple variable (e.g., "long_chain") + elif isinstance(curr, ast.Attribute) and isinstance(curr.value, ast.Name): + dict_name = f"{curr.value.id}.{curr.attr}" # Attribute access (e.g., "self.long_chain") + else: + dict_name = "UNKNOWN" + + full_access = f"{dict_name}{''.join(reversed(access_chain))}" + + return dict_name, full_access, curr.lineno, curr.col_offset + + def _count_nested_subscripts(self, node: ast.Subscript): + """ + Counts how many times a dictionary is accessed (nested Subscript nodes). + """ + level = 0 + curr = node + while isinstance(curr, ast.Subscript): + curr = curr.value # Move up the AST + level += 1 + return level + + def find_dict_assignment_in_file(self, tree: ast.AST): + """find the dictionary assignment from AST based on the dict name""" + + class DictVisitor(ast.NodeVisitor): + def visit_Assign(self_, node: ast.Assign): + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + # dictionary is a varibale + if ( + isinstance(node.targets[0], ast.Name) + and node.targets[0].id in self.dict_name + ): + dict_value = self.extract_dict_literal(node.value) + flattened_version = self.flatten_dict(dict_value) # type: ignore + self.dict_assignment = flattened_version + + # dictionary is an attribute + elif ( + isinstance(node.targets[0], ast.Attribute) + and node.targets[0].attr in self.dict_name + ): + dict_value = self.extract_dict_literal(node.value) + self.dict_assignment = self.flatten_dict(dict_value) # type: ignore + self_.generic_visit(node) + + DictVisitor().visit(tree) + + def extract_dict_literal(self, node: ast.AST): + """Convert AST dict literal to Python dict.""" + if isinstance(node, ast.Dict): + return { + self.extract_dict_literal(k) + if isinstance(k, ast.AST) + else k: self.extract_dict_literal(v) if isinstance(v, ast.AST) else v + for k, v in zip(node.keys, node.values) + } + elif isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + return node.id + return node + + def flatten_dict( + self, d: dict[str, Any], depth: int = 0, parent_key: str = "" + ) -> dict[str, Any]: + """Recursively flatten a nested dictionary.""" + + if depth >= self.min_value - 1: + # At max_depth, we return the current dictionary as flattened key-value pairs + items = {} + for k, v in d.items(): + new_key = f"{parent_key}_{k}" if parent_key else k + items[new_key] = v + return items + + items = {} + for k, v in d.items(): + new_key = f"{parent_key}_{k}" if parent_key else k + + if isinstance(v, dict): + # Recursively flatten the dictionary, increasing the depth + items.update(self.flatten_dict(v, depth + 1, new_key)) + else: + # If it's not a dictionary, just add it to the result + items[new_key] = v + + return items + + def generate_flattened_access(self, access_chain: list[str]) -> str: + """Generate flattened dictionary key only until given min_value.""" + + joined = "_".join(k.strip("'\"") for k in access_chain[: self.min_value]) + if not joined.endswith("']") or not joined.endswith('"]'): # Corrected to check for "']" + joined += "']" + remaining = access_chain[self.min_value :] # Keep the rest unchanged + + rest = "".join(f"[{key}]" for key in remaining) + + return f"{joined}" + rest + + def _refactor_all_in_file(self, file_path: Path): + """Refactor dictionary access patterns in a single file.""" + # Skip if no access patterns found + if not any(access.path == file_path for access in self.access_patterns): + return False + + source_code = file_path.read_text() + lines = source_code.split("\n") + line_modifications = self._collect_line_modifications(file_path) + + refactored_lines = self._apply_modifications(lines, line_modifications) + refactored_lines = self._update_dict_assignment(refactored_lines) + + # Write changes back to file + file_path.write_text("\n".join(refactored_lines)) + + return True + + def _collect_line_modifications(self, file_path: Path) -> dict[int, list[tuple[int, str, str]]]: + """Collect all modifications needed for each line.""" + modifications: dict[int, list[tuple[int, str, str]]] = {} + + for access in sorted(self.access_patterns, key=lambda a: (a.line_number, a.col_offset)): + if access.path != file_path: + continue + + access_chain = access.full_access.split("][") + for i in range(len(access_chain)): + access_chain[i] = access_chain[i].replace("]", "") + new_access = self.generate_flattened_access(access_chain) + + if access.line_number not in modifications: + modifications[access.line_number] = [] + modifications[access.line_number].append( + (access.col_offset, access.full_access, new_access) + ) + + return modifications + + def _apply_modifications( + self, lines: list[str], modifications: dict[int, list[tuple[int, str, str]]] + ) -> list[str]: + """Apply collected modifications to each line.""" + refactored_lines = [] + for line_num, original_line in enumerate(lines, start=1): + if line_num in modifications: + # Sort modifications by column offset (reverse to replace from right to left) + mods = sorted(modifications[line_num], key=lambda x: x[0], reverse=True) + modified_line = original_line + # print("this si the og line: " + modified_line) + + for col_offset, old_access, new_access in mods: + end_idx = col_offset + len(old_access) + # Replace specific occurrence using slicing + modified_line = ( + modified_line[:col_offset] + new_access + modified_line[end_idx:] + ) + # print(modified_line) + + refactored_lines.append(modified_line) + else: + # No modification, add original line + refactored_lines.append(original_line) + + return refactored_lines + + def _update_dict_assignment(self, refactored_lines: list[str]) -> None: + """Update dictionary assignment to be the new flattened dictionary.""" + dictionary_assignment_name = self.dict_name + for i, line in enumerate(refactored_lines): + match = next( + ( + name + for name in dictionary_assignment_name + if re.match(rf"^\s*(?:\w+\.)*{re.escape(name)}\s*=", line) + ), + None, + ) + + if match: + # Preserve indentation and the `=` + indent, prefix, _ = re.split(r"(=)", line, maxsplit=1) + + # Convert dict to a properly formatted string + dict_str = json.dumps(self.dict_assignment, separators=(",", ": ")) + # Update the line with the new flattened dictionary + refactored_lines[i] = f"{indent}{prefix} {dict_str}" + + # Remove the following lines of the original nested dictionary, + # leaving only one empty line after them + j = i + 1 + while j < len(refactored_lines) and ( + refactored_lines[j].strip().startswith('"') + or refactored_lines[j].strip().startswith("}") + ): + refactored_lines[j] = "Remove this line" # Mark for removal + j += 1 + break + + refactored_lines = [line for line in refactored_lines if line.strip() != "Remove this line"] + + return refactored_lines diff --git a/src/ecooptimizer/refactorers/concrete/long_lambda_function.py b/src/ecooptimizer/refactorers/concrete/long_lambda_function.py new file mode 100644 index 00000000..76c5e6bc --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_lambda_function.py @@ -0,0 +1,153 @@ +from pathlib import Path +import re +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import LLESmell + + +class LongLambdaFunctionRefactorer(BaseRefactorer[LLESmell]): + """ + Refactorer that targets long lambda functions by converting them into normal functions. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def truncate_at_top_level_comma(body: str) -> str: + """ + Truncate the lambda body at the first top-level comma, ignoring commas + within nested parentheses, brackets, or braces. + """ + truncated_body = [] + open_parens = 0 + + for char in body: + if char in "([{": + open_parens += 1 + elif char in ")]}": + open_parens -= 1 + elif char == "," and open_parens == 0: + # Stop at the first top-level comma + break + + truncated_body.append(char) + + return "".join(truncated_body).strip() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: LLESmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor long lambda functions by converting them into normal functions + and writing the refactored code to a new file. + """ + # Extract details from smell + line_number = smell.occurences[0].line + + # Read the original file + content = target_file.read_text(encoding="utf-8") + lines = content.splitlines(keepends=True) + + # Capture the entire logical line containing the lambda + current_line = line_number - 1 + lambda_lines = [lines[current_line].rstrip()] + + # Check if lambda is wrapped in parentheses + has_parentheses = lambda_lines[0].strip().startswith("(") + + # Find continuation lines only if needed + if has_parentheses: + while current_line < len(lines) - 1 and not lambda_lines[ + -1 + ].strip().endswith(")"): + current_line += 1 + lambda_lines.append(lines[current_line].rstrip()) + else: + # Handle single-line lambda + lambda_lines = [lines[current_line].rstrip()] + + full_lambda_line = " ".join(lambda_lines).strip() + + # Remove surrounding parentheses if present + if has_parentheses: + full_lambda_line = re.sub(r"^\((.*)\)$", r"\1", full_lambda_line) + + # Extract leading whitespace for correct indentation + original_indent = re.match(r"^\s*", lambda_lines[0]).group() # type: ignore + + # Use different regex based on whether the lambda line starts with a parenthesis + if has_parentheses: + lambda_match = re.search( + r"lambda\s+([\w, ]+):\s+(.+?)(?=\s*\))", full_lambda_line + ) + else: + lambda_match = re.search(r"lambda\s+([\w, ]+):\s+(.+)", full_lambda_line) + + if not lambda_match: + return + + # Extract arguments and body of the lambda + lambda_args = lambda_match.group(1).strip() + lambda_body_before = lambda_match.group(2).strip() + lambda_body_before = LongLambdaFunctionRefactorer.truncate_at_top_level_comma( + lambda_body_before + ) + + # Ensure that the lambda body does not contain extra trailing characters + # Remove any trailing commas or mismatched closing brackets + lambda_body = re.sub(r",\s*\)$", "", lambda_body_before).strip() + + lambda_body_no_extra_space = re.sub(r"\s{2,}", " ", lambda_body) + # Generate a unique function name + function_name = f"converted_lambda_{line_number}" + + # Find the start of the block containing the lambda + original_indent_len = len(original_indent) + block_start = line_number - 1 + while block_start > 0: + prev_line = lines[block_start - 1].rstrip() + prev_indent = len(re.match(r"^\s*", prev_line).group()) # type: ignore + if prev_line.endswith(":") and prev_indent < original_indent_len: + break + block_start -= 1 + + # Get proper block indentation + block_indentation = re.match(r"^\s*", lines[block_start]).group() # type: ignore + function_indent = block_indentation + body_indent = function_indent + " " * 4 + + # Create properly indented function definition + function_def = ( + f"{function_indent}def {function_name}({lambda_args}):\n" + f"{body_indent}result = {lambda_body_no_extra_space}\n" + f"{body_indent}return result\n\n" + ) + + # Prepare refactored line with original indentation + replacement_line = full_lambda_line.replace( + f"lambda {lambda_args}: {lambda_body}", function_name + ) + refactored_line = f"{original_indent}{replacement_line.strip()}" + + # Split multi-line function definition into individual lines + function_lines = function_def.splitlines(keepends=True) + + # Replace the lambda line with the refactored line in place + lines[current_line] = f"{refactored_line}\n" + + # Insert the new function definition immediately at the beginning of the block + lines.insert(block_start, "".join(function_lines)) + + # Write changes + new_content = "".join(lines) + if overwrite: + target_file.write_text(new_content, encoding="utf-8") + else: + output_file.write_text(new_content, encoding="utf-8") + + self.modified_files.append(target_file) diff --git a/src/ecooptimizer/refactorers/concrete/long_message_chain.py b/src/ecooptimizer/refactorers/concrete/long_message_chain.py new file mode 100644 index 00000000..663778dc --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_message_chain.py @@ -0,0 +1,181 @@ +from pathlib import Path +import re +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import LMCSmell + + +class LongMessageChainRefactorer(BaseRefactorer[LMCSmell]): + """ + Refactorer that targets long method chains to improve performance. + """ + + def __init__(self) -> None: + super().__init__() + + @staticmethod + def remove_unmatched_brackets(input_string: str): + """ + Removes unmatched brackets from the input string. + + Args: + input_string (str): The string to process. + + Returns: + str: The string with unmatched brackets removed. + """ + stack = [] + indexes_to_remove = set() + + # Iterate through the string to find unmatched brackets + for i, char in enumerate(input_string): + if char == "(": + stack.append(i) + elif char == ")": + if stack: + stack.pop() # Matched bracket, remove from stack + else: + indexes_to_remove.add(i) # Unmatched closing bracket + + # Add any unmatched opening brackets left in the stack + indexes_to_remove.update(stack) + + # Build the result string without unmatched brackets + result = "".join( + char for i, char in enumerate(input_string) if i not in indexes_to_remove + ) + + return result + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: LMCSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor long message chains by breaking them into separate statements + and writing the refactored code to a new file. + """ + # Extract details from smell + line_number = smell.occurences[0].line + # temp_filename = output_file + + # Read file content using read_text + content = target_file.read_text(encoding="utf-8") + lines = content.splitlines(keepends=True) # Preserve line endings + + # Identify the line with the long method chain + line_with_chain = lines[line_number - 1].rstrip() + + # Extract leading whitespace for correct indentation + leading_whitespace = re.match(r"^\s*", line_with_chain).group() # type: ignore + + # Check if the line contains an f-string + f_string_pattern = r"f\".*?\"" + if re.search(f_string_pattern, line_with_chain): + # Determine if original was print or assignment + is_print = line_with_chain.startswith("print(") + original_var = ( + None if is_print else line_with_chain.split("=", 1)[0].strip() + ) + + # Extract f-string and methods + f_string_content = re.search(f_string_pattern, line_with_chain).group() # type: ignore + remaining_chain = line_with_chain.split(f_string_content, 1)[-1].lstrip(".") + + method_calls = re.split(r"\.(?![^()]*\))", remaining_chain.strip()) + refactored_lines = [] + + # Initial f-string assignment + refactored_lines.append( + f"{leading_whitespace}intermediate_0 = {f_string_content}" + ) + + # Process method calls + for i, method in enumerate(method_calls, start=1): + method = method.strip() + if not method: + continue + + if i < len(method_calls): + refactored_lines.append( + f"{leading_whitespace}intermediate_{i} = " + f"intermediate_{i-1}.{method}" + ) + else: + # Final assignment using original variable name + if is_print: + refactored_lines.append( + f"{leading_whitespace}print(intermediate_{i-1}.{method})" + ) + else: + refactored_lines.append( + f"{leading_whitespace}{original_var} = " + f"intermediate_{i-1}.{method}" + ) + + lines[line_number - 1] = "\n".join(refactored_lines) + "\n" + + else: + # Handle non-f-string chains + original_has_print = "print(" in line_with_chain + chain_content = re.sub(r"^\s*print\((.*)\)\s*$", r"\1", line_with_chain) + + # Extract RHS if assignment exists + if "=" in chain_content: + chain_content = chain_content.split("=", 1)[1].strip() + + # Split chain after closing parentheses + method_calls = re.split(r"(?<=\))\.", chain_content) + + if len(method_calls) > 1: + refactored_lines = [] + base_var = method_calls[0].strip() + refactored_lines.append( + f"{leading_whitespace}intermediate_0 = {base_var}" + ) + + # Process subsequent method calls + for i, method in enumerate(method_calls[1:], start=1): + method = method.strip().lstrip(".") + if not method: + continue + + if i < len(method_calls) - 1: + refactored_lines.append( + f"{leading_whitespace}intermediate_{i} = " + f"intermediate_{i-1}.{method}" + ) + else: + # Preserve original assignment/print structure + if original_has_print: + refactored_lines.append( + f"{leading_whitespace}print(intermediate_{i-1}.{method})" + ) + else: + original_assignment = line_with_chain.split("=", 1)[ + 0 + ].strip() + refactored_lines.append( + f"{leading_whitespace}{original_assignment} = " + f"intermediate_{i-1}.{method}" + ) + + lines[line_number - 1] = "\n".join(refactored_lines) + "\n" + + # # Write the refactored file + # with temp_filename.open("w") as f: + # f.writelines(lines) + + # Join lines and write using write_text + new_content = "".join(lines) + + # Write to appropriate file based on overwrite flag + if overwrite: + target_file.write_text(new_content, encoding="utf-8") + else: + output_file.write_text(new_content, encoding="utf-8") + + self.modified_files.append(target_file) diff --git a/src/ecooptimizer/refactorers/concrete/long_parameter_list.py b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py new file mode 100644 index 00000000..8cd49a9e --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/long_parameter_list.py @@ -0,0 +1,594 @@ +import ast +import astor +from pathlib import Path + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import LPLSmell + + +class FunctionCallVisitor(ast.NodeVisitor): + def __init__(self, function_name: str, class_name: str, is_constructor: bool): + self.function_name = function_name + self.is_constructor = is_constructor # whether or not given function call is a constructor + self.class_name = ( + class_name # name of class being instantiated if function is a constructor + ) + self.found = False + + def visit_Call(self, node: ast.Call): + """Check if the function/class constructor is called.""" + # handle function call + if isinstance(node.func, ast.Name) and node.func.id == self.function_name: + self.found = True + + # handle method call + elif isinstance(node.func, ast.Attribute): + if node.func.attr == self.function_name: + self.found = True + + # handle class constructor call + elif ( + self.is_constructor + and isinstance(node.func, ast.Name) + and node.func.id == self.class_name + ): + self.found = True + + self.generic_visit(node) + + +class LongParameterListRefactorer(MultiFileRefactorer[LPLSmell]): + def __init__(self): + super().__init__() + self.parameter_analyzer = ParameterAnalyzer() + self.parameter_encapsulator = ParameterEncapsulator() + self.function_updater = FunctionCallUpdater() + self.function_node = None # AST node of definition of function that needs to be refactored + self.used_params = None # list of unclassified used params + self.classified_params = None + self.classified_param_names = None + self.classified_param_nodes = [] + self.enclosing_class_name = None + self.is_method = False + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: LPLSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactors function/method with more than 6 parameters by encapsulating those with related names and removing those that are unused + """ + # maximum limit on number of parameters beyond which the code smell is configured to be detected(see analyzers_config.py) + max_param_limit = 6 + self.target_file = target_file + + with target_file.open() as f: + tree = ast.parse(f.read()) + + # find the line number of target function indicated by the code smell object + target_line = smell.occurences[0].line + # use target_line to find function definition at the specific line for given code smell object + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.lineno == target_line: + self.function_node = node + params = [arg.arg for arg in self.function_node.args.args if arg.arg != "self"] + default_value_params = self.parameter_analyzer.get_parameters_with_default_value( + self.function_node.args.defaults, params + ) # params that have default value assigned in function definition, stored as a dict of param name to default value + + if ( + len(params) > max_param_limit + ): # max limit beyond which the code smell is configured to be detected + # need to identify used parameters so unused ones can be removed + self.used_params = self.parameter_analyzer.get_used_parameters( + self.function_node, params + ) + if len(self.used_params) > max_param_limit: + # classify used params into data and config types and store the results in a dictionary, if number of used params is beyond the configured limit + self.classified_params = self.parameter_analyzer.classify_parameters( + self.used_params + ) + self.classified_param_names = self._generate_unique_param_class_names() + # add class defitions for data and config encapsulations to the tree + self.classified_param_nodes = ( + self.parameter_encapsulator.encapsulate_parameters( + self.classified_params, + default_value_params, + self.classified_param_names, + ) + ) + + tree = self._update_tree_with_class_nodes(tree) + + # first update calls to this function(this needs to use existing params) + updated_tree = self.function_updater.update_function_calls( + tree, + self.function_node, + self.used_params, + self.classified_params, + self.classified_param_names, + ) + # then update function signature and parameter usages with function body) + updated_function = self.function_updater.update_function_signature( + self.function_node, self.classified_params + ) + updated_function = self.function_updater.update_parameter_usages( + self.function_node, self.classified_params + ) + else: + # just remove the unused params if used parameters are within the max param list + updated_function = self.function_updater.remove_unused_params( + self.function_node, self.used_params, default_value_params + ) + + # update the tree by replacing the old function with the updated one + for i, body_node in enumerate(tree.body): + if body_node == self.function_node: + tree.body[i] = updated_function + break + updated_tree = tree + + modified_source = astor.to_source(updated_tree) + + with output_file.open("w") as temp_file: + temp_file.write(modified_source) + + if overwrite: + with target_file.open("w") as f: + f.write(modified_source) + + self.is_method = self.function_node.name == "__init__" + + # if refactoring __init__, determine the class name + if self.is_method: + self.enclosing_class_name = FunctionCallUpdater.get_enclosing_class_name( + ast.parse(target_file.read_text()), self.function_node + ) + + self.traverse_and_process(source_dir) + + def _process_file(self, file: Path): + if file.samefile(self.target_file): + return False + + tree = ast.parse(file.read_text()) + + # check if function call or class instantiation occurs in this file + visitor = FunctionCallVisitor( + self.function_node.name, self.enclosing_class_name, self.is_method + ) + visitor.visit(tree) + + if not visitor.found: + return False + + # insert class definitions before modifying function calls + updated_tree = self._update_tree_with_class_nodes(tree) + + # update function calls/class instantiations + updated_tree = self.function_updater.update_function_calls( + updated_tree, + self.function_node, + self.used_params, + self.classified_params, + self.classified_param_names, + ) + + modified_source = astor.to_source(updated_tree) + with file.open("w") as f: + f.write(modified_source) + + return True + + def _generate_unique_param_class_names(self) -> tuple[str, str]: + """ + Generate unique class names for data params and config params based on function name and line number. + :return: A tuple containing (DataParams class name, ConfigParams class name). + """ + unique_suffix = f"{self.function_node.name}_{self.function_node.lineno}" + data_class_name = f"DataParams_{unique_suffix}" + config_class_name = f"ConfigParams_{unique_suffix}" + return data_class_name, config_class_name + + def _update_tree_with_class_nodes(self, tree: ast.Module) -> ast.Module: + insert_index = 0 + for i, node in enumerate(tree.body): + if isinstance(node, ast.FunctionDef): + insert_index = i # first function definition found + break + + # insert class nodes before the first function definition + for class_node in reversed(self.classified_param_nodes): + tree.body.insert(insert_index, class_node) + return tree + + +class ParameterAnalyzer: + @staticmethod + def get_used_parameters(function_node: ast.FunctionDef, params: list[str]) -> set[str]: + """ + Identifies parameters that actually are used within the function/method body using AST analysis + """ + source_code = astor.to_source(function_node) + tree = ast.parse(source_code) + + used_set = set() + + # visitor class that tracks parameter usage + class ParamUsageVisitor(ast.NodeVisitor): + def visit_Name(self, node: ast.Name): + if isinstance(node.ctx, ast.Load) and node.id in params: + used_set.add(node.id) + + ParamUsageVisitor().visit(tree) + + # preserve the order of params by filtering used parameters + used_params = [param for param in params if param in used_set] + return used_params + + @staticmethod + def get_parameters_with_default_value(default_values: list[ast.Constant], params: list[str]): + """ + Given list of default values for params and params, creates a dictionary mapping param names to default values + """ + default_params_len = len(default_values) + params_len = len(params) + # default params are always defined towards the end of param list, so offest is needed to access param names + offset = params_len - default_params_len + + defaultsDict = dict() + for i in range(0, default_params_len): + defaultsDict[params[offset + i]] = default_values[i].value + return defaultsDict + + @staticmethod + def classify_parameters(params: list[str]) -> dict: + """ + Classifies parameters into 'data' and 'config' groups based on naming conventions + """ + data_params: list[str] = [] + config_params: list[str] = [] + + data_keywords = {"data", "input", "output", "result", "record", "item"} + config_keywords = {"config", "setting", "option", "env", "parameter", "path"} + + for param in params: + param_lower = param.lower() + if any(keyword in param_lower for keyword in data_keywords): + data_params.append(param) + elif any(keyword in param_lower for keyword in config_keywords): + config_params.append(param) + else: + data_params.append(param) + return {"data": data_params, "config": config_params} + + +class ParameterEncapsulator: + @staticmethod + def create_parameter_object_class( + param_names: list[str], default_value_params: dict, class_name: str = "ParamsObject" + ) -> str: + """ + Creates a class definition for encapsulating related parameters + """ + # class_def = f"class {class_name}:\n" + # init_method = " def __init__(self, {}):\n".format(", ".join(param_names)) + # init_body = "".join([f" self.{param} = {param}\n" for param in param_names]) + # return class_def + init_method + init_body + class_def = f"class {class_name}:\n" + init_params = [] + init_body = [] + for param in param_names: + if param in default_value_params: # Include default value in the constructor + init_params.append(f"{param}={default_value_params[param]}") + else: + init_params.append(param) + init_body.append(f" self.{param} = {param}\n") + + init_method = " def __init__(self, {}):\n".format(", ".join(init_params)) + return class_def + init_method + "".join(init_body) + + def encapsulate_parameters( + self, + classified_params: dict, + default_value_params: dict, + classified_param_names: tuple[str, str], + ) -> list[ast.ClassDef]: + """ + Injects parameter object classes into the AST tree + """ + data_params, config_params = classified_params["data"], classified_params["config"] + class_nodes = [] + + data_class_name, config_class_name = classified_param_names + + if data_params: + data_param_object_code = self.create_parameter_object_class( + data_params, default_value_params, class_name=data_class_name + ) + class_nodes.append(ast.parse(data_param_object_code).body[0]) + + if config_params: + config_param_object_code = self.create_parameter_object_class( + config_params, default_value_params, class_name=config_class_name + ) + class_nodes.append(ast.parse(config_param_object_code).body[0]) + + return class_nodes + + +class FunctionCallUpdater: + @staticmethod + def get_method_type(func_node: ast.FunctionDef): + # Check decorators + for decorator in func_node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == "staticmethod": + return "static method" + if isinstance(decorator, ast.Name) and decorator.id == "classmethod": + return "class method" + + # Check first argument + if func_node.args.args: + first_arg = func_node.args.args[0].arg + if first_arg == "self": + return "instance method" + elif first_arg == "cls": + return "class method" + + return "unknown method type" + + @staticmethod + def remove_unused_params( + function_node: ast.FunctionDef, used_params: set[str], default_value_params: dict + ) -> ast.FunctionDef: + """ + Removes unused parameters from the function signature. + """ + method_type = FunctionCallUpdater.get_method_type(function_node) + updated_node_args = ( + [ast.arg(arg="self", annotation=None)] + if method_type == "instance method" + else [ast.arg(arg="cls", annotation=None)] + if method_type == "class method" + else [] + ) + + updated_node_defaults = [] + for arg in function_node.args.args: + if arg.arg in used_params: + updated_node_args.append(arg) + if arg.arg in default_value_params.keys(): + updated_node_defaults.append(default_value_params[arg.arg]) + + function_node.args.args = updated_node_args + function_node.args.defaults = updated_node_defaults + return function_node + + @staticmethod + def update_function_signature(function_node: ast.FunctionDef, params: dict) -> ast.FunctionDef: + """ + Updates the function signature to use encapsulated parameter objects. + """ + data_params, config_params = params["data"], params["config"] + + method_type = FunctionCallUpdater.get_method_type(function_node) + updated_node_args = ( + [ast.arg(arg="self", annotation=None)] + if method_type == "instance method" + else [ast.arg(arg="cls", annotation=None)] + if method_type == "class method" + else [] + ) + + updated_node_args += [ + ast.arg(arg="data_params", annotation=None) for _ in [data_params] if data_params + ] + [ + ast.arg(arg="config_params", annotation=None) for _ in [config_params] if config_params + ] + + function_node.args.args = updated_node_args + function_node.args.defaults = [] + + return function_node + + @staticmethod + def update_parameter_usages(function_node: ast.FunctionDef, params: dict) -> ast.FunctionDef: + """ + Updates all parameter usages within the function body with encapsulated objects. + """ + data_params, config_params = params["data"], params["config"] + + class ParameterUsageTransformer(ast.NodeTransformer): + def visit_Name(self, node: ast.Name): + if node.id in data_params and isinstance(node.ctx, ast.Load): + return ast.Attribute( + value=ast.Name(id="data_params", ctx=ast.Load()), attr=node.id, ctx=node.ctx + ) + if node.id in config_params and isinstance(node.ctx, ast.Load): + return ast.Attribute( + value=ast.Name(id="config_params", ctx=ast.Load()), + attr=node.id, + ctx=node.ctx, + ) + return node + + function_node.body = [ + ParameterUsageTransformer().visit(stmt) for stmt in function_node.body + ] + return function_node + + @staticmethod + def get_enclosing_class_name(tree: ast.Module, init_node: ast.FunctionDef) -> str | None: + """ + Finds the class name enclosing the given __init__ function node. This will be the class that is instantiaeted by the init method. + + :param tree: AST tree + :param init_node: __init__ function node + :return: name of the enclosing class, or None if not found + """ + # Stack to track parent nodes + parent_stack = [] + + class ClassNameVisitor(ast.NodeVisitor): + def visit_ClassDef(self, node: ast.ClassDef): + # Push the class onto the stack + parent_stack.append(node) + self.generic_visit(node) + # Pop the class after visiting its children + parent_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef): + # If this is the target __init__ function, get the enclosing class + if node is init_node: + # Find the nearest enclosing class from the stack + for parent in reversed(parent_stack): + if isinstance(parent, ast.ClassDef): + raise StopIteration(parent.name) # Return the class name + self.generic_visit(node) + + # Traverse the AST with the visitor + try: + ClassNameVisitor().visit(tree) + except StopIteration as e: + return e.value + + # If no enclosing class is found + return None + + @staticmethod + def update_function_calls( + tree: ast.Module, + function_node: ast.FunctionDef, + used_params: [], + classified_params: dict, + classified_param_names: tuple[str, str], + ) -> ast.Module: + """ + Updates all calls to a given function in the provided AST tree to reflect new encapsulated parameters. + + :param tree: The AST tree of the code. + :param function_node: AST node of the function to update calls for. + :param params: A dictionary containing 'data' and 'config' parameters. + :return: The updated AST tree. + """ + + class FunctionCallTransformer(ast.NodeTransformer): + def __init__( + self, + function_node: ast.FunctionDef, + unclassified_params: [], + classified_params: dict, + classified_param_names: tuple[str, str], + is_constructor: bool = False, + class_name: str = "", + ): + self.function_node = function_node + self.unclassified_params = unclassified_params + self.classified_params = classified_params + self.is_constructor = is_constructor + self.class_name = class_name + self.classified_param_names = classified_param_names + + def visit_Call(self, node: ast.Call): + # node.func is a ast.Name if it is a function call, and ast.Attribute if it is a a method class + if isinstance(node.func, ast.Name): + node_name = node.func.id + elif isinstance(node.func, ast.Attribute): + node_name = node.func.attr + + if ( + self.is_constructor and node_name == self.class_name + ) or node_name == self.function_node.name: + transformed_node = self.transform_call(node) + return transformed_node + return node + + def create_ast_call( + self, + function_name: str, + param_list: dict, + args_map: list[ast.expr], + keywords_map: list[ast.keyword], + ): + """ + Creates a AST for function call + """ + + return ( + ast.Call( + func=ast.Name(id=function_name, ctx=ast.Load()), + args=[args_map[key] for key in param_list if key in args_map], + keywords=[ + ast.keyword(arg=key, value=keywords_map[key]) + for key in param_list + if key in keywords_map + ], + ) + if param_list + else None + ) + + def transform_call(self, node: ast.Call): + # original and classified params from function node + data_params, config_params = ( + self.classified_params["data"], + self.classified_params["config"], + ) + data_class_name, config_class_name = self.classified_param_names + + # positional and keyword args passed in function call + original_args, original_kargs = node.args, node.keywords + + data_args = { + param: original_args[i] + for i, param in enumerate(self.unclassified_params) + if i < len(original_args) and param in data_params + } + config_args = { + param: original_args[i] + for i, param in enumerate(self.unclassified_params) + if i < len(original_args) and param in config_params + } + + data_keywords = {kw.arg: kw.value for kw in original_kargs if kw.arg in data_params} + config_keywords = { + kw.arg: kw.value for kw in original_kargs if kw.arg in config_params + } + + updated_node_args = [] + if data_node := self.create_ast_call( + data_class_name, data_params, data_args, data_keywords + ): + updated_node_args.append(data_node) + if config_node := self.create_ast_call( + config_class_name, config_params, config_args, config_keywords + ): + updated_node_args.append(config_node) + + # update function call node. note that keyword arguments are updated within encapsulated param objects above + node.args, node.keywords = updated_node_args, [] + return node + + # apply the transformer to update all function calls to given function node + if function_node.name == "__init__": + # if function is a class initialization, then we need to fetch class name + class_name = FunctionCallUpdater.get_enclosing_class_name(tree, function_node) + transformer = FunctionCallTransformer( + function_node, + used_params, + classified_params, + classified_param_names, + True, + class_name, + ) + else: + transformer = FunctionCallTransformer( + function_node, used_params, classified_params, classified_param_names + ) + updated_tree = transformer.visit(tree) + + return updated_tree diff --git a/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py new file mode 100644 index 00000000..4747875e --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py @@ -0,0 +1,239 @@ +import astroid +from astroid import nodes, util +import libcst as cst +from libcst.metadata import PositionProvider, MetadataWrapper + +from pathlib import Path + +from ...config import CONFIG + +from ..multi_file_refactorer import MultiFileRefactorer +from ...data_types.smell import MIMSmell + + +class CallTransformer(cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self, class_name: str): + self.method_calls: list[tuple[str, int, str, str]] = None + self.class_name = class_name # Class name to replace instance calls + self.transformed = False + + def set_calls(self, valid_calls: list[tuple[str, int, str, str]]): + self.method_calls = valid_calls + + def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: + """Transform instance calls to static calls if they match.""" + if isinstance(original_node.func, cst.Attribute): + caller = original_node.func.value + method = original_node.func.attr.value + position = self.get_metadata(PositionProvider, original_node, None) + + if not position: + raise TypeError("What do you mean you can't find the position?") + + # Check if this call matches one from astroid (by caller, method name, and line number) + for call_caller, line, call_method, cls in self.method_calls: + CONFIG["refactorLogger"].debug( + f"cst caller: {call_caller} at line {position.start.line}" + ) + if ( + method == call_method + and position.start.line == line + and caller.deep_equals(cst.parse_expression(call_caller)) + ): + CONFIG["refactorLogger"].debug("transforming") + # Transform `obj.method(args)` -> `ClassName.method(args)` + new_func = cst.Attribute( + value=cst.Name(cls), # Replace `obj` with class name + attr=original_node.func.attr, + ) + self.transformed = True + return updated_node.with_changes(func=new_func) + + return updated_node # Return unchanged if no match + + +def find_valid_method_calls( + tree: nodes.Module, mim_method: str, valid_classes: set[str] +) -> list[tuple[str, int, str, str]]: + """ + Finds method calls where the instance is of a valid class. + + Returns: + A list of (caller_name, line_number, method_name). + """ + valid_calls = [] + + CONFIG["refactorLogger"].info("Finding valid method calls") + + for node in tree.body: + for descendant in node.nodes_of_class(nodes.Call): + if isinstance(descendant.func, nodes.Attribute): + CONFIG["refactorLogger"].debug(f"caller: {descendant.func.expr.as_string()}") + caller = descendant.func.expr # The object calling the method + method_name = descendant.func.attrname + + if method_name != mim_method: + continue + + inferred_types: list[str] = [] + inferrences = caller.infer() + + for inferred in inferrences: + CONFIG["refactorLogger"].debug(f"inferred: {inferred.repr_name()}") + if isinstance(inferred, util.UninferableBase): + hint = check_for_annotations(caller, descendant.scope()) + inits = check_for_initializations(caller, descendant.scope()) + if hint: + inferred_types.append(hint.as_string()) + elif inits: + inferred_types.extend(inits) + else: + continue + else: + inferred_types.append(inferred.repr_name()) + + CONFIG["refactorLogger"].debug(f"Inferred types: {inferred_types}") + + # Check if any inferred type matches a valid class + for cls in inferred_types: + if cls in valid_classes: + CONFIG["refactorLogger"].debug( + f"Foud valid call: {caller.as_string()} at line {descendant.lineno}" + ) + valid_calls.append( + (caller.as_string(), descendant.lineno, method_name, cls) + ) + + return valid_calls + + +def check_for_initializations(caller: nodes.NodeNG, scope: nodes.NodeNG): + inits: list[str] = [] + + for assign in scope.nodes_of_class(nodes.Assign): + if assign.targets[0].as_string() == caller.as_string() and isinstance( + assign.value, nodes.Call + ): + if isinstance(assign.value.func, nodes.Name): + inits.append(assign.value.func.name) + + return inits + + +def check_for_annotations(caller: nodes.NodeNG, scope: nodes.NodeNG): + if not isinstance(scope, nodes.FunctionDef): + return None + + hint = None + CONFIG["refactorLogger"].debug(f"annotations: {scope.args}") + + args = scope.args.args + anns = scope.args.annotations + if args and anns: + for arg, ann in zip(args, anns): + if arg.name == caller.as_string() and ann: + hint = ann + break + + return hint + + +class MakeStaticRefactorer(MultiFileRefactorer[MIMSmell], cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def __init__(self): + super().__init__() + self.target_line = None + self.mim_method_class = "" + self.mim_method = "" + self.valid_classes: set[str] = set() + + def refactor( + self, + target_file: Path, + source_dir: Path, + smell: MIMSmell, + output_file: Path, + overwrite: bool = True, + ): + self.target_line = smell.occurences[0].line + self.target_file = target_file + + if not smell.obj: + raise TypeError("No method object found") + + self.mim_method_class, self.mim_method = smell.obj.split(".") + self.valid_classes.add(self.mim_method_class) + + source_code = target_file.read_text() + tree = MetadataWrapper(cst.parse_module(source_code)) + + # Find all subclasses of the target class + self._find_subclasses(source_dir) + + modified_tree = tree.visit(self) + target_file.write_text(modified_tree.code) + + self.transformer = CallTransformer(self.mim_method_class) + + self.traverse_and_process(source_dir) + if not overwrite: + output_file.write_text(target_file.read_text()) + + def _find_subclasses(self, directory: Path): + """Find all subclasses of the target class within the file.""" + + def get_subclasses(tree: nodes.Module): + subclasses: set[str] = set() + for klass in tree.nodes_of_class(nodes.ClassDef): + if any(base == self.mim_method_class for base in klass.basenames): + if not any(method.name == self.mim_method for method in klass.mymethods()): + subclasses.add(klass.name) + return subclasses + + CONFIG["refactorLogger"].debug("find all subclasses") + self.traverse(directory) + for file in self.py_files: + tree = astroid.parse(file.read_text()) + self.valid_classes = self.valid_classes.union(get_subclasses(tree)) + CONFIG["refactorLogger"].debug(f"valid classes: {self.valid_classes}") + + def _process_file(self, file: Path): + processed = False + + source_code = file.read_text("utf-8") + + astroid_tree = astroid.parse(source_code) + valid_calls = find_valid_method_calls(astroid_tree, self.mim_method, self.valid_classes) + self.transformer.set_calls(valid_calls) + + tree = MetadataWrapper(cst.parse_module(source_code)) + modified_tree = tree.visit(self.transformer) + + if self.transformer.transformed: + file.write_text(modified_tree.code) + if not file.samefile(self.target_file): + processed = True + self.transformer.transformed = False + + return processed + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.FunctionDef: + func_name = original_node.name.value + if func_name and updated_node.deep_equals(original_node): + position = self.get_metadata(PositionProvider, original_node).start # type: ignore + if position.line == self.target_line and func_name == self.mim_method: + CONFIG["refactorLogger"].debug("Modifying MIM method") + decorators = [ + *list(original_node.decorators), + cst.Decorator(cst.Name("staticmethod")), + ] + params = original_node.params + if params.params and params.params[0].name.value == "self": + params = params.with_changes(params=params.params[1:]) + return updated_node.with_changes(decorators=decorators, params=params) + return updated_node diff --git a/src/ecooptimizer/refactorers/concrete/repeated_calls.py b/src/ecooptimizer/refactorers/concrete/repeated_calls.py new file mode 100644 index 00000000..9057281a --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/repeated_calls.py @@ -0,0 +1,148 @@ +import ast +from pathlib import Path + +from ...data_types.smell import CRCSmell + +from ..base_refactorer import BaseRefactorer + + +class CacheRepeatedCallsRefactorer(BaseRefactorer[CRCSmell]): + def __init__(self): + """ + Initializes the CacheRepeatedCallsRefactorer. + """ + super().__init__() + self.target_line = None + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: CRCSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor the repeated function call smell and save to a new file. + """ + self.target_file = target_file + self.smell = smell + self.call_string = self.smell.additionalInfo.callString.strip() + + self.cached_var_name = "cached_" + self.call_string.split("(")[0] + + with self.target_file.open("r") as file: + lines = file.readlines() + + # Parse the AST + tree = ast.parse("".join(lines)) + + # Find the valid parent node + parent_node = self._find_valid_parent(tree) + if not parent_node: + return + + # Determine the insertion point for the cached variable + insert_line = self._find_insert_line(parent_node) + indent = self._get_indentation(lines, insert_line) + cached_assignment = f"{indent}{self.cached_var_name} = {self.call_string}\n" + + # Insert the cached variable into the source lines + lines.insert(insert_line - 1, cached_assignment) + line_shift = 1 # Track the shift in line numbers caused by the insertion + + # Replace calls with the cached variable in the affected lines + for occurrence in self.smell.occurences: + adjusted_line_index = occurrence.line - 1 + line_shift + original_line = lines[adjusted_line_index] + updated_line = self._replace_call_in_line( + original_line, self.call_string, self.cached_var_name + ) + if updated_line != original_line: + lines[adjusted_line_index] = updated_line + + # Save the modified file + temp_file_path = output_file + + with temp_file_path.open("w") as refactored_file: + refactored_file.writelines(lines) + + # CHANGE FOR MULTI FILE IMPLEMENTATION + if overwrite: + with target_file.open("w") as f: + f.writelines(lines) + else: + with output_file.open("w") as f: + f.writelines(lines) + + def _get_indentation(self, lines: list[str], line_number: int): + """ + Determine the indentation level of a given line. + + :param lines: List of source code lines. + :param line_number: The line number to check. + :return: The indentation string. + """ + line = lines[line_number - 1] + return line[: len(line) - len(line.lstrip())] + + def _replace_call_in_line(self, line: str, call_string: str, cached_var_name: str): + """ + Replace the repeated call in a line with the cached variable. + + :param line: The original line of source code. + :param call_string: The string representation of the call. + :param cached_var_name: The name of the cached variable. + :return: The updated line. + """ + # Replace all exact matches of the call string with the cached variable + updated_line = line.replace(call_string, cached_var_name) + return updated_line + + def _find_valid_parent(self, tree: ast.Module): + """ + Find the valid parent node that contains all occurences of the repeated call. + + :param tree: The root AST tree. + :return: The valid parent node, or None if not found. + """ + candidate_parent = None + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if all(self._line_in_node_body(node, occ.line) for occ in self.smell.occurences): + candidate_parent = node + if candidate_parent: + print( + f"Valid parent found: {type(candidate_parent).__name__} at line " + f"{getattr(candidate_parent, 'lineno', 'module')}" + ) + return candidate_parent + + def _find_insert_line(self, parent_node: ast.FunctionDef | ast.ClassDef | ast.Module): + """ + Find the line to insert the cached variable assignment. + + :param parent_node: The parent node containing the occurences. + :return: The line number where the cached variable should be inserted. + """ + if isinstance(parent_node, ast.Module): + return 1 # Top of the module + return parent_node.body[0].lineno # Beginning of the parent node's body + + def _line_in_node_body(self, node: ast.FunctionDef | ast.ClassDef | ast.Module, line: int): + """ + Check if a line is within the body of a given AST node. + + :param node: The AST node to check. + :param line: The line number to check. + :return: True if the line is within the node's body, False otherwise. + """ + if not hasattr(node, "body"): + return False + + for child in node.body: + if hasattr(child, "lineno") and child.lineno <= line <= getattr( + child, "end_lineno", child.lineno + ): + return True + return False diff --git a/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py b/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py new file mode 100644 index 00000000..e4575844 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/str_concat_in_loop.py @@ -0,0 +1,303 @@ +import re + +from pathlib import Path +import astroid +from astroid import nodes + +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import SCLSmell + + +class UseListAccumulationRefactorer(BaseRefactorer[SCLSmell]): + """ + Refactorer that targets string concatenations inside loops + """ + + def __init__(self): + super().__init__() + self.target_lines: list[int] = [] + self.assign_var = "" + self.target_node: nodes.NodeNG = None + self.last_assign_node: nodes.Assign | nodes.AugAssign = None # type: ignore + self.concat_nodes: list[nodes.Assign | nodes.AugAssign] = [] + self.reassignments: list[nodes.Assign] = [] + self.outer_loop_line: int = 0 + self.outer_loop: nodes.For | nodes.While = None # type: ignore + + def reset(self): + self.__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: SCLSmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactor string concatenations in loops to use list accumulation and join + + :param target_file: absolute path to source code + :param smell: pylint code for smell + :param initial_emission: inital carbon emission prior to refactoring + """ + self.target_lines = [occ.line for occ in smell.occurences] + + if not smell.additionalInfo: + raise RuntimeError("Missing additional info for 'string-concat-loop' smell") + + self.assign_var = smell.additionalInfo.concatTarget + self.outer_loop_line = smell.additionalInfo.innerLoopLine + + # Parse the code into an AST + source_code = target_file.read_text() + tree = astroid.parse(source_code) + for node in tree.get_children(): + self.visit(node) + + if not self.outer_loop or len(self.concat_nodes) != len(self.target_lines): + raise Exception("Missing inner loop or concat nodes.") + + self.find_reassignments() + self.find_scope() + + temp_concat_nodes = [("concat", node) for node in self.concat_nodes] + temp_reassignments = [("reassign", node) for node in self.reassignments] + + combined_nodes = temp_concat_nodes + temp_reassignments + + combined_nodes = sorted( + combined_nodes, + key=lambda x: x[1].lineno, # type: ignore + reverse=True, + ) + + modified_code = self.add_node_to_body(source_code, combined_nodes) + + if overwrite: + target_file.write_text(modified_code) + else: + output_file.write_text(modified_code) + + def visit(self, node: nodes.NodeNG): + if isinstance(node, nodes.Assign) and node.lineno in self.target_lines: + if not self.target_node: + self.target_node = node.targets[0] + self.concat_nodes.append(node) + elif isinstance(node, nodes.AugAssign) and node.lineno in self.target_lines: + if not self.target_node: + self.target_node = node.target + self.concat_nodes.append(node) + elif isinstance(node, (nodes.For, nodes.While)) and node.lineno == self.outer_loop_line: + self.outer_loop = node + for child in node.get_children(): + self.visit(child) + else: + for child in node.get_children(): + self.visit(child) + + def find_reassignments(self): + for node in self.outer_loop.nodes_of_class(nodes.Assign): + for target in node.targets: + if target.as_string() == self.assign_var and node.lineno not in self.target_lines: + self.reassignments.append(node) + + def find_last_assignment(self, scope_node: nodes.NodeNG): + """Find the last assignment of the target variable within a given scope node.""" + last_assignment_node = None + + # Traverse the scope node and find assignments within the valid range + for node in scope_node.nodes_of_class((nodes.AugAssign, nodes.Assign)): + if isinstance(node, nodes.Assign): + for target in node.targets: + if ( + target.as_string() == self.assign_var + and node.lineno < self.outer_loop.lineno # type: ignore + ): + if last_assignment_node is None: + last_assignment_node = node + elif node.lineno > last_assignment_node.lineno: # type: ignore + last_assignment_node = node + else: + if ( + node.target.as_string() == self.assign_var + and node.lineno < self.outer_loop.lineno # type: ignore + ): + if last_assignment_node is None: + last_assignment_node = node + elif node.lineno > last_assignment_node.lineno: # type: ignore + last_assignment_node = node + + self.last_assign_node = last_assignment_node # type: ignore + + def find_scope(self): + """Locate the second innermost loop if nested, else find first non-loop function/method/module ancestor.""" + + for node in self.outer_loop.node_ancestors(): + if isinstance(node, (nodes.For, nodes.While)): + self.find_last_assignment(node) + if not self.last_assign_node: + self.outer_loop = node + else: + self.scope_node = node + break + elif isinstance(node, (nodes.Module, nodes.FunctionDef, nodes.AsyncFunctionDef)): + self.find_last_assignment(node) + self.scope_node = node + break + + def last_assign_is_referenced(self, search_area: str): + return ( + search_area.find(self.assign_var) != -1 + or isinstance(self.last_assign_node, nodes.AugAssign) + or self.assign_var in self.last_assign_node.value.as_string() + ) + + def generate_temp_list_name(self): + node = self.target_node + + def _get_node_representation(node: nodes.NodeNG): + """Helper function to get a string representation of a node.""" + if isinstance(node, astroid.Const): + return str(node.value) + if isinstance(node, astroid.Name): + return node.name + if isinstance(node, astroid.Attribute): + return node.attrname + return "unknown" + + if isinstance(node, astroid.Subscript): + # Extracting slice and value for a Subscript node + slice_repr = _get_node_representation(node.slice) + value_repr = _get_node_representation(node.value) + custom_component = f"{value_repr}_at_{slice_repr}" + elif isinstance(node, astroid.AssignAttr): + # Extracting attribute name for an AssignAttr node + attribute_name = node.attrname + custom_component = attribute_name + else: + raise TypeError("Node must be either Subscript or AssignAttr.") + + return f"temp_{custom_component}" + + def add_node_to_body(self, code_file: str, nodes_to_change: list[tuple]): # type: ignore + """ + Add a new AST node + """ + + code_file_lines = code_file.splitlines() + + list_name = self.assign_var + + if not isinstance(self.target_node, nodes.AssignName): + list_name = self.generate_temp_list_name() + + # ------------- ADD JOIN STATEMENT TO SOURCE ---------------- + + join_line = f"{self.assign_var} = ''.join({list_name})" + indent_lno: int = self.outer_loop.lineno - 1 # type: ignore + join_lno: int = self.outer_loop.end_lineno # type: ignore + + source_line = code_file_lines[indent_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.insert(join_lno, outer_scope_whitespace + join_line) + + def get_new_concat_line(concat_node: nodes.AugAssign | nodes.Assign): + concat_line = "" + if isinstance(concat_node, nodes.AugAssign): + concat_line = f"{list_name}.append({concat_node.value.as_string()})" + else: + parts = re.split( + rf"\s*[+]*\s*\b{re.escape(self.assign_var)}\b\s*[+]*\s*", + concat_node.value.as_string(), + ) + + if len(parts[0]) == 0: + concat_line = f"{list_name}.append({parts[1]})" + elif len(parts[1]) == 0: + concat_line = f"{list_name}.insert(0, {parts[0]})" + else: + concat_line = [ + f"{list_name}.insert(0, {parts[0]})", + f"{list_name}.append({parts[1]})", + ] + return concat_line + + def get_new_reassign_line(reassign_node: nodes.Assign): + if reassign_node.value.as_string() in ["''", '""', "str()"]: + return f"{list_name}.clear()" + else: + return f"{list_name} = [{reassign_node.value.as_string()}]" + + # ------------- REFACTOR CONCATS and REASSIGNS ---------------------------- + + for node in nodes_to_change: + if node[0] == "concat": + new_concat = get_new_concat_line(node[1]) + concat_lno = node[1].lineno - 1 + + if isinstance(new_concat, list): + source_line = code_file_lines[concat_lno] + concat_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(concat_lno) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat[1]) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat[0]) + else: + source_line = code_file_lines[concat_lno] + concat_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(concat_lno) + code_file_lines.insert(concat_lno, concat_whitespace + new_concat) + else: + new_reassign = get_new_reassign_line(node[1]) + reassign_lno = node[1].lineno - 1 + + source_line = code_file_lines[reassign_lno] + reassign_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + code_file_lines.pop(reassign_lno) + code_file_lines.insert(reassign_lno, reassign_whitespace + new_reassign) + + # ------------- INITIALIZE TARGET VAR AS A LIST ------------- + if ( + not isinstance(self.target_node, nodes.AssignName) + or not self.last_assign_node + or self.last_assign_is_referenced( + "".join(code_file_lines[self.last_assign_node.lineno : self.outer_loop.lineno - 1]) # type: ignore + ) + ): + list_lno: int = self.outer_loop.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = [{self.assign_var}]" + + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + elif self.last_assign_node.value.as_string() in ["''", "str()"]: + list_lno: int = self.last_assign_node.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = []" + + code_file_lines.pop(list_lno) + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + else: + list_lno: int = self.last_assign_node.lineno - 1 # type: ignore + + source_line = code_file_lines[list_lno] + outer_scope_whitespace = source_line[: len(source_line) - len(source_line.lstrip())] + + list_line = f"{list_name} = [{self.last_assign_node.value.as_string()}]" + + code_file_lines.pop(list_lno) + code_file_lines.insert(list_lno, outer_scope_whitespace + list_line) + + return "\n".join(code_file_lines) diff --git a/src/ecooptimizer/refactorers/concrete/unused.py b/src/ecooptimizer/refactorers/concrete/unused.py new file mode 100644 index 00000000..38ee4cf2 --- /dev/null +++ b/src/ecooptimizer/refactorers/concrete/unused.py @@ -0,0 +1,54 @@ +from pathlib import Path + +from ..base_refactorer import BaseRefactorer +from ...data_types.smell import UVASmell + + +class RemoveUnusedRefactorer(BaseRefactorer[UVASmell]): + def __init__(self): + super().__init__() + + def refactor( + self, + target_file: Path, + source_dir: Path, # noqa: ARG002 + smell: UVASmell, + output_file: Path, + overwrite: bool = True, + ): + """ + Refactors unused imports, variables and class attributes by removing lines where they appear. + Modifies the specified instance in the file if it results in lower emissions. + + :param target_file: Path to the file to be refactored. + :param smell: Dictionary containing details of the Pylint smell, including the line number. + :param initial_emission: Initial emission value before refactoring. + """ + line_number = smell.occurences[0].line + code_type = smell.messageId + + # Load the source code as a list of lines + with target_file.open() as file: + original_lines = file.readlines() + + # Check if the line number is valid within the file + if not (1 <= line_number <= len(original_lines)): + return + + # remove specified line + modified_lines = original_lines[:] + modified_lines[line_number - 1] = "\n" + + # for logging purpose to see what was removed + if code_type != "W0611" and code_type != "UV001": # UNUSED_IMPORT + return + + # Write the modified content to a temporary file + temp_file_path = output_file + + with temp_file_path.open("w") as temp_file: + temp_file.writelines(modified_lines) + + if overwrite: + with target_file.open("w") as f: + f.writelines(modified_lines) diff --git a/src/ecooptimizer/refactorers/multi_file_refactorer.py b/src/ecooptimizer/refactorers/multi_file_refactorer.py new file mode 100644 index 00000000..f5ee57e0 --- /dev/null +++ b/src/ecooptimizer/refactorers/multi_file_refactorer.py @@ -0,0 +1,80 @@ +# pyright: reportOptionalMemberAccess=false +from abc import abstractmethod +import fnmatch +from pathlib import Path +from typing import TypeVar + +from ..config import CONFIG + +from .base_refactorer import BaseRefactorer + +from ..data_types.smell import Smell + + +T = TypeVar("T", bound=Smell) + +DEFAULT_IGNORED_PATTERNS = { + "__pycache__", + "build", + ".venv", + "*.egg-info", + ".git", + "node_modules", + ".*", +} + +DEFAULT_IGNORE_PATH = Path(__file__).parent / "patterns_to_ignore" + + +class MultiFileRefactorer(BaseRefactorer[T]): + def __init__(self): + super().__init__() + self.target_file: Path = None # type: ignore + self.ignore_patterns = self._load_ignore_patterns() + self.py_files: list[Path] = [] + + def _load_ignore_patterns(self, ignore_dir: Path = DEFAULT_IGNORE_PATH) -> set[str]: + """Load ignore patterns from a file, similar to .gitignore.""" + if not ignore_dir.is_dir(): + return DEFAULT_IGNORED_PATTERNS + + patterns = DEFAULT_IGNORED_PATTERNS + for file in ignore_dir.iterdir(): + with file.open() as f: + patterns.update( + [line.strip() for line in f if line.strip() and not line.startswith("#")] + ) + + return patterns + + def is_ignored(self, item: Path) -> bool: + """Check if a file or directory matches any ignore pattern.""" + return any(fnmatch.fnmatch(item.name, pattern) for pattern in self.ignore_patterns) + + def traverse(self, directory: Path): + for item in directory.iterdir(): + if item.is_dir(): + CONFIG["refactorLogger"].debug(f"Scanning directory: {item!s}, name: {item.name}") + if self.is_ignored(item): + CONFIG["refactorLogger"].debug(f"Ignored directory: {item!s}") + continue + + CONFIG["refactorLogger"].debug(f"Entering directory: {item!s}") + self.traverse_and_process(item) + elif item.is_file() and item.suffix == ".py": + self.py_files.append(item) + + def traverse_and_process(self, directory: Path): + if not self.py_files: + self.traverse(directory) + for file in self.py_files: + CONFIG["refactorLogger"].debug(f"Checking file: {file!s}") + if self._process_file(file): + if file not in self.modified_files and not file.samefile(self.target_file): + self.modified_files.append(file.resolve()) + CONFIG["refactorLogger"].debug("finished processing file") + + @abstractmethod + def _process_file(self, file: Path) -> bool: + """Abstract method to be implemented by subclasses to handle file processing.""" + pass diff --git a/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore b/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore new file mode 100644 index 00000000..e36e56d3 --- /dev/null +++ b/src/ecooptimizer/refactorers/patterns_to_ignore/.generalignore @@ -0,0 +1,32 @@ +# Build and distribution artifacts +*.whl + +# IDE and editor files +.vscode/ +.idea/ +*.sublime-* + +# Version control and OS metadata +.git/ +.gitignore +.gitattributes +.svn/ +.DS_Store +Thumbs.db + +# Containerisation and deployment +Dockerfile +.dockerignore +.env +*.log + +# Dependency managers and tooling +poetry.lock +pyproject.toml +requirements.txt +*.ipynb_checkpoints/ + +# Hidden files and miscellaneous patterns +.* +*.bak +*.swp diff --git a/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore b/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore new file mode 100644 index 00000000..1800114d --- /dev/null +++ b/src/ecooptimizer/refactorers/patterns_to_ignore/.pythonignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc \ No newline at end of file diff --git a/src/ecooptimizer/refactorers/refactorer_controller.py b/src/ecooptimizer/refactorers/refactorer_controller.py new file mode 100644 index 00000000..214dd29d --- /dev/null +++ b/src/ecooptimizer/refactorers/refactorer_controller.py @@ -0,0 +1,54 @@ +# pyright: reportOptionalMemberAccess=false +from pathlib import Path + +from ..config import CONFIG + +from ..data_types.smell import Smell +from ..utils.smells_registry import get_refactorer + + +class RefactorerController: + def __init__(self): + """Manages the execution of refactorers for detected code smells.""" + self.smell_counters = {} + + def run_refactorer( + self, target_file: Path, source_dir: Path, smell: Smell, overwrite: bool = True + ): + """Executes the appropriate refactorer for the given smell. + + Args: + target_file (Path): The file to be refactored. + source_dir (Path): The source directory containing the file. + smell (Smell): The detected smell to be refactored. + overwrite (bool, optional): Whether to overwrite existing files. Defaults to True. + + Returns: + list[Path]: A list of modified files resulting from the refactoring process. + + Raises: + NotImplementedError: If no refactorer exists for the given smell. + """ + smell_id = smell.messageId + smell_symbol = smell.symbol + refactorer_class = get_refactorer(smell_symbol) + modified_files = [] + + if refactorer_class: + self.smell_counters[smell_id] = self.smell_counters.get(smell_id, 0) + 1 + file_count = self.smell_counters[smell_id] + + output_file_name = f"{target_file.stem}_path_{smell_id}_{file_count}.py" + output_file_path = Path(__file__).parent / "../../../outputs" / output_file_name + + CONFIG["refactorLogger"].info( + f"πŸ”„ Running refactoring for {smell_symbol} using {refactorer_class.__name__}" + ) + refactorer = refactorer_class() + refactorer.refactor(target_file, source_dir, smell, output_file_path, overwrite) + modified_files = refactorer.modified_files + else: + CONFIG["refactorLogger"].error(f"❌ No refactorer found for smell: {smell_symbol}") + raise NotImplementedError(f"No refactorer implemented for smell: {smell_symbol}") + + return modified_files diff --git a/src/ecooptimizer/utils/__init__.py b/src/ecooptimizer/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ecooptimizer/utils/output_manager.py b/src/ecooptimizer/utils/output_manager.py new file mode 100644 index 00000000..8ba2539e --- /dev/null +++ b/src/ecooptimizer/utils/output_manager.py @@ -0,0 +1,124 @@ +from enum import Enum +import json +import logging +from pathlib import Path +import shutil +from typing import Any + + +DEV_OUTPUT = Path(__file__).parent / "../../../outputs" + + +class EnumEncoder(json.JSONEncoder): + def default(self, o): # noqa: ANN001 + if isinstance(o, Enum): + return o.value # Serialize using the Enum's value + return super().default(o) + + +class LoggingManager: + def __init__(self, logs_dir: Path = DEV_OUTPUT / "logs", production: bool = False): + """Initializes log paths based on mode.""" + + self.production = production + self.logs_dir = logs_dir + + self._initialize_output_structure() + self.log_files = { + "main": self.logs_dir / "main.log", + "detect": self.logs_dir / "detect.log", + "refactor": self.logs_dir / "refactor.log", + } + self._setup_loggers() + + def _initialize_output_structure(self): + """Ensures required directories exist and clears old logs.""" + if not self.production: + DEV_OUTPUT.mkdir(exist_ok=True) + self.logs_dir.mkdir(exist_ok=True) + + def _clear_logs(self): + """Removes existing log files while preserving the log directory.""" + if self.logs_dir.exists(): + for log_file in self.logs_dir.iterdir(): + if log_file.is_file(): + log_file.unlink() + logging.info("πŸ—‘οΈ Cleared existing log files.") + + def _setup_loggers(self): + """Configures loggers for different EcoOptimizer processes.""" + logging.root.handlers.clear() + + logging.basicConfig( + filename=str(self.log_files["main"]), + filemode="a", + level=logging.INFO, + format="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + self.loggers = { + "detect": self._create_logger( + "detect", self.log_files["detect"], self.log_files["main"] + ), + "refactor": self._create_logger( + "refactor", self.log_files["refactor"], self.log_files["main"] + ), + } + + logging.info("πŸ“ Loggers initialized successfully.") + + def _create_logger(self, name: str, log_file: Path, main_log_file: Path): + """ + Creates a logger that logs to both its own file and the main log file. + + Args: + name (str): Name of the logger. + log_file (Path): Path to the specific log file. + main_log_file (Path): Path to the main log file. + + Returns: + logging.Logger: Configured logger instance. + """ + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.propagate = False + + file_handler = logging.FileHandler(str(log_file), mode="a", encoding="utf-8") + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S" + ) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + main_handler = logging.FileHandler(str(main_log_file), mode="a", encoding="utf-8") + main_handler.setFormatter(formatter) + logger.addHandler(main_handler) + + logging.info(f"πŸ“ Logger '{name}' initialized and writing to {log_file}.") + return logger + + +def save_file(file_name: str, data: str, mode: str, message: str = ""): + """Saves data to a file in the output directory.""" + file_path = DEV_OUTPUT / file_name + with file_path.open(mode) as file: + file.write(data) + log_message = message if message else f"πŸ“ {file_name} saved to {file_path!s}" + logging.info(log_message) + + +def save_json_files(file_name: str, data: dict[Any, Any] | list[Any]): + """Saves data to a JSON file in the output directory.""" + file_path = DEV_OUTPUT / file_name + file_path.write_text(json.dumps(data, cls=EnumEncoder, sort_keys=True, indent=4)) + logging.info(f"πŸ“ {file_name} saved to {file_path!s} as JSON file") + + +def copy_file_to_output(source_file_path: Path, new_file_name: str): + """Copies a file to the output directory with a new name.""" + destination_path = DEV_OUTPUT / new_file_name + shutil.copy(source_file_path, destination_path) + logging.info(f"πŸ“ {new_file_name} copied to {destination_path!s}") + return destination_path diff --git a/src/ecooptimizer/utils/smell_enums.py b/src/ecooptimizer/utils/smell_enums.py new file mode 100644 index 00000000..3661002e --- /dev/null +++ b/src/ecooptimizer/utils/smell_enums.py @@ -0,0 +1,29 @@ +from enum import Enum + + +class ExtendedEnum(Enum): + @classmethod + def list(cls) -> list[str]: + return [c.value for c in cls] + + def __eq__(self, value: object) -> bool: + return str(self.value) == value + + +# Enum class for standard Pylint code smells +class PylintSmell(ExtendedEnum): + LONG_PARAMETER_LIST = "R0913" # Pylint code smell for functions with too many parameters + NO_SELF_USE = "R6301" # Pylint code smell for class methods that don't use any self calls + USE_A_GENERATOR = ( + "R1729" # Pylint code smell for unnecessary list comprehensions inside `any()` or `all()` + ) + + +# Enum class for custom code smells not detected by Pylint +class CustomSmell(ExtendedEnum): + LONG_MESSAGE_CHAIN = "LMC001" # Ast code smell for long message chains + UNUSED_VAR_OR_ATTRIBUTE = "UVA001" # Ast code smell for unused variable or attribute + LONG_ELEMENT_CHAIN = "LEC001" # Ast code smell for long element chains + LONG_LAMBDA_EXPR = "LLE001" # Ast code smell for long lambda expressions + STR_CONCAT_IN_LOOP = "SCL001" # Astroid code smell for string concatenation inside loops + CACHE_REPEATED_CALLS = "CRC001" # Ast code smell for repeated calls diff --git a/src/ecooptimizer/utils/smells_registry.py b/src/ecooptimizer/utils/smells_registry.py new file mode 100644 index 00000000..5504a848 --- /dev/null +++ b/src/ecooptimizer/utils/smells_registry.py @@ -0,0 +1,112 @@ +from copy import deepcopy +from .smell_enums import CustomSmell, PylintSmell + +from ..analyzers.ast_analyzers.detect_long_element_chain import detect_long_element_chain +from ..analyzers.ast_analyzers.detect_long_lambda_expression import detect_long_lambda_expression +from ..analyzers.ast_analyzers.detect_long_message_chain import detect_long_message_chain +from ..analyzers.astroid_analyzers.detect_string_concat_in_loop import detect_string_concat_in_loop +from ..analyzers.ast_analyzers.detect_repeated_calls import detect_repeated_calls +from ..analyzers.ast_analyzers.detect_unused_variables_and_attributes import ( + detect_unused_variables_and_attributes, +) + +from ..refactorers.concrete.list_comp_any_all import UseAGeneratorRefactorer + +from ..refactorers.concrete.long_lambda_function import LongLambdaFunctionRefactorer +from ..refactorers.concrete.long_element_chain import LongElementChainRefactorer +from ..refactorers.concrete.long_message_chain import LongMessageChainRefactorer +from ..refactorers.concrete.unused import RemoveUnusedRefactorer +from ..refactorers.concrete.member_ignoring_method import MakeStaticRefactorer +from ..refactorers.concrete.long_parameter_list import LongParameterListRefactorer +from ..refactorers.concrete.str_concat_in_loop import UseListAccumulationRefactorer +from ..refactorers.concrete.repeated_calls import CacheRepeatedCallsRefactorer + +from ..data_types.smell_record import SmellRecord + +_SMELL_REGISTRY: dict[str, SmellRecord] = { + "use-a-generator": { + "id": PylintSmell.USE_A_GENERATOR.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": {}, + "refactorer": UseAGeneratorRefactorer, + }, + "too-many-arguments": { + "id": PylintSmell.LONG_PARAMETER_LIST.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": {"max_args": {"flag": "--max-args", "value": 6}}, + "refactorer": LongParameterListRefactorer, + }, + "no-self-use": { + "id": PylintSmell.NO_SELF_USE.value, + "enabled": True, + "analyzer_method": "pylint", + "checker": None, + "analyzer_options": { + "load-plugin": {"flag": "--load-plugins", "value": "pylint.extensions.no_self_use"} + }, + "refactorer": MakeStaticRefactorer, + }, + "long-lambda-expression": { + "id": CustomSmell.LONG_LAMBDA_EXPR.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_lambda_expression, + "analyzer_options": {"threshold_length": 100, "threshold_count": 5}, + "refactorer": LongLambdaFunctionRefactorer, + }, + "long-message-chain": { + "id": CustomSmell.LONG_MESSAGE_CHAIN.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_message_chain, + "analyzer_options": {"threshold": 3}, + "refactorer": LongMessageChainRefactorer, + }, + "unused_variables_and_attributes": { + "id": CustomSmell.UNUSED_VAR_OR_ATTRIBUTE.value, + "enabled": False, + "analyzer_method": "ast", + "checker": detect_unused_variables_and_attributes, + "analyzer_options": {}, + "refactorer": RemoveUnusedRefactorer, + }, + "long-element-chain": { + "id": CustomSmell.LONG_ELEMENT_CHAIN.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_long_element_chain, + "analyzer_options": {"threshold": 3}, + "refactorer": LongElementChainRefactorer, + }, + "cached-repeated-calls": { + "id": CustomSmell.CACHE_REPEATED_CALLS.value, + "enabled": True, + "analyzer_method": "ast", + "checker": detect_repeated_calls, + "analyzer_options": {"threshold": 2}, + "refactorer": CacheRepeatedCallsRefactorer, + }, + "string-concat-loop": { + "id": CustomSmell.STR_CONCAT_IN_LOOP.value, + "enabled": True, + "analyzer_method": "astroid", + "checker": detect_string_concat_in_loop, + "analyzer_options": {}, + "refactorer": UseListAccumulationRefactorer, + }, +} + + +def retrieve_smell_registry(enabled_smells: list[str] | str): + """Returns a modified SMELL_REGISTRY based on user preferences (enables/disables smells).""" + if enabled_smells == "ALL": + return deepcopy(_SMELL_REGISTRY) + return {key: val for (key, val) in _SMELL_REGISTRY.items() if key in enabled_smells} + + +def get_refactorer(symbol: str): + return _SMELL_REGISTRY[symbol].get("refactorer", None) diff --git a/src/main.py b/src/main.py deleted file mode 100644 index 4508a68d..00000000 --- a/src/main.py +++ /dev/null @@ -1,15 +0,0 @@ -from analyzers.pylint_analyzer import PylintAnalyzer - -def main(): - """ - Entry point for the refactoring tool. - - Create an instance of the analyzer. - - Perform code analysis and print the results. - """ - code_path = "path/to/your/code" # Path to the code to analyze - analyzer = PylintAnalyzer(code_path) - report = analyzer.analyze() # Analyze the code - print(report) # Print the analysis report - -if __name__ == "__main__": - main() diff --git a/src/measurement/energy_meter.py b/src/measurement/energy_meter.py deleted file mode 100644 index 8d589d9d..00000000 --- a/src/measurement/energy_meter.py +++ /dev/null @@ -1,59 +0,0 @@ -import time -from typing import Callable -import pyJoules.energy as joules - -class EnergyMeter: - """ - A class to measure the energy consumption of specific code blocks using PyJoules. - """ - - def __init__(self): - """ - Initializes the EnergyMeter class. - """ - # Optional: Any initialization for the energy measurement can go here - pass - - def measure_energy(self, func: Callable, *args, **kwargs): - """ - Measures the energy consumed by the specified function during its execution. - - Parameters: - - func (Callable): The function to measure. - - *args: Arguments to pass to the function. - - **kwargs: Keyword arguments to pass to the function. - - Returns: - - tuple: A tuple containing the return value of the function and the energy consumed (in Joules). - """ - start_energy = joules.getEnergy() # Start measuring energy - start_time = time.time() # Record start time - - result = func(*args, **kwargs) # Call the specified function - - end_time = time.time() # Record end time - end_energy = joules.getEnergy() # Stop measuring energy - - energy_consumed = end_energy - start_energy # Calculate energy consumed - - # Log the timing (optional) - print(f"Execution Time: {end_time - start_time:.6f} seconds") - print(f"Energy Consumed: {energy_consumed:.6f} Joules") - - return result, energy_consumed # Return the result of the function and the energy consumed - - def measure_block(self, code_block: str): - """ - Measures energy consumption for a block of code represented as a string. - - Parameters: - - code_block (str): A string containing the code to execute. - - Returns: - - float: The energy consumed (in Joules). - """ - local_vars = {} - exec(code_block, {}, local_vars) # Execute the code block - energy_consumed = joules.getEnergy() # Measure energy after execution - print(f"Energy Consumed for the block: {energy_consumed:.6f} Joules") - return energy_consumed diff --git a/src/refactorer/base_refactorer.py b/src/refactorer/base_refactorer.py deleted file mode 100644 index 698440fb..00000000 --- a/src/refactorer/base_refactorer.py +++ /dev/null @@ -1,24 +0,0 @@ -# src/refactorer/base_refactorer.py - -from abc import ABC, abstractmethod - -class BaseRefactorer(ABC): - """ - Abstract base class for refactorers. - Subclasses should implement the `refactor` method. - """ - - def __init__(self, code): - """ - Initialize the refactorer with the code to refactor. - - :param code: The code that needs refactoring - """ - self.code = code - - def refactor(self): - """ - Perform the refactoring process. - Must be implemented by subclasses. - """ - raise NotImplementedError("Subclasses should implement this method") diff --git a/src/refactorer/complex_list_comprehension_refactorer.py b/src/refactorer/complex_list_comprehension_refactorer.py deleted file mode 100644 index b4a96586..00000000 --- a/src/refactorer/complex_list_comprehension_refactorer.py +++ /dev/null @@ -1,115 +0,0 @@ -import ast -import astor - -class ComplexListComprehensionRefactorer: - """ - Refactorer for complex list comprehensions to improve readability. - """ - - def __init__(self, code: str): - """ - Initializes the refactorer. - - :param code: The source code to refactor. - """ - self.code = code - - def refactor(self): - """ - Refactor the code by transforming complex list comprehensions into for-loops. - - :return: The refactored code. - """ - # Parse the code to get the AST - tree = ast.parse(self.code) - - # Walk through the AST and refactor complex list comprehensions - for node in ast.walk(tree): - if isinstance(node, ast.ListComp): - # Check if the list comprehension is complex - if self.is_complex(node): - # Create a for-loop equivalent - for_loop = self.create_for_loop(node) - # Replace the list comprehension with the for-loop in the AST - self.replace_node(node, for_loop) - - # Convert the AST back to code - return self.ast_to_code(tree) - - def create_for_loop(self, list_comp: ast.ListComp) -> ast.For: - """ - Create a for-loop that represents the list comprehension. - - :param list_comp: The ListComp node to convert. - :return: An ast.For node representing the for-loop. - """ - # Create the variable to hold results - result_var = ast.Name(id='result', ctx=ast.Store()) - - # Create the for-loop - for_loop = ast.For( - target=ast.Name(id='item', ctx=ast.Store()), - iter=list_comp.generators[0].iter, - body=[ - ast.Expr(value=ast.Call( - func=ast.Name(id='append', ctx=ast.Load()), - args=[self.transform_value(list_comp.elt)], - keywords=[] - )) - ], - orelse=[] - ) - - # Create a list to hold results - result_list = ast.List(elts=[], ctx=ast.Store()) - return ast.With( - context_expr=ast.Name(id='result', ctx=ast.Load()), - body=[for_loop], - lineno=list_comp.lineno, - col_offset=list_comp.col_offset - ) - - def transform_value(self, value_node: ast.AST) -> ast.AST: - """ - Transform the value in the list comprehension into a form usable in a for-loop. - - :param value_node: The value node to transform. - :return: The transformed value node. - """ - return value_node - - def replace_node(self, old_node: ast.AST, new_node: ast.AST): - """ - Replace an old node in the AST with a new node. - - :param old_node: The node to replace. - :param new_node: The node to insert in its place. - """ - parent = self.find_parent(old_node) - if parent: - for index, child in enumerate(ast.iter_child_nodes(parent)): - if child is old_node: - parent.body[index] = new_node - break - - def find_parent(self, node: ast.AST) -> ast.AST: - """ - Find the parent node of a given AST node. - - :param node: The node to find the parent for. - :return: The parent node, or None if not found. - """ - for parent in ast.walk(node): - for child in ast.iter_child_nodes(parent): - if child is node: - return parent - return None - - def ast_to_code(self, tree: ast.AST) -> str: - """ - Convert AST back to source code. - - :param tree: The AST to convert. - :return: The source code as a string. - """ - return astor.to_source(tree) diff --git a/src/refactorer/large_class_refactorer.py b/src/refactorer/large_class_refactorer.py deleted file mode 100644 index aff1f32d..00000000 --- a/src/refactorer/large_class_refactorer.py +++ /dev/null @@ -1,83 +0,0 @@ -import ast - -class LargeClassRefactorer: - """ - Refactorer for large classes that have too many methods. - """ - - def __init__(self, code: str, method_threshold: int = 5): - """ - Initializes the refactorer. - - :param code: The source code of the class to refactor. - :param method_threshold: The number of methods above which a class is considered large. - """ - self.code = code - self.method_threshold = method_threshold - - def refactor(self): - """ - Refactor the class by splitting it into smaller classes if it exceeds the method threshold. - - :return: The refactored code. - """ - # Parse the code to get the class definition - tree = ast.parse(self.code) - class_definitions = [node for node in tree.body if isinstance(node, ast.ClassDef)] - - refactored_code = [] - - for class_def in class_definitions: - methods = [n for n in class_def.body if isinstance(n, ast.FunctionDef)] - if len(methods) > self.method_threshold: - # If the class is large, split it - new_classes = self.split_class(class_def, methods) - refactored_code.extend(new_classes) - else: - # Keep the class as is - refactored_code.append(class_def) - - # Convert the AST back to code - return self.ast_to_code(refactored_code) - - def split_class(self, class_def, methods): - """ - Split the large class into smaller classes based on methods. - - :param class_def: The class definition node. - :param methods: The list of methods in the class. - :return: A list of new class definitions. - """ - # For demonstration, we'll simply create two classes based on the method count - half_index = len(methods) // 2 - new_class1 = self.create_new_class(class_def.name + "Part1", methods[:half_index]) - new_class2 = self.create_new_class(class_def.name + "Part2", methods[half_index:]) - - return [new_class1, new_class2] - - def create_new_class(self, new_class_name, methods): - """ - Create a new class definition with the specified methods. - - :param new_class_name: Name of the new class. - :param methods: List of methods to include in the new class. - :return: A new class definition node. - """ - # Create the class definition with methods - class_def = ast.ClassDef( - name=new_class_name, - bases=[], - body=methods, - decorator_list=[] - ) - return class_def - - def ast_to_code(self, nodes): - """ - Convert AST nodes back to source code. - - :param nodes: The AST nodes to convert. - :return: The source code as a string. - """ - import astor - return astor.to_source(nodes) diff --git a/src/refactorer/long_method_refactorer.py b/src/refactorer/long_method_refactorer.py deleted file mode 100644 index 459a32e4..00000000 --- a/src/refactorer/long_method_refactorer.py +++ /dev/null @@ -1,14 +0,0 @@ -from .base_refactorer import BaseRefactorer - -class LongMethodRefactorer(BaseRefactorer): - """ - Refactorer that targets long methods to improve readability. - """ - - def refactor(self): - """ - Refactor long methods into smaller methods. - Implement the logic to detect and refactor long methods. - """ - # Logic to identify long methods goes here - pass diff --git a/src/testing/test_runner.py b/src/testing/test_runner.py deleted file mode 100644 index 84fe92a9..00000000 --- a/src/testing/test_runner.py +++ /dev/null @@ -1,17 +0,0 @@ -import unittest -import os -import sys - -# Add the src directory to the path to import modules -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) - -# Discover and run all tests in the 'tests' directory -def run_tests(): - test_loader = unittest.TestLoader() - test_suite = test_loader.discover('tests', pattern='*.py') - - test_runner = unittest.TextTestRunner(verbosity=2) - test_runner.run(test_suite) - -if __name__ == '__main__': - run_tests() diff --git a/src/testing/test_validator.py b/src/testing/test_validator.py deleted file mode 100644 index cbbb29d4..00000000 --- a/src/testing/test_validator.py +++ /dev/null @@ -1,3 +0,0 @@ -def validate_output(original, refactored): - # Compare original and refactored output - return original == refactored diff --git a/src/utils/logger.py b/src/utils/logger.py deleted file mode 100644 index 711c62b5..00000000 --- a/src/utils/logger.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import os - -def setup_logger(log_file: str = "app.log", log_level: int = logging.INFO): - """ - Set up the logger configuration. - - Args: - log_file (str): The name of the log file to write logs to. - log_level (int): The logging level (default is INFO). - - Returns: - Logger: Configured logger instance. - """ - # Create log directory if it does not exist - log_directory = os.path.dirname(log_file) - if log_directory and not os.path.exists(log_directory): - os.makedirs(log_directory) - - # Configure the logger - logging.basicConfig( - filename=log_file, - filemode='a', # Append mode - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - level=log_level, - ) - - logger = logging.getLogger(__name__) - return logger - -# # Example usage -# if __name__ == "__main__": -# logger = setup_logger() # You can customize the log file and level here -# logger.info("Logger is set up and ready to use.") diff --git a/test/test_analyzer.py b/test/test_analyzer.py deleted file mode 100644 index 3f522dd4..00000000 --- a/test/test_analyzer.py +++ /dev/null @@ -1,12 +0,0 @@ -# import unittest -# from src.analyzer.pylint_analyzer import PylintAnalyzer - -# class TestPylintAnalyzer(unittest.TestCase): -# def test_analyze_method(self): -# analyzer = PylintAnalyzer("path/to/test/code.py") -# report = analyzer.analyze() -# self.assertIsInstance(report, list) # Check if the output is a list -# # Add more assertions based on expected output - -# if __name__ == "__main__": -# unittest.main() diff --git a/test/test_end_to_end.py b/test/test_end_to_end.py deleted file mode 100644 index bef67b8e..00000000 --- a/test/test_end_to_end.py +++ /dev/null @@ -1,16 +0,0 @@ -import unittest - -class TestEndToEnd(unittest.TestCase): - """ - End-to-end tests for the full refactoring flow. - """ - - def test_refactor_flow(self): - """ - Test the complete flow from analysis to refactoring. - """ - # Implement the test logic here - self.assertTrue(True) # Placeholder for actual test - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_energy_measure.py b/test/test_energy_measure.py deleted file mode 100644 index 00d381c6..00000000 --- a/test/test_energy_measure.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -from src.measurement.energy_meter import EnergyMeter - -class TestEnergyMeter(unittest.TestCase): - """ - Unit tests for the EnergyMeter class. - """ - - def test_measurement(self): - """ - Test starting and stopping energy measurement. - """ - meter = EnergyMeter() - meter.start_measurement() - # Logic to execute code - result = meter.stop_measurement() - self.assertIsNotNone(result) # Check that a result is produced - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_refactorer.py b/test/test_refactorer.py deleted file mode 100644 index af992428..00000000 --- a/test/test_refactorer.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest -from src.refactorer.long_method_refactorer import LongMethodRefactorer -from src.refactorer.large_class_refactorer import LargeClassRefactorer -from src.refactorer.complex_list_comprehension_refactorer import ComplexListComprehensionRefactorer - -class TestRefactorers(unittest.TestCase): - """ - Unit tests for various refactorers. - """ - - def test_refactor_long_method(self): - """ - Test the refactor method of the LongMethodRefactorer. - """ - original_code = """ - def long_method(): - # A long method with too many lines of code - a = 1 - b = 2 - c = a + b - # More complex logic... - return c - """ - expected_refactored_code = """ - def long_method(): - result = calculate_result() - return result - - def calculate_result(): - a = 1 - b = 2 - return a + b - """ - refactorer = LongMethodRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - - def test_refactor_large_class(self): - """ - Test the refactor method of the LargeClassRefactorer. - """ - original_code = """ - class LargeClass: - def method1(self): - # Method 1 - pass - - def method2(self): - # Method 2 - pass - - def method3(self): - # Method 3 - pass - - # ... many more methods ... - """ - expected_refactored_code = """ - class LargeClass: - def method1(self): - # Method 1 - pass - - class AnotherClass: - def method2(self): - # Method 2 - pass - - def method3(self): - # Method 3 - pass - """ - refactorer = LargeClassRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - - def test_refactor_complex_list_comprehension(self): - """ - Test the refactor method of the ComplexListComprehensionRefactorer. - """ - original_code = """ - def complex_list(): - return [x**2 for x in range(10) if x % 2 == 0 and x > 3] - """ - expected_refactored_code = """ - def complex_list(): - result = [] - for x in range(10): - if x % 2 == 0 and x > 3: - result.append(x**2) - return result - """ - refactorer = ComplexListComprehensionRefactorer(original_code) - result = refactorer.refactor() - self.assertEqual(result.strip(), expected_refactored_code.strip()) - -# Run all tests in the module -if __name__ == "__main__": - unittest.main() diff --git a/test/README.md b/tests/README.md similarity index 100% rename from test/README.md rename to tests/README.md diff --git a/tests/_input_copies/test_2_copy.py b/tests/_input_copies/test_2_copy.py new file mode 100644 index 00000000..4d1f853d --- /dev/null +++ b/tests/_input_copies/test_2_copy.py @@ -0,0 +1,105 @@ +import datetime # unused import + + +class Temp: + + def __init__(self) -> None: + self.unused_class_attribute = True + self.a = 3 + + def temp_function(self): + unused_var = 3 + b = 4 + return self.a + b + + +# LC: Large Class with too many responsibilities +class DataProcessor: + def __init__(self, data): + self.data = data + self.processed_data = [] + + # LM: Long Method - this method does way too much + def process_all_data(self): + results = [] + for item in self.data: + try: + # LPL: Long Parameter List + result = self.complex_calculation( + item, True, False, "multiply", 10, 20, None, "end" + ) + results.append(result) + except ( + Exception + ) as e: # UEH: Unqualified Exception Handling, catching generic exceptions + print("An error occurred:", e) + + # LMC: Long Message Chain + print(self.data[0].upper().strip().replace(" ", "_").lower()) + + # LLF: Long Lambda Function + self.processed_data = list( + filter(lambda x: x != None and x != 0 and len(str(x)) > 1, results) + ) + + return self.processed_data + + # LBCL: Long Base Class List + + +class AdvancedProcessor(DataProcessor): + pass + + # LTCE: Long Ternary Conditional Expression + def check_data(self, item): + return ( + True if item > 10 else False if item < -10 else None if item == 0 else item + ) + + # Complex List Comprehension + def complex_comprehension(self): + # CLC: Complex List Comprehension + self.processed_data = [ + x**2 if x % 2 == 0 else x**3 + for x in range(1, 100) + if x % 5 == 0 and x != 50 and x > 3 + ] + + # Long Element Chain + def long_chain(self): + # LEC: Long Element Chain accessing deeply nested elements + try: + deep_value = self.data[0][1]["details"]["info"]["more_info"][2]["target"] + return deep_value + except KeyError: + return None + + # Long Scope Chaining (LSC) + def long_scope_chaining(self): + for a in range(10): + for b in range(10): + for c in range(10): + for d in range(10): + for e in range(10): + if a + b + c + d + e > 25: + return "Done" + + # LPL: Long Parameter List + def complex_calculation( + self, item, flag1, flag2, operation, threshold, max_value, option, final_stage + ): + if operation == "multiply": + result = item * threshold + elif operation == "add": + result = item + max_value + else: + result = item + return result + + +# Main method to execute the code +if __name__ == "__main__": + sample_data = [1, 2, 3, 4, 5] + processor = DataProcessor(sample_data) + processed = processor.process_all_data() + print("Processed Data:", processed) diff --git a/tests/analyzers/test_detect_lec.py b/tests/analyzers/test_detect_lec.py new file mode 100644 index 00000000..d6d63cb5 --- /dev/null +++ b/tests/analyzers/test_detect_lec.py @@ -0,0 +1,300 @@ +import ast +from pathlib import Path +import textwrap +import pytest + +from ecooptimizer.analyzers.ast_analyzers.detect_long_element_chain import detect_long_element_chain +from ecooptimizer.data_types.smell import LECSmell + + +@pytest.fixture +def temp_file(tmp_path): + """Create a temporary file for testing.""" + file_path = tmp_path / "test_code.py" + return file_path + + +def parse_code(code_str): + """Parse code string into an AST.""" + return ast.parse(code_str) + + +def test_no_chains(temp_file): + """Test with code that has no chains.""" + code = textwrap.dedent(""" + a = 1 + b = 2 + c = a + b + d = {'key': 'value'} + e = d['key'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree) + + assert len(result) == 0 + + +def test_chains_below_threshold(temp_file): + """Test with chains shorter than threshold.""" + code = textwrap.dedent(""" + a = {'key1': {'key2': 'value'}} + b = a['key1']['key2'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 5 + result = detect_long_element_chain(temp_file, tree, 5) + + assert len(result) == 0 + + +def test_chains_at_threshold(temp_file): + """Test with chains exactly at threshold.""" + code = textwrap.dedent(""" + a = {'key1': {'key2': {'key3': 'value'}}} + b = a['key1']['key2']['key3'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 3 + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + assert result[0].messageId == "LEC001" + assert result[0].symbol == "long-element-chain" + assert result[0].occurences[0].line == 3 # Line 3 in the code + + +def test_chains_above_threshold(temp_file): + """Test with chains longer than threshold.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': {'d': 'value'}}}} + result = data['a']['b']['c']['d'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + # Using threshold of 3 + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + assert "Dictionary chain too long (4/3)" in result[0].message + + +def test_multiple_chains(temp_file): + """Test with multiple chains in the same file.""" + code = textwrap.dedent(""" + data1 = {'a': {'b': {'c': 'value1'}}} + data2 = {'x': {'y': {'z': 'value2'}}} + + result1 = data1['a']['b']['c'] + result2 = data2['x']['y']['z'] + + # Some other code without chains + a = 1 + b = 2 + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + assert result[0].occurences[0].line != result[1].occurences[0].line + + +def test_nested_functions_with_chains(temp_file): + """Test chains inside nested functions and classes.""" + code = textwrap.dedent(""" + def outer_function(): + data = {'a': {'b': {'c': 'value'}}} + + def inner_function(): + return data['a']['b']['c'] + + return inner_function() + + class TestClass: + def method(self): + obj = {'x': {'y': {'z': {'deep': 'nested'}}}} + return obj['x']['y']['z']['deep'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + # Check that we detected the chain in both locations + + +def test_same_line_reported_once(temp_file): + """Test that chains on the same line are reported only once.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value1'}}} + # Two identical chains on the same line + result1, result2 = data['a']['b']['c'], data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 2) + + assert len(result) == 1 + + assert result[0].occurences[0].line == 4 + + +def test_variable_types_chains(temp_file): + """Test chains with different variable types.""" + code = textwrap.dedent(""" + # List within dict chain + data1 = {'a': [{'b': {'c': 'value'}}]} + result1 = data1['a'][0]['b']['c'] + + # Tuple with dict chain + data2 = {'x': ({'y': {'z': 'value'}},)} + result2 = data2['x'][0]['y']['z'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 2 + + +def test_custom_threshold(temp_file): + """Test with a custom threshold value.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value'}}} + result = data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + + # With threshold of 4, no chains should be detected + result1 = detect_long_element_chain(temp_file, tree, 4) + assert len(result1) == 0 + + # With threshold of 2, the chain should be detected + result2 = detect_long_element_chain(temp_file, tree, 2) + assert len(result2) == 1 + assert "Dictionary chain too long (3/2)" in result2[0].message + + +def test_result_structure(temp_file): + """Test the structure of the returned LECSmell object.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 'value'}}} + result = data['a']['b']['c'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 1 + smell = result[0] + + # Verify it's the correct type + assert isinstance(smell, LECSmell) + + # Check required fields + assert smell.path == str(temp_file) + assert smell.module == temp_file.stem + assert smell.type == "convention" + assert smell.symbol == "long-element-chain" + assert "Dictionary chain too long" in smell.message + + # Check occurrence details + assert len(smell.occurences) == 1 + assert smell.occurences[0].line == 3 + assert smell.occurences[0].column is not None + assert smell.occurences[0].endLine is not None + assert smell.occurences[0].endColumn is not None + + # Verify additional info exists + assert hasattr(smell, "additionalInfo") + + +def test_complex_expressions(temp_file): + """Test chains within complex expressions.""" + code = textwrap.dedent(""" + data = {'a': {'b': {'c': 5}}} + + # Chain in an arithmetic expression + result1 = data['a']['b']['c'] + 10 + + # Chain in a function call + def my_func(x): + return x * 2 + + result2 = my_func(data['a']['b']['c']) + + # Chain in a comprehension + result3 = [i * data['a']['b']['c'] for i in range(5)] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 3) + + assert len(result) == 3 # Should detect all three chains + + +def test_edge_case_empty_file(temp_file): + """Test with an empty file.""" + code = "" + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree) + + assert len(result) == 0 + + +def test_edge_case_threshold_one(temp_file): + """Test with threshold of 1 (every subscript would be reported).""" + code = textwrap.dedent(""" + data1 = {'a': [{'b': {'c': {'d': 'value'}}}]} + result1 = data1['a'][0]['b']['c']['d'] + """) + + with Path.open(temp_file, "w") as f: + f.write(code) + + tree = parse_code(code) + result = detect_long_element_chain(temp_file, tree, 5) + + assert len(result) == 1 + assert "Dictionary chain too long (5/5)" in result[0].message diff --git a/tests/analyzers/test_long_lambda_element.py b/tests/analyzers/test_long_lambda_element.py new file mode 100644 index 00000000..4306b0f3 --- /dev/null +++ b/tests/analyzers/test_long_lambda_element.py @@ -0,0 +1,178 @@ +import ast +import textwrap +from pathlib import Path +from unittest.mock import patch + +from ecooptimizer.data_types.smell import LLESmell +from ecooptimizer.analyzers.ast_analyzers.detect_long_lambda_expression import ( + detect_long_lambda_expression, +) + + +def test_no_lambdas(): + """Ensures no smells are detected when no lambda is present.""" + code = textwrap.dedent( + """ + def example(): + x = 42 + return x + 1 + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 + + +def test_short_single_lambda(): + """ + A single short lambda (well under length=100) + and only one expression -> should NOT be flagged. + """ + code = textwrap.dedent( + """ + def example(): + f = lambda x: x + 1 + return f(5) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 0 + + +def test_lambda_exceeds_expr_count(): + """ + Long lambda due to too many expressions + In the AST, this breaks down as: + (x + 1 if x > 0 else 0) -> ast.IfExp (expression #1) + abs(x) * 2 -> ast.BinOp (Call inside it) (expression #2) + min(x, 5) -> ast.Call (expression #3) + """ + code = textwrap.dedent( + """ + def example(): + func = lambda x: (x + 1 if x > 0 else 0) + (x * 2 if x < 5 else 5) + abs(x) + return func(4) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to expression count" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_char_length(): + """ + Exceeds threshold_length=100 by using a very long expression in the lambda. + """ + long_str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 4 + code = textwrap.dedent( + f""" + def example(): + func = lambda x: x + "{long_str}" + return func("test") + """ + ) + # exceeds 100 char + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to character length" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_both_thresholds(): + """ + Both too many chars and too many expressions + """ + code = textwrap.dedent( + """ + def example(): + giant_lambda = lambda a, b, c: (a + b if a > b else b - c) + (max(a, b, c) * 10) + (min(a, b, c) / 2) + ("hello" + "world") + return giant_lambda(1,2,3) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + # one smell per line + assert len(smells) >= 1 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_nested(): + """ + Nested lambdas inside one function. + # outer and inner detected + """ + code = textwrap.dedent( + """ + def example(): + outer = lambda x: (x ** 2) + (lambda y: y + 10)(x) + # inner = lambda y: y + 10 is short, but let's make it long + # We'll artificially make it a big expression + inner = lambda a, b: (a + b if a > 0 else 0) + (a * b) + (b - a) + return outer(5) + inner(3,4) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), ast.parse(code), threshold_length=80, threshold_count=3 + ) + # inner and outter + assert len(smells) == 2 + assert isinstance(smells[0], LLESmell) + + +def test_lambda_inline_passed_to_function(): + """ + Lambdas passed inline to a function: sum(map(...)) or filter(..., lambda). + """ + code = textwrap.dedent( + """ + def test_lambdas(): + result = map(lambda x: x*2 + (x//3) if x > 10 else x, range(20)) + + # This lambda has a ternary, but let's keep it short enough + # that it doesn't trigger by default unless threshold_count=2 or so. + # We'll push it with a second ternary + more code to reach threshold_count=3 + + result2 = filter(lambda z: (z+1 if z < 5 else z-1) + (z*3 if z%2==0 else z/2) and z != 0, result) + + return list(result2) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + # 2 smells + assert len(smells) == 2 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_no_body_too_short(): + """ + A degenerate case: a lambda that has no real body or is trivially short. + Should produce 0 smells even if it's spread out. + """ + code = textwrap.dedent( + """ + def example(): + trivial = lambda: None + return trivial() + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 diff --git a/tests/analyzers/test_long_lambda_function.py b/tests/analyzers/test_long_lambda_function.py new file mode 100644 index 00000000..4306b0f3 --- /dev/null +++ b/tests/analyzers/test_long_lambda_function.py @@ -0,0 +1,178 @@ +import ast +import textwrap +from pathlib import Path +from unittest.mock import patch + +from ecooptimizer.data_types.smell import LLESmell +from ecooptimizer.analyzers.ast_analyzers.detect_long_lambda_expression import ( + detect_long_lambda_expression, +) + + +def test_no_lambdas(): + """Ensures no smells are detected when no lambda is present.""" + code = textwrap.dedent( + """ + def example(): + x = 42 + return x + 1 + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 + + +def test_short_single_lambda(): + """ + A single short lambda (well under length=100) + and only one expression -> should NOT be flagged. + """ + code = textwrap.dedent( + """ + def example(): + f = lambda x: x + 1 + return f(5) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 0 + + +def test_lambda_exceeds_expr_count(): + """ + Long lambda due to too many expressions + In the AST, this breaks down as: + (x + 1 if x > 0 else 0) -> ast.IfExp (expression #1) + abs(x) * 2 -> ast.BinOp (Call inside it) (expression #2) + min(x, 5) -> ast.Call (expression #3) + """ + code = textwrap.dedent( + """ + def example(): + func = lambda x: (x + 1 if x > 0 else 0) + (x * 2 if x < 5 else 5) + abs(x) + return func(4) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to expression count" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_char_length(): + """ + Exceeds threshold_length=100 by using a very long expression in the lambda. + """ + long_str = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" * 4 + code = textwrap.dedent( + f""" + def example(): + func = lambda x: x + "{long_str}" + return func("test") + """ + ) + # exceeds 100 char + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + assert len(smells) == 1, "Expected smell due to character length" + assert isinstance(smells[0], LLESmell) + + +def test_lambda_exceeds_both_thresholds(): + """ + Both too many chars and too many expressions + """ + code = textwrap.dedent( + """ + def example(): + giant_lambda = lambda a, b, c: (a + b if a > b else b - c) + (max(a, b, c) * 10) + (min(a, b, c) / 2) + ("hello" + "world") + return giant_lambda(1,2,3) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), + ast.parse(code), + ) + # one smell per line + assert len(smells) >= 1 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_nested(): + """ + Nested lambdas inside one function. + # outer and inner detected + """ + code = textwrap.dedent( + """ + def example(): + outer = lambda x: (x ** 2) + (lambda y: y + 10)(x) + # inner = lambda y: y + 10 is short, but let's make it long + # We'll artificially make it a big expression + inner = lambda a, b: (a + b if a > 0 else 0) + (a * b) + (b - a) + return outer(5) + inner(3,4) + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression( + Path("fake.py"), ast.parse(code), threshold_length=80, threshold_count=3 + ) + # inner and outter + assert len(smells) == 2 + assert isinstance(smells[0], LLESmell) + + +def test_lambda_inline_passed_to_function(): + """ + Lambdas passed inline to a function: sum(map(...)) or filter(..., lambda). + """ + code = textwrap.dedent( + """ + def test_lambdas(): + result = map(lambda x: x*2 + (x//3) if x > 10 else x, range(20)) + + # This lambda has a ternary, but let's keep it short enough + # that it doesn't trigger by default unless threshold_count=2 or so. + # We'll push it with a second ternary + more code to reach threshold_count=3 + + result2 = filter(lambda z: (z+1 if z < 5 else z-1) + (z*3 if z%2==0 else z/2) and z != 0, result) + + return list(result2) + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + # 2 smells + assert len(smells) == 2 + assert all(isinstance(smell, LLESmell) for smell in smells) + + +def test_lambda_no_body_too_short(): + """ + A degenerate case: a lambda that has no real body or is trivially short. + Should produce 0 smells even if it's spread out. + """ + code = textwrap.dedent( + """ + def example(): + trivial = lambda: None + return trivial() + """ + ) + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_lambda_expression(Path("fake.py"), ast.parse(code)) + assert len(smells) == 0 diff --git a/tests/analyzers/test_long_message_chain.py b/tests/analyzers/test_long_message_chain.py new file mode 100644 index 00000000..52326c4e --- /dev/null +++ b/tests/analyzers/test_long_message_chain.py @@ -0,0 +1,352 @@ +import ast +import textwrap +from pathlib import Path +from unittest.mock import patch + +from ecooptimizer.data_types.smell import LMCSmell +from ecooptimizer.analyzers.ast_analyzers.detect_long_message_chain import ( + detect_long_message_chain, +) + +# NOTE: The default threshold is 5. That means a chain of 5 or more consecutive calls will be flagged. + + +def test_detects_exact_five_calls_chain(): + """Detects a chain with exactly five method calls.""" + code = textwrap.dedent( + """ + def example(): + details = "some text" + details.upper().lower().capitalize().replace("|", "-").strip() + """ + ) + + # This chain has 5 calls: upper -> lower -> capitalize -> replace -> strip + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected exactly one smell for a chain of length 5" + assert isinstance(smells[0], LMCSmell) + assert "Method chain too long" in smells[0].message + assert smells[0].occurences[0].line == 4 + + +def test_detects_six_calls_chain(): + """Detects a chain with six method calls, definitely flagged.""" + code = textwrap.dedent( + """ + def example(): + details = "some text" + details.upper().lower().upper().capitalize().upper().replace("|", "-") + """ + ) + + # This chain has 6 calls: upper -> lower -> upper -> capitalize -> upper -> replace + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected exactly one smell for a chain of length 6" + assert isinstance(smells[0], LMCSmell) + assert "Method chain too long" in smells[0].message + assert smells[0].occurences[0].line == 4 + + +def test_ignores_chain_of_four_calls(): + """Ensures a chain with only four calls is NOT flagged (below threshold).""" + code = textwrap.dedent( + """ + def example(): + text = "some-other" + text.strip().lower().replace("-", "_").title() + """ + ) + + # This chain has 4 calls: strip -> lower -> replace -> title + # The default threshold is 5, so it should not be detected. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Chain of length 4 should NOT be flagged" + + +def test_detects_chain_with_attributes_and_calls(): + """Detects a long chain that involves both attribute and method calls.""" + code = textwrap.dedent( + """ + class Sample: + def __init__(self): + self.details = "some text".upper() + def method(self): + # below is a chain with 5 steps: + # self.details -> lower() -> capitalize() -> isalpha() -> bit_length() + # isalpha() returns bool, bit_length() is from int => means chain length is still counted. + return self.details.upper().lower().capitalize().isalpha().bit_length() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + # Because we have 5 method calls, it should be flagged. + assert len(smells) == 1, "Expected one smell for chain of length >= 5" + assert isinstance(smells[0], LMCSmell) + + +def test_detects_chain_inside_loop(): + """Detects a chain inside a loop that meets the threshold.""" + code = textwrap.dedent( + """ + def loop_chain(data_list): + for item in data_list: + item.strip().replace("-", "_").split("_").index("some") + """ + ) + + # Calls: strip -> replace -> split -> index = 4 calls total. + # add to 5 + code = code.replace('index("some")', 'index("some").upper()') + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected smell for chain length 5" + assert isinstance(smells[0], LMCSmell) + + +def test_multiple_chains_one_line(): + """Detect multiple separate long chains on the same line. Should only report 1 smell, the first chain""" + code = textwrap.dedent( + """ + def combo(): + details = "some text" + other = "other text" + details.lower().title().replace("|", "-").upper().split("-"); other.upper().lower().capitalize().zfill(10).replace("xyz", "abc") + """ + ) + + # On line 5, we have two separate chains: + # 1) details -> lower -> title -> replace -> upper -> split => 5 calls. + # 2) other -> upper -> lower -> capitalize -> zfill -> replace => 5 calls. + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + # The function logic says it only reports once per line. So we expect 1 smell, not 2. + assert len(smells) == 1, "Both chains on the same line => single smell reported" + assert "Method chain too long" in smells[0].message + + +def test_ignores_separate_statements(): + """Ensures that separate statements with fewer calls each are not combined into one chain.""" + code = textwrap.dedent( + """ + def example(): + details = "some-other" + data = details.upper() + data = data.lower() + data = data.capitalize() + data = data.replace("|", "-") + data = data.title() + """ + ) + + # Each statement individually has only 1 call. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "No single chain of length >= 5 in separate statements" + + +def test_ignores_short_chain_comprehension(): + """Ensures short chain in a comprehension doesn't get flagged.""" + code = textwrap.dedent( + """ + def short_comp(lst): + return [item.replace("-", "_").lower() for item in lst] + """ + ) + + # Only 2 calls in the chain: replace -> lower. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0 + + +def test_detects_long_chain_comprehension(): + """Detects a long chain in a list comprehension.""" + code = textwrap.dedent( + """ + def long_comp(lst): + return [item.upper().lower().capitalize().strip().replace("|", "-") for item in lst] + """ + ) + + # 5 calls in the chain: upper -> lower -> capitalize -> strip -> replace. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected one smell for chain of length 5" + assert isinstance(smells[0], LMCSmell) + + +def test_five_separate_long_chains(): + """ + Five distinct lines in a single function, each with a chain of exactly 5 calls. + Expect 5 separate smells (assuming you record each line). + """ + code = textwrap.dedent( + """ + def combo(): + data = "text" + data.upper().lower().capitalize().replace("|", "-").split("|") + data.capitalize().replace("|", "-").strip().upper().title() + data.lower().upper().replace("|", "-").strip().title() + data.strip().replace("|", "_").split("_").capitalize().title() + data.replace("|", "-").upper().lower().capitalize().title() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 5, "Expected 5 smells" + assert isinstance(smells[0], LMCSmell) + + +def test_element_access_chain_no_calls(): + """ + A chain of attributes and index lookups only, no parentheses (no actual calls). + Some detectors won't flag this unless they specifically count attribute hops. + """ + code = textwrap.dedent( + """ + def get_nested(nested): + return nested.a.b.c[3][0].x.y + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" + + +def test_chain_with_slicing(): + """ + Demonstrates slicing as part of the chain. + e.g. `text[2:7]` -> `.replace()` -> `.upper()` ... + """ + code = textwrap.dedent( + """ + def slice_chain(text): + return text[2:7].replace("abc", "xyz").upper().strip().split("-").lower() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_multiline_chain(): + """ + A chain split over multiple lines using parentheses or backslash. + The AST should still see them as a continuous chain of calls. + """ + code = textwrap.dedent( + """ + def multiline_chain(): + var = "some text"\\ + .replace(" ", "-")\\ + .lower()\\ + .title()\\ + .strip()\\ + .upper() + """ + ) + + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_chain_in_lambda(): + """ + A chain inside a lambda's body. + """ + code = textwrap.dedent( + """ + def lambda_test(): + func = lambda x: x.upper().strip().replace("-", "_").lower().title() + return func("HELLO-WORLD") + """ + ) + # That’s 5 calls: upper -> strip -> replace -> lower -> title + # Expect 1 chain smell if you're scanning inside lambda bodies. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_mixed_return_types_chain(): + """ + It's 5 calls, with type changes from str to bool to int. + Typical 'chain detection' doesn't care about type. + """ + code = textwrap.dedent( + """ + class TypeMix: + def do_stuff(self): + text = "Hello" + return text.lower().capitalize().isalpha().bit_length().to_bytes(2, 'big') + """ + ) + # That’s 5 calls: lower -> capitalize -> isalpha -> bit_length -> to_bytes + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 1, "Expected 1 smells" + + +def test_multiple_short_chains_same_line(): + """ + Two short chains on the same line, each with 3 calls, but they're separate. + They should not combine into 6, so likely 0 smells if threshold=5. + """ + code = textwrap.dedent( + """ + def short_line(): + x = "abc" + y = "def" + x.upper().replace("A", "Z").strip(); y.lower().replace("d", "x").title() + """ + ) + # Each chain is 3 calls, so if threshold is 5, expect 0 smells. + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" + + +def test_conditional_chain(): + """ + A chain inside an inline if/else expression (ternary). + The question: do we see it as a single chain? Usually yes, but only if we actually parse it as an ast.Call chain. + """ + code = textwrap.dedent( + """ + def cond_chain(cond): + text = "some text" + return (text.lower().replace(" ", "_").strip().upper() if cond + else text.upper().replace(" ", "|").lower().split("|")) + """ + ) + # code shouldnt lump them together + with patch.object(Path, "read_text", return_value=code): + smells = detect_long_message_chain(Path("fake.py"), ast.parse(code)) + + assert len(smells) == 0, "Expected 0 smells" diff --git a/tests/analyzers/test_str_concat_in_loop.py b/tests/analyzers/test_str_concat_in_loop.py new file mode 100644 index 00000000..15b9f11d --- /dev/null +++ b/tests/analyzers/test_str_concat_in_loop.py @@ -0,0 +1,542 @@ +from pathlib import Path +from astroid import parse +from unittest.mock import patch + +from ecooptimizer.data_types.smell import SCLSmell +from ecooptimizer.analyzers.astroid_analyzers.detect_string_concat_in_loop import ( + detect_string_concat_in_loop, +) + +# === Basic Concatenation Cases === + + +def test_detects_simple_for_loop_concat(): + """Detects += string concatenation inside a for loop.""" + code = """ + def test(): + result = "" + for i in range(10): + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_simple_assign_loop_concat(): + """Detects string concatenation inside a loop.""" + code = """ + def test(): + result = "" + for i in range(10): + result = result + str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_simple_while_loop_concat(): + """Detects += string concatenation inside a while loop.""" + code = """ + def test(): + result = "" + while i < 10: + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_list_attribute_concat(): + """Detects += modifying a list item inside a loop.""" + code = """ + class Test: + def __init__(self): + self.text = [""] * 5 + def update(self): + for i in range(5): + self.text[0] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "self.text[0]" + assert smells[0].additionalInfo.innerLoopLine == 6 + + +def test_detects_object_attribute_concat(): + """Detects += modifying an object attribute inside a loop.""" + code = """ + class Test: + def __init__(self): + self.text = "" + def update(self): + for i in range(5): + self.text += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "self.text" + assert smells[0].additionalInfo.innerLoopLine == 6 + + +def test_detects_dict_value_concat(): + """Detects += modifying a dictionary value inside a loop.""" + code = """ + def test(): + data = {"key": ""} + for i in range(5): + data["key"] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + # astroid changes double quotes to singles + assert smells[0].additionalInfo.concatTarget == "data['key']" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_multi_loop_concat(): + """Detects multiple separate string concats in a loop.""" + code = """ + def test(): + result = "" + logs = [""] * 4 + for i in range(10): + result += str(i) + logs[0] += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 2 + assert all(isinstance(smell, SCLSmell) for smell in smells) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 5 + + assert len(smells[1].occurences) == 1 + assert smells[1].additionalInfo.concatTarget == "logs[0]" + assert smells[1].additionalInfo.innerLoopLine == 5 + + +def test_detects_reset_loop_concat(): + """Detects string concats with re-assignments inside the loop.""" + code = """ + def reset(): + result = '' + for i in range(5): + result += "Iteration: " + str(i) + if i == 2: + result = "" # Resetting `result` + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Nested Loop Cases === + + +def test_detects_nested_loop_concat(): + """Detects concatenation inside nested loops.""" + code = """ + def test(): + result = "" + for i in range(3): + for j in range(3): + result += str(j) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 5 + + +def test_detects_complex_nested_loop_concat(): + """Detects multi level concatenations belonging to the same smell.""" + code = """ + def super_complex(): + result = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Conditional Cases === + + +def test_detects_if_else_concat(): + """Detects += inside an if-else condition within a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "even" + else: + result += "odd" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === String Interpolation Cases === + + +def test_detects_f_string_concat(): + """Detects += using f-strings inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += f"{i}" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_percent_format_concat(): + """Detects += using % formatting inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += "%d" % i + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_str_format_concat(): + """Detects += using .format() inside a loop.""" + code = """ + def test(): + result = "" + for i in range(5): + result += "{}".format(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === False Positives (Should NOT Detect) === + + +def test_ignores_access_inside_loop(): + """Ensures that accessing the concatenation variable inside the loop is NOT flagged.""" + code = """ + def test(): + result = "" + for i in range(5): + print(result) # Accessing result mid-loop + result += str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_regular_str_assign_inside_loop(): + """Ensures that regular string assignments are NOT flagged.""" + code = """ + def test(): + result = "" + for i in range(5): + result = str(i) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_number_addition_inside_loop(): + """Ensures number operations with the += format are NOT flagged.""" + code = """ + def test(): + num = 1 + for i in range(5): + num += i + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_ignores_concat_outside_loop(): + """Ensures that string concatenation OUTSIDE a loop is NOT flagged.""" + code = """ + def test(): + result = "" + part1 = "Hello" + part2 = "World" + result = result + part1 + part2 + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +# === Edge Cases === + + +def test_detects_sequential_concat(): + """Detects a variable concatenated multiple times in the same loop iteration.""" + code = """ + def test(): + result = "" + for i in range(5): + result += str(i) + result += "-" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 2 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_concat_with_prefix_and_suffix(): + """Detects concatenation where both prefix and suffix are added.""" + code = """ + def test(): + result = "" + for i in range(5): + result = "prefix-" + result + "-suffix" + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_prepend_concat(): + """Detects += where new values are inserted at the beginning instead of the end.""" + code = """ + def test(): + result = "" + for i in range(5): + result = str(i) + result + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +# === Typing Cases === + + +def test_ignores_unknown_type(): + """Ignores potential smells where type cannot be confirmed as a string.""" + code = """ + def test(a, b): + result = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 0 + + +def test_detects_param_type_hint_concat(): + """Detects string concat where type is inferrred from the FunctionDef type hints.""" + code = """ + def test(a: str, b: str): + result = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a, b) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_var_type_hint_concat(): + """Detects string concats where the type is inferred from an assign type hint.""" + code = """ + def test(a, b): + result: str = a + for i in range(5): + result = result + b + + a = "Hello" + b = "world" + test(a, b) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 + + +def test_detects_cls_attr_type_hint_concat(): + """Detects string concats where type is inferred from class attributes.""" + code = """ + class Test: + + def __init__(self): + self.text = "word" + + def test(self, a): + result = a + for i in range(5): + result = result + self.text + + a = Test() + a.test("this ") + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 9 + + +def test_detects_inferred_str_type_concat(): + """Detects string concat where type is inferred from the initial value assigned.""" + code = """ + def test(a): + result = "" + for i in range(5): + result = a + result + + a = "hello" + test(a) + """ + with patch.object(Path, "read_text", return_value=code): + smells = detect_string_concat_in_loop(Path("fake.py"), parse(code)) + + assert len(smells) == 1 + assert isinstance(smells[0], SCLSmell) + + assert len(smells[0].occurences) == 1 + assert smells[0].additionalInfo.concatTarget == "result" + assert smells[0].additionalInfo.innerLoopLine == 4 diff --git a/tests/api/test_detect_route.py b/tests/api/test_detect_route.py new file mode 100644 index 00000000..21faf8ec --- /dev/null +++ b/tests/api/test_detect_route.py @@ -0,0 +1,81 @@ +from pathlib import Path +from fastapi.testclient import TestClient +from unittest.mock import patch + +from ecooptimizer.api.app import app +from ecooptimizer.data_types import Smell +from ecooptimizer.data_types.custom_fields import Occurence + +client = TestClient(app) + + +def get_mock_smell(): + return Smell( + confidence="UNKNOWN", + message="This is a message", + messageId="smellID", + module="module", + obj="obj", + path="fake_path.py", + symbol="smell-symbol", + type="type", + occurences=[ + Occurence( + line=9, + endLine=999, + column=999, + endColumn=999, + ) + ], + ) + + +def test_detect_smells_success(): + request_data = { + "file_path": "fake_path.py", + "enabled_smells": ["smell1", "smell2"], + } + + with patch("pathlib.Path.exists", return_value=True): + with patch( + "ecooptimizer.analyzers.analyzer_controller.AnalyzerController.run_analysis" + ) as mock_run_analysis: + mock_run_analysis.return_value = [get_mock_smell(), get_mock_smell()] + + response = client.post("/smells", json=request_data) + + assert response.status_code == 200 + assert len(response.json()) == 2 + + +def test_detect_smells_file_not_found(): + request_data = { + "file_path": "path/to/nonexistent/file.py", + "enabled_smells": ["smell1", "smell2"], + } + + response = client.post("/smells", json=request_data) + + assert response.status_code == 404 + assert ( + response.json()["detail"] + == f'File not found: {Path("path","to","nonexistent", "file.py")!s}' + ) + + +def test_detect_smells_internal_server_error(): + request_data = { + "file_path": "fake_path.py", + "enabled_smells": ["smell1", "smell2"], + } + + with patch("pathlib.Path.exists", return_value=True): + with patch( + "ecooptimizer.analyzers.analyzer_controller.AnalyzerController.run_analysis" + ) as mock_run_analysis: + mock_run_analysis.side_effect = Exception("Internal error") + + response = client.post("/smells", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal server error" diff --git a/tests/api/test_refactor_route.py b/tests/api/test_refactor_route.py new file mode 100644 index 00000000..79a81155 --- /dev/null +++ b/tests/api/test_refactor_route.py @@ -0,0 +1,157 @@ +# ruff: noqa: PT004 +import pytest + +import shutil +from pathlib import Path +from typing import Any +from collections.abc import Generator +from fastapi.testclient import TestClient +from unittest.mock import patch + + +from ecooptimizer.api.app import app +from ecooptimizer.analyzers.analyzer_controller import AnalyzerController +from ecooptimizer.refactorers.refactorer_controller import RefactorerController + + +client = TestClient(app) + +SAMPLE_SMELL = { + "confidence": "UNKNOWN", + "message": "This is a message", + "messageId": "smellID", + "module": "module", + "obj": "obj", + "path": "fake_path.py", + "symbol": "smell-symbol", + "type": "type", + "occurences": [ + { + "line": 9, + "endLine": 999, + "column": 999, + "endColumn": 999, + } + ], +} + +SAMPLE_SOURCE_DIR = "path\\to\\source_dir" + + +@pytest.fixture(scope="module") +def mock_dependencies() -> Generator[None, Any, None]: + """Fixture to mock all dependencies for the /refactor route.""" + with ( + patch.object(Path, "is_dir"), + patch.object(shutil, "copytree"), + patch.object(shutil, "rmtree"), + patch.object( + RefactorerController, + "run_refactorer", + return_value=[ + Path("path/to/modified_file_1.py"), + Path("path/to/modified_file_2.py"), + ], + ), + patch.object(AnalyzerController, "run_analysis", return_value=[SAMPLE_SMELL]), + patch("tempfile.mkdtemp", return_value="/fake/temp/dir"), + ): + yield + + +def test_refactor_success(mock_dependencies): # noqa: ARG001 + """Test the /refactor route with a successful refactoring process.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, 5.0]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 200 + assert "refactoredData" in response.json() + assert "updatedSmells" in response.json() + assert len(response.json()["updatedSmells"]) == 1 + + +def test_refactor_source_dir_not_found(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when the source directory does not exist.""" + Path.is_dir.return_value = False # type: ignore + + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 404 + assert f"Directory {SAMPLE_SOURCE_DIR} does not exist" in response.json()["detail"] + + +def test_refactor_energy_not_saved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, 15.0]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Energy was not saved" in response.json()["detail"] + + +def test_refactor_initial_energy_not_retrieved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", return_value=None): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Could not retrieve initial emissions" in response.json()["detail"] + + +def test_refactor_final_energy_not_retrieved(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when no energy is saved after refactoring.""" + Path.is_dir.return_value = True # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", side_effect=[10.0, None]): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Could not retrieve final emissions" in response.json()["detail"] + + +def test_refactor_unexpected_error(mock_dependencies): # noqa: ARG001 + """Test the /refactor route when an unexpected error occurs during refactoring.""" + Path.is_dir.return_value = True # type: ignore + RefactorerController.run_refactorer.side_effect = Exception("Mock error") # type: ignore + + with patch("ecooptimizer.api.routes.refactor_smell.measure_energy", return_value=10.0): + request_data = { + "source_dir": SAMPLE_SOURCE_DIR, + "smell": SAMPLE_SMELL, + } + + response = client.post("/refactor", json=request_data) + + assert response.status_code == 400 + assert "Mock error" in response.json()["detail"] diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..10837a56 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +from pathlib import Path +import pytest + + +# ===== FIXTURES ====================== +@pytest.fixture(scope="session") +def output_dir(tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("output") + + +@pytest.fixture(scope="session") +def source_files(tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("input") diff --git a/tests/controllers/test_analyzer_controller.py b/tests/controllers/test_analyzer_controller.py new file mode 100644 index 00000000..fc8523be --- /dev/null +++ b/tests/controllers/test_analyzer_controller.py @@ -0,0 +1,5 @@ +import pytest + + +def test_placeholder(): + pytest.fail("TODO: Implement this test") diff --git a/tests/controllers/test_refactorer_controller.py b/tests/controllers/test_refactorer_controller.py new file mode 100644 index 00000000..9d8222e8 --- /dev/null +++ b/tests/controllers/test_refactorer_controller.py @@ -0,0 +1,147 @@ +from unittest.mock import Mock +import pytest + +from ecooptimizer.data_types.custom_fields import Occurence +from ecooptimizer.refactorers.refactorer_controller import RefactorerController +from ecooptimizer.data_types.smell import LECSmell + + +@pytest.fixture +def mock_refactorer_class(mocker): + mock_class = mocker.Mock() + mock_class.__name__ = "TestRefactorer" + return mock_class + + +@pytest.fixture +def mock_logger(mocker): + logger = Mock() + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": logger}) + return logger + + +@pytest.fixture +def mock_smell(): + """Create a mock smell object for testing.""" + return LECSmell( + confidence="UNDEFINED", + message="Dictionary chain too long (6/4)", + messageId="LEC001", + module="lec_module", + obj="lec_function", + path="path/to/file.py", + symbol="long-element-chain", + type="convention", + occurences=[Occurence(line=10, endLine=10, column=15, endColumn=26)], + additionalInfo=None, + ) + + +def test_run_refactorer_success(mocker, mock_refactorer_class, mock_logger, tmp_path, mock_smell): + # Setup mock refactorer + mock_instance = mock_refactorer_class.return_value + # mock_instance.refactor = Mock() + mock_refactorer_class.return_value = mock_instance + + mock_instance.modified_files = [tmp_path / "modified.py"] + + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + target_file.write_text("print('test content')") # 🚨 Create file with dummy content + + source_dir = tmp_path + + # Execute + modified_files = controller.run_refactorer(target_file, source_dir, mock_smell) + + # Assertions + assert controller.smell_counters["LEC001"] == 1 + mock_logger.info.assert_called_once_with( + "πŸ”„ Running refactoring for long-element-chain using TestRefactorer" + ) + mock_instance.refactor.assert_called_once_with( + target_file, source_dir, mock_smell, mocker.ANY, True + ) + call_args = mock_instance.refactor.call_args + output_path = call_args[0][3] + assert output_path.name == "test_path_LEC001_1.py" + assert modified_files == [tmp_path / "modified.py"] + + +def test_run_refactorer_no_refactorer(mock_logger, mocker, tmp_path, mock_smell): + mocker.patch("ecooptimizer.refactorers.refactorer_controller.get_refactorer", return_value=None) + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + + with pytest.raises(NotImplementedError) as exc_info: + controller.run_refactorer(target_file, source_dir, mock_smell) + + mock_logger.error.assert_called_once_with( + "❌ No refactorer found for smell: long-element-chain" + ) + assert "No refactorer implemented for smell: long-element-chain" in str(exc_info.value) + + +def test_run_refactorer_multiple_calls(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mock_instance.modified_files = [] + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + controller.run_refactorer(target_file, source_dir, smell) + controller.run_refactorer(target_file, source_dir, smell) + + assert controller.smell_counters["LEC001"] == 2 + calls = mock_instance.refactor.call_args_list + assert calls[0][0][3].name == "test_path_LEC001_1.py" + assert calls[1][0][3].name == "test_path_LEC001_2.py" + + +def test_run_refactorer_overwrite_false(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + controller.run_refactorer(target_file, source_dir, smell, overwrite=False) + call_args = mock_instance.refactor.call_args + assert call_args[0][4] is False # overwrite is the fifth argument + + +def test_run_refactorer_empty_modified_files(mocker, mock_refactorer_class, tmp_path, mock_smell): + mock_instance = mock_refactorer_class.return_value + mock_instance.modified_files = [] + mocker.patch( + "ecooptimizer.refactorers.refactorer_controller.get_refactorer", + return_value=mock_refactorer_class, + ) + mocker.patch.dict("ecooptimizer.config.CONFIG", {"refactorLogger": Mock()}) + + controller = RefactorerController() + target_file = tmp_path / "test.py" + source_dir = tmp_path + smell = mock_smell + + modified_files = controller.run_refactorer(target_file, source_dir, smell) + assert modified_files == [] diff --git a/tests/input/__init__.py b/tests/input/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/inefficient_code_example_1.py b/tests/input/inefficient_code_example_1.py new file mode 100644 index 00000000..dae6717c --- /dev/null +++ b/tests/input/inefficient_code_example_1.py @@ -0,0 +1,33 @@ +# Should trigger Use A Generator code smells + +def has_positive(numbers): + # List comprehension inside `any()` - triggers R1729 + return any([num > 0 for num in numbers]) + +def all_non_negative(numbers): + # List comprehension inside `all()` - triggers R1729 + return all([num >= 0 for num in numbers]) + +def contains_large_strings(strings): + # List comprehension inside `any()` - triggers R1729 + return any([len(s) > 10 for s in strings]) + +def all_uppercase(strings): + # List comprehension inside `all()` - triggers R1729 + return all([s.isupper() for s in strings]) + +def contains_special_numbers(numbers): + # List comprehension inside `any()` - triggers R1729 + return any([num % 5 == 0 and num > 100 for num in numbers]) + +def all_lowercase(strings): + # List comprehension inside `all()` - triggers R1729 + return all(s.islower() for s in strings) + +def any_even_numbers(numbers): + # List comprehension inside `any()` - triggers R1729 + return any(num % 2 == 0 for num in numbers) + +def all_strings_start_with_a(strings): + # List comprehension inside `all()` - triggers R1729 + return all(s.startswith('A') for s in strings) diff --git a/tests/input/inefficient_code_example_2.py b/tests/input/inefficient_code_example_2.py new file mode 100644 index 00000000..f68c1f09 --- /dev/null +++ b/tests/input/inefficient_code_example_2.py @@ -0,0 +1,119 @@ +import datetime # unused import + + +class Temp: + + def __init__(self) -> None: + self.unused_class_attribute = True + self.a = 3 + + def temp_function(self): + unused_var = 3 + b = 4 + return self.a + b + + +class DataProcessor: + + def __init__(self, data): + self.data = data + self.processed_data = [] + + def process_all_data(self): + if not self.data: + return [] + results = [] + for item in self.data: + try: + result = self.complex_calculation(item, "multiply", True, False) + results.append(result) + except Exception as e: + print("An error occurred:", e) + if isinstance(self.data[0], str): + print(self.data[0].upper().strip().replace(" ", "_").lower()) + self.processed_data = list( + filter(lambda x: x is not None and x != 0 and len(str(x)) > 1, results) + ) + return self.processed_data + + @staticmethod + def complex_calculation(item, operation, threshold, max_value): + if operation == "multiply": + result = item * threshold + elif operation == "add": + result = item + max_value + else: + result = item + return result + + @staticmethod + def multi_param_calculation( + item1, + item2, + item3, + flag1, + flag2, + flag3, + operation, + threshold, + max_value, + option, + final_stage, + min_value, + ): + value = 0 + if operation == "multiply": + value = item1 * item2 * item3 + elif operation == "add": + value = item1 + item2 + item3 + elif flag1 == "true": + value = item1 + elif flag2 == "true": + value = item2 + elif flag3 == "true": + value = item3 + elif max_value < threshold: + value = max_value + else: + value = min_value + return value + + +class AdvancedProcessor(DataProcessor): + + @staticmethod + def check_data(item): + return ( + True if item > 10 else False if item < -10 else None if item == 0 else item + ) + + def complex_comprehension(self): + self.processed_data = [ + (x**2 if x % 2 == 0 else x**3) + for x in range(1, 100) + if x % 5 == 0 and x != 50 and x > 3 + ] + + def long_chain(self): + try: + deep_value = self.data[0][1]["details"]["info"]["more_info"][2]["target"] + return deep_value + except (KeyError, IndexError, TypeError): + return None + + @staticmethod + def long_scope_chaining(): + for a in range(10): + for b in range(10): + for c in range(10): + for d in range(10): + for e in range(10): + if a + b + c + d + e > 25: + return "Done" + + +if __name__ == "__main__": + sample_data = [1, 2, 3, 4, 5] + processor = DataProcessor(sample_data) + processed = processor.process_all_data() + print("Processed Data:", processed) diff --git a/tests/input/inefficient_code_example_2_tests.py b/tests/input/inefficient_code_example_2_tests.py new file mode 100644 index 00000000..4f0c1731 --- /dev/null +++ b/tests/input/inefficient_code_example_2_tests.py @@ -0,0 +1,105 @@ +import unittest +from datetime import datetime + +from inefficient_code_example_2 import ( + AdvancedProcessor, + DataProcessor, +) # Just to show the unused import issue + + +# Assuming the classes DataProcessor and AdvancedProcessor are already defined +# and imported + + +class TestDataProcessor(unittest.TestCase): + + def test_process_all_data(self): + # Test valid data processing + data = [1, 2, 3, 4, 5] + processor = DataProcessor(data) + processed_data = processor.process_all_data() + # Expecting values [10, 20, 30, 40, 50] (because all are greater than 1 character in length) + self.assertEqual(processed_data, [10, 20, 30, 40, 50]) + + def test_process_all_data_empty(self): + # Test with empty data list + processor = DataProcessor([]) + processed_data = processor.process_all_data() + self.assertEqual(processed_data, []) + + def test_complex_calculation_multiply(self): + # Test multiplication operation + result = DataProcessor.complex_calculation(True, "multiply", 10, 20) + self.assertEqual(result, 50) # 5 * 10 + + def test_complex_calculation_add(self): + # Test addition operation + result = DataProcessor.complex_calculation(True, "add", 20, 5) + self.assertEqual(result, 25) # 5 + 20 + + def test_complex_calculation_default(self): + # Test default operation + result = DataProcessor.complex_calculation(True, "unknown", 10, 20) + self.assertEqual(result, 5) # Default value is item itself + + +class TestAdvancedProcessor(unittest.TestCase): + + def test_complex_comprehension(self): + # Test complex list comprehension + processor = AdvancedProcessor([1, 2, 3, 4, 5]) + processor.complex_comprehension() + expected_result = [ + 125, + 100, + 3375, + 400, + 15625, + 900, + 42875, + 1600, + 91125, + 166375, + 3600, + 274625, + 4900, + 421875, + 6400, + 614125, + 8100, + 857375, + ] + self.assertEqual(processor.processed_data, expected_result) + + def test_long_chain_valid(self): + # Test valid deep chain access + data = [ + [ + None, + { + "details": { + "info": {"more_info": [{}, {}, {"target": "Valid Value"}]} + } + }, + ] + ] + processor = AdvancedProcessor(data) + result = processor.long_chain() + self.assertEqual(result, "Valid Value") + + def test_long_chain_invalid(self): + # Test invalid deep chain access, should return None + data = [{"details": {"info": {"more_info": [{}]}}}] + processor = AdvancedProcessor(data) + result = processor.long_chain() + self.assertIsNone(result) + + def test_long_scope_chaining(self): + # Test long scope chaining, expecting 'Done' when the sum exceeds 25 + processor = AdvancedProcessor([1, 2, 3, 4, 5]) + result = processor.long_scope_chaining() + self.assertEqual(result, "Done") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/input/inefficient_code_example_3.py b/tests/input/inefficient_code_example_3.py new file mode 100644 index 00000000..04cc9573 --- /dev/null +++ b/tests/input/inefficient_code_example_3.py @@ -0,0 +1,22 @@ +import numpy as np +import time + + +def heavy_computation(): + # Start a large matrix multiplication task to consume CPU + print("Starting heavy computation...") + size = 1000 + matrix_a = np.random.rand(size, size) + matrix_b = np.random.rand(size, size) + + start_time = time.time() + result = np.dot(matrix_a, matrix_b) + end_time = time.time() + + print(f"Heavy computation finished in {end_time - start_time:.2f} seconds") + + +# Run the heavy computation in a loop for a longer duration +for _ in range(5): + heavy_computation() + time.sleep(1) # Add a small delay to observe periodic CPU load diff --git a/tests/input/long_param.py b/tests/input/long_param.py new file mode 100644 index 00000000..a95b0cfa --- /dev/null +++ b/tests/input/long_param.py @@ -0,0 +1,252 @@ +################################################ Constructors ############################################################### +class UserDataProcessor1: + # 1. 0 parameters + def __init__(self): + self.config = {} + self.data = [] + +class UserDataProcessor2: + # 2. 4 parameters (no unused) + def __init__(self, user_id, username, email, app_config): + self.user_id = user_id + self.username = username + self.email = email + self.app_config = app_config + +class UserDataProcessor3: + # 3. 4 parameters (1 unused) + def __init__(self, user_id, username, email, theme="light"): + self.user_id = user_id + self.username = username + self.email = email + # theme is unused + +class UserDataProcessor4: + # 4. 8 parameters (no unused) + def __init__(self, user_id, username, email, preferences, timezone_config, language, notification_settings, is_active): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.language = language + self.notification_settings = notification_settings + self.is_active = is_active + +class UserDataProcessor5: + # 5. 8 parameters (1 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, region, notification_settings, theme="light"): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + self.timezone_config = timezone_config + self.region = region + self.notification_settings = notification_settings + # theme is unused + +class UserDataProcessor6: + # 6. 8 parameters (4 unused) + def __init__(self, user_id, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.user_id = user_id + self.username = username + self.email = email + self.preferences = preferences + # timezone_config, backup_config, display_theme, active_status are unused + + ################################################ Instance Methods ############################################################### + # 1. 0 parameters + def clear_data(self): + self.data = [] + + # 2. 4 parameters (no unused) + def update_settings(self, display_mode, alert_settings, language_preference, timezone_config): + self.settings["display_mode"] = display_mode + self.settings["alert_settings"] = alert_settings + self.settings["language_preference"] = language_preference + self.settings["timezone"] = timezone_config + + # 3. 4 parameters (1 unused) + def update_profile(self, username, email, timezone_config, bio=None): + self.username = username + self.email = email + self.settings["timezone"] = timezone_config + # bio is unused + + # 4. 8 parameters (no unused) + def bulk_update(self, username, email, preferences, timezone_config, region, notification_settings, theme="light", is_active=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + self.settings["region"] = region + self.settings["notifications"] = notification_settings + self.settings["theme"] = theme + self.settings["is_active"] = is_active + + # 5. 8 parameters (1 unused) + def bulk_update_partial(self, username, email, preferences, timezone_config, region, notification_settings, theme, active_status=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + self.settings["region"] = region + self.settings["notifications"] = notification_settings + self.settings["theme"] = theme + # active_status is unused + + # 6. 7 parameters (3 unused) + def partial_update(self, username, email, preferences, timezone_config, backup_config=None, display_theme=None, active_status=None): + self.username = username + self.email = email + self.preferences = preferences + self.settings["timezone"] = timezone_config + # backup_config, display_theme, active_status are unused + +################################################ Static Methods ############################################################### + + # 1. 0 parameters + @staticmethod + def reset_global_settings(): + return {"theme": "default", "language": "en", "notifications": True} + + # 2. 4 parameters (no unused) + @staticmethod + def validate_user_input(username, email, password, age): + return all([username, email, password, age >= 18]) + + # 3. 4 parameters (2 unused) + @staticmethod + def hash_password(password, salt, encryption="SHA256", retries=1000): + # encryption and retries are unused + return f"hashed({password} + {salt})" + + # 4. 8 parameters (no unused) + @staticmethod + def generate_report(username, email, preferences, timezone_config, region, notification_settings, theme, is_active): + return { + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "region": region, + "notifications": notification_settings, + "theme": theme, + "is_active": is_active, + } + + # 5. 8 parameters (1 unused) + @staticmethod + def generate_report_partial(username, email, preferences, timezone_config, region, notification_settings, theme, active_status=None): + return { + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "region": region, + "notifications": notification_settings, + "active status": active_status, + } + # theme is unused + + # 6. 8 parameters (3 unused) + # @staticmethod + # def minimal_report(username, email, preferences, timezone_config, backup, region="Global", display_mode=None, status=None): + # return { + # "username": username, + # "email": email, + # "preferences": preferences, + # "timezone": timezone_config, + # "region": region + # } + # # backup, display_mode, status are unused + + +################################################ Standalone Functions ############################################################### + +# 1. 0 parameters +def reset_system(): + return "System reset completed" + +# 2. 4 parameters (no unused) +def calculate_discount(price, discount_rate, minimum_purchase, maximum_discount): + if price >= minimum_purchase: + return min(price * discount_rate, maximum_discount) + return 0 + +# 3. 4 parameters (1 unused) +def apply_coupon(coupon_code, expiry_date, discount_rate, minimum_order=None): + return f"Coupon {coupon_code} applied with {discount_rate}% off until {expiry_date}" + # minimum_order is unused + +# 4. 8 parameters (no unused) +def create_user_report(user_id, username, email, preferences, timezone_config, language, notification_settings, is_active): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "language": language, + "notifications": notification_settings, + "is_active": is_active, + } + +# 5. 8 parameters (1 unused) +def create_partial_report(user_id, username, email, preferences, timezone_config, language, notification_settings, active_status=None): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + "language": language, + "notifications": notification_settings, + } + # active_status is unused + +# 6. 8 parameters (3 unused) +def create_minimal_report(user_id, username, email, preferences, timezone_config, backup_config=None, alert_settings=None, active_status=None): + return { + "user_id": user_id, + "username": username, + "email": email, + "preferences": preferences, + "timezone": timezone_config, + } + # backup_config, alert_settings, active_status are unused + +################################################ Calls ############################################################### + +# Constructor calls +user1 = UserDataProcessor1() +user2 = UserDataProcessor2(1, "johndoe", "johndoe@example.com", app_config={"theme": "dark"}) +user3 = UserDataProcessor3(1, "janedoe", email="janedoe@example.com") +user4 = UserDataProcessor4(2, "johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", language="en", notification_settings=False, is_active=True) +user5 = UserDataProcessor5(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "UTC", region="en", notification_settings=False) +user6 = UserDataProcessor6(3, "janedoe", "janedoe@example.com", {"theme": "blue"}, timezone_config="PST") + +# Instance method calls +user6.clear_data() +user6.update_settings("dark_mode", True, "en", timezone_config="UTC") +user6.update_profile(username="janedoe", email="janedoe@example.com", timezone_config="PST") +user6.bulk_update("johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, "dark", is_active=True) +user6.bulk_update_partial("janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", False, "light", active_status="offline") +user6.partial_update("janedoe", "janedoe@example.com", preferences={"theme": "blue"}, timezone_config="PST") + +# Static method calls +UserDataProcessor6.reset_global_settings() +UserDataProcessor6.validate_user_input("johndoe", "johndoe@example.com", password="password123", age=25) +UserDataProcessor6.hash_password("password123", "salt123", retries=200) +UserDataProcessor6.generate_report("johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, "dark", True) +UserDataProcessor6.generate_report_partial("janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", False, theme="green", active_status="online") +# UserDataProcessor6.minimal_report("janedoe", "janedoe@example.com", {"theme": "blue"}, "PST", False, "Canada") + +# Standalone function calls +reset_system() +calculate_discount(price=100, discount_rate=0.1, minimum_purchase=50, maximum_discount=20) +apply_coupon("SAVE10", "2025-12-31", 10, minimum_order=2) +create_user_report(1, "johndoe", "johndoe@example.com", {"theme": "dark"}, "UTC", "en", True, True) +create_partial_report(2, "janedoe", "janedoe@example.com", {"theme": "light"}, "PST", "en", notification_settings=False) +create_minimal_report(3, "janedoe", "janedoe@example.com", {"theme": "blue"}, timezone_config="PST") + diff --git a/tests/input/project_car_stuff/__init__.py b/tests/input/project_car_stuff/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_car_stuff/main.py b/tests/input/project_car_stuff/main.py new file mode 100644 index 00000000..b4b03ea0 --- /dev/null +++ b/tests/input/project_car_stuff/main.py @@ -0,0 +1,162 @@ +import math # Unused import + +class Test: + def __init__(self, name) -> None: + self.name = name + pass + + def unused_method(self): + print('Hello World!') + + +# Code Smell: Long Parameter List +class Vehicle: + def __init__( + self, make, model, year: int, color, fuel_type, engine_start_stop_option, mileage, suspension_setting, transmission, price, seat_position_setting = None + ): + # Code Smell: Long Parameter List in __init__ + self.make = make # positional argument + self.model = model + self.year = year + self.color = color + self.fuel_type = fuel_type + self.engine_start_stop_option = engine_start_stop_option + self.mileage = mileage + self.suspension_setting = suspension_setting + self.transmission = transmission + self.price = price + self.seat_position_setting = seat_position_setting # default value + self.owner = None # Unused class attribute, used in constructor + + def display_info(self): + # Code Smell: Long Message Chain + random_test = self.make.split('') + print(f"Make: {self.make}, Model: {self.model}, Year: {self.year}".upper().replace(",", "")[::2]) + + def calculate_price(self): + # Code Smell: List Comprehension in an All Statement + condition = all( + [ + isinstance(attribute, str) + for attribute in [self.make, self.model, self.year, self.color] + ] + ) + if condition: + return ( + self.price * 0.9 + ) # Apply a 10% discount if all attributes are strings (totally arbitrary condition) + + return self.price + + def unused_method(self): + # Code Smell: Member Ignoring Method + print( + "This method doesn't interact with instance attributes, it just prints a statement." + ) + +class Car(Vehicle): + + def __init__( + self, + make, + model, + year, + color, + fuel_type, + engine_start_stop_option, + mileage, + suspension_setting, + transmission, + price, + sunroof=False, + ): + super().__init__( + make, model, year, color, fuel_type, engine_start_stop_option, mileage, suspension_setting, transmission, price + ) + self.sunroof = sunroof + self.engine_size = 2.0 # Unused variable in class + + def add_sunroof(self): + # Code Smell: Long Parameter List + self.sunroof = True + print("Sunroof added!") + + def show_details(self): + # Code Smell: Long Message Chain + details = f"Car: {self.make} {self.model} ({self.year}) | Mileage: {self.mileage} | Transmission: {self.transmission} | Sunroof: {self.sunroof} | Engine Start Option: {self.engine_start_stop_option} | Suspension Setting: {self.suspension_setting} | Seat Position {self.seat_position_setting}" + print(details.upper().lower().upper().capitalize().upper().replace("|", "-")) + + +def process_vehicle(vehicle: Vehicle): + # Code Smell: Unused Variables + temp_discount = 0.05 + temp_shipping = 100 + + vehicle.display_info() + price_after_discount = vehicle.calculate_price() + print(f"Price after discount: {price_after_discount}") + + vehicle.unused_method() # Calls a method that doesn't actually use the class attributes + + +def is_all_string(attributes): + # Code Smell: List Comprehension in an All Statement + return all(isinstance(attribute, str) for attribute in attributes) + + +def access_nested_dict(): + nested_dict1 = {"level1": {"level2": {"level3": {"key": "value"}}}} + + nested_dict2 = { + "level1": { + "level2": { + "level3": {"key": "value", "key2": "value2"}, + "level3a": {"key": "value"}, + } + } + } + print(nested_dict1["level1"]["level2"]["level3"]["key"]) + print(nested_dict2["level1"]["level2"]["level3"]["key2"]) + print(nested_dict2["level1"]["level2"]["level3"]["key"]) + print(nested_dict2["level1"]["level2"]["level3a"]["key"]) + print(nested_dict1["level1"]["level2"]["level3"]["key"]) + +# Main loop: Arbitrary use of the classes and demonstrating code smells +if __name__ == "__main__": + car1 = Car( + make="Toyota", + model="Camry", + year=2020, + color="Blue", + fuel_type="Gas", + engine_start_stop_option = "no key", + mileage=25000, + suspension_setting = "Sport", + transmission="Automatic", + price=20000, + ) + process_vehicle(car1) + car1.add_sunroof() + car1.show_details() + + car1.unused_method() + + # Testing with another vehicle object + car2 = Vehicle( + "Honda", + model="Civic", + year=2018, + color="Red", + fuel_type="Gas", + engine_start_stop_option = "key", + mileage=30000, + suspension_setting = "Sport", + transmission="Manual", + price=15000, + ) + process_vehicle(car2) + + test = Test('Anna') + test.unused_method() + + print("Hello") diff --git a/tests/input/project_car_stuff/test_main.py b/tests/input/project_car_stuff/test_main.py new file mode 100644 index 00000000..70126d34 --- /dev/null +++ b/tests/input/project_car_stuff/test_main.py @@ -0,0 +1,34 @@ +import pytest +from .main import Vehicle, Car, process_vehicle + +# Fixture to create a car instance +@pytest.fixture +def car1(): + return Car(make="Toyota", model="Camry", year=2020, color="Blue", fuel_type="Gas", mileage=25000, transmission="Automatic", price=20000) + +# Test the price after applying discount +def test_vehicle_price_after_discount(car1): + assert car1.calculate_price() == 20000, "Price after discount should be 18000" + +# Test the add_sunroof method to confirm it works as expected +def test_car_add_sunroof(car1): + car1.add_sunroof() + assert car1.sunroof is True, "Car should have sunroof after add_sunroof() is called" + +# Test that show_details method runs without error +def test_car_show_details(car1, capsys): + car1.show_details() + captured = capsys.readouterr() + assert "CAR: TOYOTA CAMRY" in captured.out # Checking if the output contains car details + +# Test the is_all_string function indirectly through the calculate_price method +def test_is_all_string(car1): + price_after_discount = car1.calculate_price() + assert price_after_discount > 0, "Price calculation should return a valid price" + +# Test the process_vehicle function to check its behavior with a Vehicle object +def test_process_vehicle(car1, capsys): + process_vehicle(car1) + captured = capsys.readouterr() + assert "Price after discount" in captured.out, "The process_vehicle function should output the price after discount" + diff --git a/tests/input/project_long_parameter_list/src/__init__.py b/tests/input/project_long_parameter_list/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_long_parameter_list/src/caller_1.py b/tests/input/project_long_parameter_list/src/caller_1.py new file mode 100644 index 00000000..d0409523 --- /dev/null +++ b/tests/input/project_long_parameter_list/src/caller_1.py @@ -0,0 +1,7 @@ +from main import process_data, process_extra + +pd = process_data(1, 2, 3, 4, 3, 2, 3, 5) +pe = process_extra(1, 2, 3, 4, 3, 2, 3, 5) + +print(pd) +print(pe) \ No newline at end of file diff --git a/tests/input/project_long_parameter_list/src/caller_2.py b/tests/input/project_long_parameter_list/src/caller_2.py new file mode 100644 index 00000000..241cf165 --- /dev/null +++ b/tests/input/project_long_parameter_list/src/caller_2.py @@ -0,0 +1,7 @@ +from main import Helper + +pcd = Helper.process_class_data(1, 2, 3, 4, 3, 2, 3, 5) +pmd = Helper.process_more_class_data(1, 2, 3, 4, 3, 2, 3, 5) + +print(pcd) +print(pmd) \ No newline at end of file diff --git a/tests/input/project_long_parameter_list/src/main.py b/tests/input/project_long_parameter_list/src/main.py new file mode 100644 index 00000000..84c3a9bd --- /dev/null +++ b/tests/input/project_long_parameter_list/src/main.py @@ -0,0 +1,44 @@ +import math +print(math.isclose(20, 100)) + +def process_local_call(data_value1, data_value2, data_item1, data_item2, + config_path, config_setting, config_option, config_env): + return (data_value1 * data_value2 - data_item1 * data_item2 + + config_path * config_setting - config_option * config_env) + + +def process_data(data_value1, data_value2, data_item1, data_item2, + config_path, config_setting, config_option, config_env): + return (data_value1 + data_value2 + data_item1) * (data_item2 + config_path + ) - (config_setting + config_option + config_env) + + +def process_extra(data_record1, data_record2, data_result1, data_result2, + config_file, config_mode, config_param, config_directory): + return data_record1 - data_record2 + (data_result1 - data_result2) * ( + config_file - config_mode) + (config_param - config_directory) + + +class Helper: + + def process_class_data(self, data_input1, data_input2, data_output1, + data_output2, config_file, config_user, config_theme, config_env): + return (data_input1 * data_input2 + data_output1 * data_output2 - + config_file * config_user + config_theme * config_env) + + def process_more_class_data(self, data_record1, data_record2, + data_item1, data_item2, config_log, config_cache, config_timeout, + config_profile): + return data_record1 + data_record2 - (data_item1 + data_item2) + ( + config_log + config_cache) - (config_timeout + config_profile) + + +def main(): + local_result = process_local_call(1, 2, 3, 4, 3, 2, 3, 5) + print(local_result) + + +if __name__ == '__main__': + main() + + diff --git a/tests/input/project_long_parameter_list/tests/test_main.py b/tests/input/project_long_parameter_list/tests/test_main.py new file mode 100644 index 00000000..c1d6018e --- /dev/null +++ b/tests/input/project_long_parameter_list/tests/test_main.py @@ -0,0 +1,24 @@ +from src.caller_1 import process_data, process_extra +from src.caller_2 import Helper +from src.main import process_local + +def test_process_data(): + assert process_data(1, 2, 3, 4, 5, 6, 7, 8) == 33 + +def test_process_extra(): + assert process_extra(1, 2, 3, 4, 5, 6, 7, 8) == -1 + +def test_helper_class(): + h = Helper() + assert h.process_class_data(1, 2, 3, 4, 5, 6, 7, 8) == 40 + assert h.process_more_class_data(1, 2, 3, 4, 5, 6, 7, 8) == -8 + +def test_process_local(): + assert process_local(1, 2, 3, 4, 5, 6, 7, 8) == -36 + +if __name__ == "__main__": + test_process_data() + test_process_extra() + test_helper_class() + test_process_local() + print("All tests passed!") diff --git a/tests/input/project_multi_file_lec/src/__init__.py b/tests/input/project_multi_file_lec/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_multi_file_lec/src/main.py b/tests/input/project_multi_file_lec/src/main.py new file mode 100644 index 00000000..ca18eaf9 --- /dev/null +++ b/tests/input/project_multi_file_lec/src/main.py @@ -0,0 +1,12 @@ +from src.processor import process_data + +def main(): + """ + Main entry point of the application. + """ + sample_data = "hello world" + processed = process_data(sample_data) + print(f"Processed Data: {processed}") + +if __name__ == "__main__": + main() diff --git a/tests/input/project_multi_file_lec/src/processor.py b/tests/input/project_multi_file_lec/src/processor.py new file mode 100644 index 00000000..25dd083c --- /dev/null +++ b/tests/input/project_multi_file_lec/src/processor.py @@ -0,0 +1,16 @@ +from src.utils import Utility + +def process_data(data): + """ + Process some data and call the long_element_chain method from Utility. + """ + util = Utility() + my_call = util.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + lastVal = util.get_last_value() + fourthLevel = util.get_4th_level_value() + print(f"My call here: {my_call}") + print(f"Extracted Value1: {lastVal}") + print(f"Extracted Value2: {fourthLevel}") + return data.upper() + + diff --git a/tests/input/project_multi_file_lec/src/utils.py b/tests/input/project_multi_file_lec/src/utils.py new file mode 100644 index 00000000..00075717 --- /dev/null +++ b/tests/input/project_multi_file_lec/src/utils.py @@ -0,0 +1,23 @@ +class Utility: + def __init__(self): + self.long_chain = { + "level1": { + "level2": { + "level3": { + "level4": { + "level5": { + "level6": { + "level7": "deeply nested value" + } + } + } + } + } + } + } + + def get_last_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"]["level5"]["level6"]["level7"] + + def get_4th_level_value(self): + return self.long_chain["level1"]["level2"]["level3"]["level4"] diff --git a/tests/input/project_multi_file_mim/src/__init__.py b/tests/input/project_multi_file_mim/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_multi_file_mim/src/main.py b/tests/input/project_multi_file_mim/src/main.py new file mode 100644 index 00000000..ca18eaf9 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/main.py @@ -0,0 +1,12 @@ +from src.processor import process_data + +def main(): + """ + Main entry point of the application. + """ + sample_data = "hello world" + processed = process_data(sample_data) + print(f"Processed Data: {processed}") + +if __name__ == "__main__": + main() diff --git a/tests/input/project_multi_file_mim/src/processor.py b/tests/input/project_multi_file_mim/src/processor.py new file mode 100644 index 00000000..5afb1cd0 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/processor.py @@ -0,0 +1,9 @@ +from src.utils import Utility + +def process_data(data): + """ + Process some data and call the unused_member_method from Utility. + """ + util = Utility() + util.unused_member_method(data) + return data.upper() diff --git a/tests/input/project_multi_file_mim/src/utils.py b/tests/input/project_multi_file_mim/src/utils.py new file mode 100644 index 00000000..5d117544 --- /dev/null +++ b/tests/input/project_multi_file_mim/src/utils.py @@ -0,0 +1,7 @@ +class Utility: + def unused_member_method(self, param): + """ + A method that accepts a parameter but doesn’t use it. + This demonstrates the member ignoring code smell. + """ + print("This method is defined but doesn’t use its parameter.") diff --git a/tests/input/project_multi_file_mim/tests/test_processor.py b/tests/input/project_multi_file_mim/tests/test_processor.py new file mode 100644 index 00000000..6bf0dc29 --- /dev/null +++ b/tests/input/project_multi_file_mim/tests/test_processor.py @@ -0,0 +1,8 @@ +from src.processor import process_data + +def test_process_data(): + """ + Test the process_data function. + """ + result = process_data("test") + assert result == "TEST" diff --git a/tests/input/project_multi_file_mim/tests/test_utils.py b/tests/input/project_multi_file_mim/tests/test_utils.py new file mode 100644 index 00000000..c5ac5b11 --- /dev/null +++ b/tests/input/project_multi_file_mim/tests/test_utils.py @@ -0,0 +1,10 @@ +from src.utils import Utility + +def test_unused_member_method(capfd): + """ + Test the unused_member_method to ensure it behaves as expected. + """ + util = Utility() + util.unused_member_method("test") + captured = capfd.readouterr() + assert "This method is defined but doesn’t use its parameter." in captured.out diff --git a/tests/input/project_repeated_calls/main.py b/tests/input/project_repeated_calls/main.py new file mode 100644 index 00000000..464953d0 --- /dev/null +++ b/tests/input/project_repeated_calls/main.py @@ -0,0 +1,85 @@ +# Example Python file with repeated calls smells + +class Demo: + def __init__(self, value): + self.value = value + + def compute(self): + return self.value * 2 + +# Simple repeated function calls +def simple_repeated_calls(): + value = Demo(10).compute() + result = value + Demo(10).compute() # Repeated call + return result + +# Repeated method calls on an object +def repeated_method_calls(): + demo = Demo(5) + first = demo.compute() + second = demo.compute() # Repeated call on the same object + return first + second + +# Repeated attribute access with method calls +def repeated_attribute_calls(): + demo = Demo(3) + first = demo.compute() + demo.value = 10 # Modify attribute + second = demo.compute() # Repeated but valid since the attribute was modified + return first + second + +# Repeated nested calls +def repeated_nested_calls(): + data = [Demo(i) for i in range(3)] + total = sum(demo.compute() for demo in data) + repeated = sum(demo.compute() for demo in data) # Repeated nested call + return total + repeated + +# Repeated calls in a loop +def repeated_calls_in_loop(): + results = [] + for i in range(5): + results.append(Demo(i).compute()) # Repeated call for each loop iteration + return results + +# Repeated calls with modifications in between +def repeated_calls_with_modification(): + demo = Demo(2) + first = demo.compute() + demo.value = 4 # Modify object + second = demo.compute() # Repeated but valid due to modification + return first + second + +# Repeated calls with mixed contexts +def repeated_calls_mixed_context(): + demo1 = Demo(1) + demo2 = Demo(2) + result1 = demo1.compute() + result2 = demo2.compute() + result3 = demo1.compute() # Repeated for demo1 + return result1 + result2 + result3 + +# Repeated calls with multiple arguments +def repeated_calls_with_args(): + result = max(Demo(1).compute(), Demo(1).compute()) # Repeated identical calls + return result + +# Repeated calls using a lambda +def repeated_lambda_calls(): + compute_demo = lambda x: Demo(x).compute() + first = compute_demo(3) + second = compute_demo(3) # Repeated lambda call + return first + second + +# Repeated calls with external dependencies +def repeated_calls_with_external_dependency(data): + result = len(data.get('key')) # Repeated external call + repeated = len(data.get('key')) + return result + repeated + +# Repeated calls with slightly different arguments +def repeated_calls_slightly_different(): + demo = Demo(10) + first = demo.compute() + second = Demo(20).compute() # Different object, not a true repeated call + return first + second diff --git a/tests/input/project_string_concat/__init__.py b/tests/input/project_string_concat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/input/project_string_concat/main.py b/tests/input/project_string_concat/main.py new file mode 100644 index 00000000..b7be86dc --- /dev/null +++ b/tests/input/project_string_concat/main.py @@ -0,0 +1,137 @@ +class Demo: + def __init__(self) -> None: + self.test = "" + +def super_complex(): + result = '' + log = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + log += "Log entry for i=" + str(i) + if i == 2: + result = "" # Resetting `result` + +def concat_with_for_loop_simple_attr(): + result = Demo() + for i in range(10): + result.test += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple_sub(): + result = {"key": ""} + for i in range(10): + result["key"] += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple(): + result = "" + for i in range(10): + result += str(i) # Simple concatenation + return result + +def concat_with_while_loop_variable_append(): + result = "" + i = 0 + while i < 5: + result += f"Value-{i}" # Using f-string inside while loop + i += 1 + return result + +def nested_loop_string_concat(): + result = "" + for i in range(2): + result = str(i) + for j in range(3): + result += f"({i},{j})" # Nested loop concatenation + return result + +def string_concat_with_condition(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "Even" # Conditional concatenation + else: + result += "Odd" # Different condition + return result + +def concatenate_with_literal(): + result = "Start" + for i in range(4): + result += "-Next" # Concatenating a literal string + return result + +def complex_expression_concat(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + return result + +def repeated_variable_reassignment(): + result = Demo() + for i in range(2): + result.test = result.test + "First" + result.test = result.test + "Second" # Multiple reassignments + return result + +# Concatenation with % operator using only variables +def greet_user_with_percent(name): + greeting = "" + for i in range(2): + greeting += "Hello, " + "%s" % name + return greeting + +# Concatenation with str.format() using only variables +def describe_city_with_format(city): + description = "" + for i in range(2): + description = description + "I live in " + "the city of {}".format(city) + return description + +# Nested interpolation with % and concatenation +def person_description_with_percent(name, age): + description = "" + for i in range(2): + description += "Person: " + "%s, Age: %d" % (name, age) + return description + +# Multiple str.format() calls with concatenation +def values_with_format(x, y): + result = "" + for i in range(2): + result = result + "Value of x: {}".format(x) + ", and y: {:.2f}".format(y) + return result + +# Simple variable concatenation (edge case for completeness) +def simple_variable_concat(a: str, b: str): + result = Demo().test + for i in range(2): + result += a + b + return result + +def middle_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + str(i) + return result + +def end_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + return result + +def concat_referenced_in_loop(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + print(result) + return result + +def concat_not_in_loop(): + name = "Bob" + name += "Ross" + return name + +simple_variable_concat("Hello", " World ") \ No newline at end of file diff --git a/tests/input/project_string_concat/test_main.py b/tests/input/project_string_concat/test_main.py new file mode 100644 index 00000000..461ccccb --- /dev/null +++ b/tests/input/project_string_concat/test_main.py @@ -0,0 +1,86 @@ +import pytest +from .main import ( + concat_with_for_loop_simple, + complex_expression_concat, + concat_with_for_loop_simple_attr, + concat_with_for_loop_simple_sub, + concat_with_while_loop_variable_append, + concatenate_with_literal, + simple_variable_concat, + string_concat_with_condition, + nested_loop_string_concat, + repeated_variable_reassignment, + greet_user_with_percent, + describe_city_with_format, + person_description_with_percent, + values_with_format, + middle_var_concat, + end_var_concat +) + +def test_concat_with_for_loop_simple_attr(): + result = concat_with_for_loop_simple_attr() + assert result.test == ''.join(str(i) for i in range(10)) + +def test_concat_with_for_loop_simple_sub(): + result = concat_with_for_loop_simple_sub() + assert result["key"] == ''.join(str(i) for i in range(10)) + +def test_concat_with_for_loop_simple(): + result = concat_with_for_loop_simple() + assert result == ''.join(str(i) for i in range(10)) + +def test_concat_with_while_loop_variable_append(): + result = concat_with_while_loop_variable_append() + assert result == ''.join(f"Value-{i}" for i in range(5)) + +def test_nested_loop_string_concat(): + result = nested_loop_string_concat() + expected = "1(1,0)(1,1)(1,2)" + assert result == expected + +def test_string_concat_with_condition(): + result = string_concat_with_condition() + expected = ''.join("Even" if i % 2 == 0 else "Odd" for i in range(5)) + assert result == expected + +def test_concatenate_with_literal(): + result = concatenate_with_literal() + assert result == "Start" + "-Next" * 4 + +def test_complex_expression_concat(): + result = complex_expression_concat() + expected = ''.join(f"Complex{i*i}End" for i in range(3)) + assert result == expected + +def test_repeated_variable_reassignment(): + result = repeated_variable_reassignment() + assert result.test == ("FirstSecond" * 2) + +def test_greet_user_with_percent(): + result = greet_user_with_percent("Alice") + assert result == ("Hello, Alice" * 2) + +def test_describe_city_with_format(): + result = describe_city_with_format("London") + assert result == ("I live in the city of London" * 2) + +def test_person_description_with_percent(): + result = person_description_with_percent("Bob", 25) + assert result == ("Person: Bob, Age: 25" * 2) + +def test_values_with_format(): + result = values_with_format(42, 3.14) + assert result == ("Value of x: 42, and y: 3.14" * 2) + +def test_simple_variable_concat(): + result = simple_variable_concat("foo", "bar") + assert result == ("foobar" * 2) + +def test_end_var_concat(): + result = end_var_concat() + assert result == ("210") + +def test_middle_var_concat(): + result = middle_var_concat() + assert result == ("210012") diff --git a/tests/input/repeated_calls_examples.py b/tests/input/repeated_calls_examples.py new file mode 100644 index 00000000..464953d0 --- /dev/null +++ b/tests/input/repeated_calls_examples.py @@ -0,0 +1,85 @@ +# Example Python file with repeated calls smells + +class Demo: + def __init__(self, value): + self.value = value + + def compute(self): + return self.value * 2 + +# Simple repeated function calls +def simple_repeated_calls(): + value = Demo(10).compute() + result = value + Demo(10).compute() # Repeated call + return result + +# Repeated method calls on an object +def repeated_method_calls(): + demo = Demo(5) + first = demo.compute() + second = demo.compute() # Repeated call on the same object + return first + second + +# Repeated attribute access with method calls +def repeated_attribute_calls(): + demo = Demo(3) + first = demo.compute() + demo.value = 10 # Modify attribute + second = demo.compute() # Repeated but valid since the attribute was modified + return first + second + +# Repeated nested calls +def repeated_nested_calls(): + data = [Demo(i) for i in range(3)] + total = sum(demo.compute() for demo in data) + repeated = sum(demo.compute() for demo in data) # Repeated nested call + return total + repeated + +# Repeated calls in a loop +def repeated_calls_in_loop(): + results = [] + for i in range(5): + results.append(Demo(i).compute()) # Repeated call for each loop iteration + return results + +# Repeated calls with modifications in between +def repeated_calls_with_modification(): + demo = Demo(2) + first = demo.compute() + demo.value = 4 # Modify object + second = demo.compute() # Repeated but valid due to modification + return first + second + +# Repeated calls with mixed contexts +def repeated_calls_mixed_context(): + demo1 = Demo(1) + demo2 = Demo(2) + result1 = demo1.compute() + result2 = demo2.compute() + result3 = demo1.compute() # Repeated for demo1 + return result1 + result2 + result3 + +# Repeated calls with multiple arguments +def repeated_calls_with_args(): + result = max(Demo(1).compute(), Demo(1).compute()) # Repeated identical calls + return result + +# Repeated calls using a lambda +def repeated_lambda_calls(): + compute_demo = lambda x: Demo(x).compute() + first = compute_demo(3) + second = compute_demo(3) # Repeated lambda call + return first + second + +# Repeated calls with external dependencies +def repeated_calls_with_external_dependency(data): + result = len(data.get('key')) # Repeated external call + repeated = len(data.get('key')) + return result + repeated + +# Repeated calls with slightly different arguments +def repeated_calls_slightly_different(): + demo = Demo(10) + first = demo.compute() + second = Demo(20).compute() # Different object, not a true repeated call + return first + second diff --git a/tests/input/string_concat_sample.py b/tests/input/string_concat_sample.py new file mode 100644 index 00000000..b7be86dc --- /dev/null +++ b/tests/input/string_concat_sample.py @@ -0,0 +1,137 @@ +class Demo: + def __init__(self) -> None: + self.test = "" + +def super_complex(): + result = '' + log = '' + for i in range(5): + result += "Iteration: " + str(i) + for j in range(3): + result += "Nested: " + str(j) # Contributing to `result` + log += "Log entry for i=" + str(i) + if i == 2: + result = "" # Resetting `result` + +def concat_with_for_loop_simple_attr(): + result = Demo() + for i in range(10): + result.test += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple_sub(): + result = {"key": ""} + for i in range(10): + result["key"] += str(i) # Simple concatenation + return result + +def concat_with_for_loop_simple(): + result = "" + for i in range(10): + result += str(i) # Simple concatenation + return result + +def concat_with_while_loop_variable_append(): + result = "" + i = 0 + while i < 5: + result += f"Value-{i}" # Using f-string inside while loop + i += 1 + return result + +def nested_loop_string_concat(): + result = "" + for i in range(2): + result = str(i) + for j in range(3): + result += f"({i},{j})" # Nested loop concatenation + return result + +def string_concat_with_condition(): + result = "" + for i in range(5): + if i % 2 == 0: + result += "Even" # Conditional concatenation + else: + result += "Odd" # Different condition + return result + +def concatenate_with_literal(): + result = "Start" + for i in range(4): + result += "-Next" # Concatenating a literal string + return result + +def complex_expression_concat(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + return result + +def repeated_variable_reassignment(): + result = Demo() + for i in range(2): + result.test = result.test + "First" + result.test = result.test + "Second" # Multiple reassignments + return result + +# Concatenation with % operator using only variables +def greet_user_with_percent(name): + greeting = "" + for i in range(2): + greeting += "Hello, " + "%s" % name + return greeting + +# Concatenation with str.format() using only variables +def describe_city_with_format(city): + description = "" + for i in range(2): + description = description + "I live in " + "the city of {}".format(city) + return description + +# Nested interpolation with % and concatenation +def person_description_with_percent(name, age): + description = "" + for i in range(2): + description += "Person: " + "%s, Age: %d" % (name, age) + return description + +# Multiple str.format() calls with concatenation +def values_with_format(x, y): + result = "" + for i in range(2): + result = result + "Value of x: {}".format(x) + ", and y: {:.2f}".format(y) + return result + +# Simple variable concatenation (edge case for completeness) +def simple_variable_concat(a: str, b: str): + result = Demo().test + for i in range(2): + result += a + b + return result + +def middle_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + str(i) + return result + +def end_var_concat(): + result = '' + for i in range(3): + result = str(i) + result + return result + +def concat_referenced_in_loop(): + result = "" + for i in range(3): + result += "Complex" + str(i * i) + "End" # Expression inside concatenation + print(result) + return result + +def concat_not_in_loop(): + name = "Bob" + name += "Ross" + return name + +simple_variable_concat("Hello", " World ") \ No newline at end of file diff --git a/tests/measurements/test_codecarbon_energy_meter.py b/tests/measurements/test_codecarbon_energy_meter.py new file mode 100644 index 00000000..5cd294c5 --- /dev/null +++ b/tests/measurements/test_codecarbon_energy_meter.py @@ -0,0 +1,91 @@ +import pytest +import logging +from pathlib import Path +import subprocess +import pandas as pd +from unittest.mock import patch + +from ecooptimizer.measurements.codecarbon_energy_meter import CodeCarbonEnergyMeter + + +@pytest.fixture +def energy_meter(): + return CodeCarbonEnergyMeter() + + +@patch("codecarbon.EmissionsTracker.start") +@patch("codecarbon.EmissionsTracker.stop", return_value=0.45) +@patch("subprocess.run") +def test_measure_energy_success(mock_run, mock_stop, mock_start, energy_meter, caplog): + mock_run.return_value = subprocess.CompletedProcess( + args=["python3", "../input/project_car_stuff/main.py"], returncode=0 + ) + file_path = Path("../input/project_car_stuff/main.py") + with caplog.at_level(logging.INFO): + energy_meter.measure_energy(file_path) + + assert mock_run.call_count >= 1 + mock_run.assert_any_call( + ["/Library/Frameworks/Python.framework/Versions/3.13/bin/python3", file_path], + capture_output=True, + text=True, + check=True, + ) + mock_start.assert_called_once() + mock_stop.assert_called_once() + assert "CodeCarbon measurement completed successfully." in caplog.text + assert energy_meter.emissions == 0.45 + + +@patch("codecarbon.EmissionsTracker.start") +@patch("codecarbon.EmissionsTracker.stop", return_value=0.45) +@patch("subprocess.run", side_effect=subprocess.CalledProcessError(1, "python3")) +def test_measure_energy_failure(mock_run, mock_stop, mock_start, energy_meter, caplog): + file_path = Path("../input/project_car_stuff/main.py") + with caplog.at_level(logging.ERROR): + energy_meter.measure_energy(file_path) + + mock_start.assert_called_once() + mock_run.assert_called_once() + mock_stop.assert_called_once() + assert "Error executing file" in caplog.text + assert ( + energy_meter.emissions_data is None + ) # since execution failed, emissions data should be None + + +@patch("pandas.read_csv") +@patch("pathlib.Path.exists", return_value=True) # mock file existence +def test_extract_emissions_csv_success(mock_exists, mock_read_csv, energy_meter): + # simulate DataFrame return value + mock_read_csv.return_value = pd.DataFrame( + [{"timestamp": "2025-03-01 12:00:00", "emissions": 0.45}] + ) + + csv_path = Path("dummy_path.csv") # fake path + result = energy_meter.extract_emissions_csv(csv_path) + + assert isinstance(result, dict) + assert "emissions" in result + assert result["emissions"] == 0.45 + + +@patch("pandas.read_csv", side_effect=Exception("File read error")) +@patch("pathlib.Path.exists", return_value=True) # mock file existence +def test_extract_emissions_csv_failure(mock_exists, mock_read_csv, energy_meter, caplog): + csv_path = Path("dummy_path.csv") # fake path + with caplog.at_level(logging.INFO): + result = energy_meter.extract_emissions_csv(csv_path) + + assert result is None # since reading the CSV fails, result should be None + assert "Error reading file" in caplog.text + + +@patch("pathlib.Path.exists", return_value=False) +def test_extract_emissions_csv_missing_file(mock_exists, energy_meter, caplog): + csv_path = Path("dummy_path.csv") # fake path + with caplog.at_level(logging.INFO): + result = energy_meter.extract_emissions_csv(csv_path) + + assert result is None # since file path does not exist, result should be None + assert "File 'dummy_path.csv' does not exist." in caplog.text diff --git a/tests/refactorers/test_long_lambda_element_refactoring.py b/tests/refactorers/test_long_lambda_element_refactoring.py new file mode 100644 index 00000000..93392872 --- /dev/null +++ b/tests/refactorers/test_long_lambda_element_refactoring.py @@ -0,0 +1,240 @@ +import pytest +import textwrap +from unittest.mock import patch +from pathlib import Path + +from ecooptimizer.refactorers.concrete.long_lambda_function import ( + LongLambdaFunctionRefactorer, +) +from ecooptimizer.data_types import Occurence, LLESmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return LongLambdaFunctionRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create lambda smell objects.""" + return lambda: LLESmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="long-lambda", + message="Lambda too long", + messageId=CustomSmell.LONG_LAMBDA_EXPR.value, + confidence="UNDEFINED", + occurences=[ + Occurence(line=occ, endLine=999, column=999, endColumn=999) + for occ in occurences + ], + additionalInfo=None, + ) + + +def normalize_code(code: str) -> str: + """Normalize whitespace for reliable comparisons.""" + return "\n".join(line.rstrip() for line in code.strip().splitlines()) + "\n" + + +def test_basic_lambda_conversion(refactorer): + """Tests conversion of simple single-line lambda.""" + code = textwrap.dedent( + """ + def example(): + my_lambda = lambda x: x + 1 + """ + ) + + expected = textwrap.dedent( + """ + def example(): + def converted_lambda_3(x): + result = x + 1 + return result + + my_lambda = converted_lambda_3 + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) + + +def test_no_extra_print_statements(refactorer): + """Ensures no print statements are added unnecessarily.""" + code = textwrap.dedent( + """ + def example(): + processor = lambda x: x.strip().lower() + """ + ) + + expected = textwrap.dedent( + """ + def example(): + def converted_lambda_3(x): + result = x.strip().lower() + return result + + processor = converted_lambda_3 + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + assert "print(" not in written + assert normalize_code(written) == normalize_code(expected) + + +def test_lambda_in_function_argument(refactorer): + """Tests lambda passed as argument to another function.""" + code = textwrap.dedent( + """ + def process_data(): + results = list(map(lambda x: x * 2, [1, 2, 3])) + """ + ) + + expected = textwrap.dedent( + """ + def process_data(): + def converted_lambda_3(x): + result = x * 2 + return result + + results = list(map(converted_lambda_3, [1, 2, 3])) + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert normalize_code(written) == normalize_code(expected) + + +def test_multi_argument_lambda(refactorer): + """Tests lambda with multiple parameters passed as argument.""" + code = textwrap.dedent( + """ + from functools import reduce + def calculate(): + total = reduce(lambda a, b: a + b, [1, 2, 3, 4]) + """ + ) + + expected = textwrap.dedent( + """ + from functools import reduce + def calculate(): + def converted_lambda_4(a, b): + result = a + b + return result + + total = reduce(converted_lambda_4, [1, 2, 3, 4]) + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + assert normalize_code(written) == normalize_code(expected) + + +def test_lambda_with_keyword_arguments(refactorer): + """Tests lambda used with keyword arguments.""" + code = textwrap.dedent( + """ + def configure_settings(): + button = Button( + text="Submit", + on_click=lambda event: handle_event(event, retries=3) + ) + """ + ) + + expected = textwrap.dedent( + """ + def configure_settings(): + def converted_lambda_5(event): + result = handle_event(event, retries=3) + return result + + button = Button( + text="Submit", + on_click=converted_lambda_5 + ) + """ + ) + + smell = create_smell([5])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) + + +def test_very_long_lambda_function(refactorer): + """Tests refactoring of a very long lambda function that spans multiple lines.""" + code = textwrap.dedent( + """ + def calculate(): + value = ( + lambda a, b, c: a + b + c + a * b - c / (a + b) + a - b * c + a**2 - b**2 + a*b + a/(b+c) - c*(a-b) + (a+b+c) + )(1, 2, 3) + """ + ) + + expected = textwrap.dedent( + """ + def calculate(): + def converted_lambda_4(a, b, c): + result = a + b + c + a * b - c / (a + b) + a - b * c + a**2 - b**2 + a*b + a/(b+c) - c*(a-b) + (a+b+c) + return result + + value = ( + converted_lambda_4 + )(1, 2, 3) + """ + ) + + smell = create_smell([4])() + with patch.object(Path, "read_text", return_value=code), \ + patch.object(Path, "write_text") as mock_write: + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + written = mock_write.call_args[0][0] + print(written) + assert normalize_code(written) == normalize_code(expected) diff --git a/tests/refactorers/test_long_message_chain_refactoring.py b/tests/refactorers/test_long_message_chain_refactoring.py new file mode 100644 index 00000000..dfd9760c --- /dev/null +++ b/tests/refactorers/test_long_message_chain_refactoring.py @@ -0,0 +1,261 @@ +import pytest +import textwrap +from unittest.mock import patch +from pathlib import Path + +from ecooptimizer.refactorers.concrete.long_message_chain import ( + LongMessageChainRefactorer, +) +from ecooptimizer.data_types import Occurence, LMCSmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return LongMessageChainRefactorer() + + +def create_smell(occurences: list[int]): + """Factory function to create a smell object for long message chains.""" + + def _create(): + return LMCSmell( + path="fake.py", + module="some_module", + obj=None, + type="convention", + symbol="long-message-chain", + message="Method chain too long", + messageId=CustomSmell.LONG_MESSAGE_CHAIN.value, + confidence="UNDEFINED", + occurences=[ + Occurence(line=occ, endLine=999, column=999, endColumn=999) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_basic_method_chain_refactoring(refactorer): + """Tests refactoring of a basic method chain.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().lower().replace("|", "-").title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + intermediate_2 = intermediate_1.replace("|", "-") + result = intermediate_2.title() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_fstring_chain_refactoring(refactorer): + """Tests refactoring of a long message chain with an f-string.""" + code = textwrap.dedent( + """ + def example(): + name = "John" + greeting = f"Hello {name}".strip().replace(" ", "-").upper() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + name = "John" + intermediate_0 = f"Hello {name}" + intermediate_1 = intermediate_0.strip() + intermediate_2 = intermediate_1.replace(" ", "-") + greeting = intermediate_2.upper() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_modifications_if_no_long_chain(refactorer): + """Ensures modifications occur even if the method chain isnt long.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().lower() + """ + ) + + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + result = intermediate_0.lower() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + assert written_code.strip() == expected_code.strip() + + +def test_proper_indentation_preserved(refactorer): + """Ensures indentation is preserved after refactoring.""" + code = textwrap.dedent( + """ + def example(): + if True: + text = "Hello" + result = text.strip().lower().replace("|", "-").title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + if True: + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + intermediate_2 = intermediate_1.replace("|", "-") + result = intermediate_2.title() + """ + ) + + smell = create_smell([5])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() + written_code = mock_write_text.call_args[0][0] + print(written_code, "\n") + assert written_code.splitlines() == expected_code.splitlines() + + +def test_method_chain_with_arguments(refactorer): + """Tests refactoring of method chains containing method arguments.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + result = text.strip().replace("H", "J").lower().title() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.replace("H", "J") + intermediate_2 = intermediate_1.lower() + result = intermediate_2.title() + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() + + +def test_print_statement_preservation(refactorer): + """Tests refactoring of print statements with method chains.""" + code = textwrap.dedent( + """ + def example(): + text = "Hello" + print(text.strip().lower().title()) + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + text = "Hello" + intermediate_0 = text.strip() + intermediate_1 = intermediate_0.lower() + print(intermediate_1.title()) + """ + ) + + smell = create_smell([4])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() + + +def test_nested_method_chains(refactorer): + """Tests refactoring of nested method chains.""" + code = textwrap.dedent( + """ + def example(): + result = get_object().config().settings().load() + """ + ) + expected_code = textwrap.dedent( + """ + def example(): + intermediate_0 = get_object() + intermediate_1 = intermediate_0.config() + intermediate_2 = intermediate_1.settings() + result = intermediate_2.load() + """ + ) + + smell = create_smell([3])() + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write, + ): + + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + written = mock_write.call_args[0][0] + assert written.strip() == expected_code.strip() diff --git a/tests/refactorers/test_member_ignoring_method.py b/tests/refactorers/test_member_ignoring_method.py new file mode 100644 index 00000000..1531049b --- /dev/null +++ b/tests/refactorers/test_member_ignoring_method.py @@ -0,0 +1,364 @@ +import pytest + +import textwrap +from pathlib import Path + +from ecooptimizer.refactorers.concrete.member_ignoring_method import MakeStaticRefactorer +from ecooptimizer.data_types import MIMSmell, Occurence +from ecooptimizer.utils.smell_enums import PylintSmell + + +@pytest.fixture +def refactorer(): + return MakeStaticRefactorer() + + +def create_smell(occurences: list[int], obj: str): + """Factory function to create a smell object""" + + def _create(): + return MIMSmell( + path="fake.py", + module="some_module", + obj=obj, + type="refactor", + symbol="no-self-use", + message="Method could be a function", + messageId=PylintSmell.NO_SELF_USE.value, + confidence="INFERENCE", + occurences=[ + Occurence( + line=occ, + endLine=999, + column=999, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=None, + ) + + return _create + + +def test_mim_basic_case(source_files, refactorer): + """ + Tests that the member ignoring method refactorer: + - Adds @staticmethod decorator. + - Removes 'self' from method signature. + - Updates calls in external files. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_basic_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import Example + example = Example() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import Example + example = Example() + result = Example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_case(source_files, refactorer): + """ + Tests that calls originating from a subclass instance are also refactored. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + class SubExample(Example): + pass + + example = SubExample() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + class SubExample(Example): + pass + + example = SubExample() + num = SubExample.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = SubExample.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_seperate_subclass(source_files, refactorer): + """ + Tests that subclasses declared in files other than the initial one are detected. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_ss_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import Example + + class SubExample(Example): + pass + + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import Example + + class SubExample(Example): + pass + + example = SubExample() + result = SubExample.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_inheritence_subclass_method_override(source_files, refactorer): + """ + Tests that calls to the mim method from subclass instance with method override are NOT changed. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_inherited_override_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + class SubExample(Example): + def mim_method(self, x): + return x * 3 + + example = Example() + num = example.mim_method(5) + """) + ) + + # --- File 2: Calls the method --- + file2 = test_dir / "caller.py" + file2.write_text( + textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + class SubExample(Example): + def mim_method(self, x): + return x * 3 + + example = Example() + num = Example.mim_method(5) + """) + + # --- Expected Result for File 2 --- + expected_file2 = textwrap.dedent("""\ + from .class_def import SubExample + example = SubExample() + result = example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + assert file2.read_text().strip() == expected_file2.strip() + + +def test_mim_type_hint_inferrence(source_files, refactorer): + """ + Tests that type hints declaring and instance type are detected. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_mim_type_hint_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + def test(example: Example): + print(example.mim_method(3)) + + example = Example() + num = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + def test(example: Example): + print(Example.mim_method(3)) + + example = Example() + num = Example.mim_method(5) + """) + + # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() diff --git a/tests/refactorers/test_str_concat_in_loop_refactor.py b/tests/refactorers/test_str_concat_in_loop_refactor.py new file mode 100644 index 00000000..ce75616a --- /dev/null +++ b/tests/refactorers/test_str_concat_in_loop_refactor.py @@ -0,0 +1,439 @@ +import pytest +from unittest.mock import patch + +from pathlib import Path + +from ecooptimizer.refactorers.concrete.str_concat_in_loop import UseListAccumulationRefactorer +from ecooptimizer.data_types import SCLInfo, Occurence, SCLSmell +from ecooptimizer.utils.smell_enums import CustomSmell + + +@pytest.fixture +def refactorer(): + return UseListAccumulationRefactorer() + + +def create_smell(occurences: list[int], concat_target: str, inner_loop_line: int): + """Factory function to create a smell object""" + + def _create(): + return SCLSmell( + path="fake.py", + module="some_module", + obj=None, + type="performance", + symbol="string-concat-loop", + message="String concatenation inside loop detected", + messageId=CustomSmell.STR_CONCAT_IN_LOOP.value, + confidence="UNDEFINED", + occurences=[ + Occurence( + line=occ, + endLine=999, + column=999, + endColumn=999, + ) + for occ in occurences + ], + additionalInfo=SCLInfo( + concatTarget=concat_target, + innerLoopLine=inner_loop_line, + ), + ) + + return _create + + +@pytest.mark.parametrize("val", [("''"), ('""'), ("str()")]) +def test_empty_initial_var(refactorer, val): + """Test for inital concat var being empty.""" + code = f""" + def example(): + result = {val} + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = []\n" in written_code + assert f"result = {val}\n" not in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_non_empty_initial_name_var_not_referenced(refactorer): + """Test for initial concat value being none empty.""" + code = """ + def example(): + result = "Hello" + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = ['Hello']\n" in written_code + assert 'result = "Hello"\n' not in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_non_empty_initial_name_var_referenced(refactorer): + """Test for initialization when var is referenced after but before the loop start.""" + code = """ + def example(): + result = "Hello" + backup = result + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[6], concat_target="result", inner_loop_line=5)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert 'result = "Hello"\n' in written_code + assert "result = [result]\n" in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_initial_not_name_var(refactorer): + """Test that none name vars are initialized to a temp list""" + code = """ + def example(): + result = {"key" : "Hello"} + for i in range(5): + result["key"] += str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target='result["key"]', inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + list_name = refactorer.generate_temp_list_name() + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert 'result = {"key" : "Hello"}\n' in written_code + assert f'{list_name} = [result["key"]]\n' in written_code + + assert f"{list_name}.append(str(i))\n" in written_code + + assert f"result[\"key\"] = ''.join({list_name})\n" in written_code + + +def test_initial_not_in_scope(refactorer): + """Test for refactoring of a concat variable not initialized in the same scope.""" + code = """ + def example(result: str): + for i in range(5): + result += str(i) + return result + """ + smell = create_smell(occurences=[4], concat_target="result", inner_loop_line=3)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + # Check that the modified code is correct + assert "result = [result]\n" in written_code + + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_insert_on_prefix(refactorer): + """Ensure insert(0) is used for prefix concatenation""" + code = """ + def example(): + result = "" + for i in range(5): + result = str(i) + result + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.insert(0, str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_concat_with_prefix_and_suffix(refactorer): + """Test for proper refactoring of a concatenation containing both a prefix and suffix concat.""" + code = """ + def example(): + result = "" + for i in range(5): + result = str(i) + result + str(i) + return result + """ + smell = create_smell(occurences=[5], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.insert(0, str(i))\n" in written_code + assert "result.append(str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_multiple_concat_occurrences(refactorer): + """Test for multiple successive concatenations in the same loop for 1 smell.""" + code = """ + def example(): + result = "" + fruits = ["apple", "banana", "orange", "kiwi"] + for fruit in fruits: + result += fruit + result = fruit + result + return result + """ + smell = create_smell(occurences=[6, 7], concat_target="result", inner_loop_line=5)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(fruit)\n" in written_code + assert "result.insert(0, fruit)\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_nested_concat(refactorer): + """Test for nested concat in loop.""" + code = """ + def example(): + result = "" + for i in range(5): + for j in range(6): + result = str(i) + result + str(j) + return result + """ + smell = create_smell(occurences=[6], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(str(j))\n" in written_code + assert "result.insert(0, str(i))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_multi_occurrence_nested_concat(refactorer): + """Test for multiple occurrences of a same smell at different loop levels.""" + code = """ + def example(): + result = "" + for i in range(5): + result += str(i) + for j in range(6): + result = result + str(j) + return result + """ + smell = create_smell(occurences=[5, 7], concat_target="result", inner_loop_line=4)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + assert "result = []\n" in written_code + assert 'result = ""\n' not in written_code + + assert "result.append(str(i))\n" in written_code + assert "result.append(str(j))\n" in written_code + + assert "result = ''.join(result)\n" in written_code + + +def test_reassignment(refactorer): + """Ensure list is reset to new val when reassigned inside the loop.""" + code = """ + class Test: + def __init__(self): + self.text = "" + obj = Test() + for word in ["bug", "warning", "Hello", "World"]: + obj.text += word + if word == "warning": + obj.text = "Well, " + """ + smell = create_smell(occurences=[7], concat_target="obj.text", inner_loop_line=6)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + list_name = refactorer.generate_temp_list_name() + + assert f"{list_name} = [obj.text]\n" in written_code + + assert f"{list_name}.append(word)\n" in written_code + assert f"{list_name} = ['Well, ']\n" in written_code # astroid changes quotes + assert 'obj.text = "Well, "\n' not in written_code + + +@pytest.mark.parametrize("val", [("''"), ('""'), ("str()")]) +def test_reassignment_clears_list(refactorer, val): + """Ensure list is cleared when reassigned inside the loop using clear().""" + code = f""" + class Test: + def __init__(self): + self.text = "" + obj = Test() + for word in ["bug", "warning", "Hello", "World"]: + obj.text += word + if word == "warning": + obj.text = {val} + """ + smell = create_smell(occurences=[7], concat_target="obj.text", inner_loop_line=6)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code = mock_write_text.call_args[0][0] # The first argument is the modified code + + list_name = refactorer.generate_temp_list_name() + + assert f"{list_name} = [obj.text]\n" in written_code + + assert f"{list_name}.append(word)\n" in written_code + assert f"{list_name}.clear()\n" in written_code + + +def test_no_unrelated_modifications(refactorer): + """Ensure formatting and any comments for unrelated lines are preserved.""" + code = """ + def example(): + print("Hello World") + # This is a comment + result = "" + unrelated_var = 0 + for i in range(5): # This is also a comment + result += str(i) + unrelated_var += i # Yep, you guessed it, comment + return result # Another one here + random = example() # And another one, why not + """ + smell = create_smell(occurences=[8], concat_target="result", inner_loop_line=7)() + + with ( + patch.object(Path, "read_text", return_value=code), + patch.object(Path, "write_text") as mock_write_text, + ): + refactorer.refactor(Path("fake.py"), Path("fake.py"), smell, Path("fake.py")) + + mock_write_text.assert_called_once() # Ensure write_text was called once + written_code: str = mock_write_text.call_args[0][0] # The first argument is the modified code + + original_lines = code.split("\n") + modified_lines = written_code.split("\n") + + assert all(line_o == line_m for line_o, line_m in zip(original_lines[:4], modified_lines[:4])) + assert all(line_o == line_m for line_o, line_m in zip(original_lines[5:7], modified_lines[5:7])) + assert original_lines[8] == modified_lines[8] + assert original_lines[9] == modified_lines[10] + assert original_lines[10] == modified_lines[11] diff --git a/tests/utils/test_outputs_config.py b/tests/utils/test_outputs_config.py new file mode 100644 index 00000000..fc8523be --- /dev/null +++ b/tests/utils/test_outputs_config.py @@ -0,0 +1,5 @@ +import pytest + + +def test_placeholder(): + pytest.fail("TODO: Implement this test")