diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2fd2aea2..4897cfe6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -18,4 +18,18 @@ jobs: uses: psf/black@stable with: options: "--check --verbose" - version: "23.3.0" \ No newline at end of file + version: "23.3.0" + + - name: Install Python 3 + uses: actions/setup-python@v4 + with: + python-version: 3.10.5 + + - name: Install dependencies pyright + run: | + python -m pip install --upgrade pip + pip install pyright + python -m pip install -e . + + - name: Run pyright + run: pyright \ No newline at end of file diff --git a/glitch/__main__.py b/glitch/__main__.py index 6ef17617..fc701988 100644 --- a/glitch/__main__.py +++ b/glitch/__main__.py @@ -1,4 +1,7 @@ import click, os, sys + +from pathlib import Path +from typing import Tuple, List, Set, Optional from glitch.analysis.rules import Error, RuleVisitor from glitch.helpers import RulesListOption, get_smell_types, get_smells from glitch.parsers.docker import DockerParser @@ -6,26 +9,35 @@ from glitch.stats.stats import FileStats from glitch.tech import Tech from glitch.repr.inter import UnitBlockType +from glitch.parsers.parser import Parser from glitch.parsers.ansible import AnsibleParser from glitch.parsers.chef import ChefParser from glitch.parsers.puppet import PuppetParser from glitch.parsers.terraform import TerraformParser from pkg_resources import resource_filename -from alive_progress import alive_bar -from pathlib import Path +from alive_progress import alive_bar # type: ignore + # NOTE: These are necessary in order for python to load the visitors. # Otherwise, python will not consider these types of rules. -from glitch.analysis.design import DesignVisitor -from glitch.analysis.security import SecurityVisitor - - -def parse_and_check(type, path, module, parser, analyses, errors, stats): +from glitch.analysis.design import DesignVisitor # type: ignore +from glitch.analysis.security import SecurityVisitor # type: ignore + + +def parse_and_check( + type: UnitBlockType, + path: str, + module: bool, + parser: Parser, + analyses: List[RuleVisitor], + errors: List[Error], + stats: FileStats, +) -> None: inter = parser.parse(path, type, module) if inter != None: for analysis in analyses: errors += analysis.check(inter) - stats.compute(inter) + stats.compute(inter) @click.command( @@ -33,7 +45,7 @@ def parse_and_check(type, path, module, parser, analyses, errors, stats): ) @click.option( "--tech", - type=click.Choice(Tech), + type=click.Choice([t.value for t in Tech]), required=True, help="The IaC technology in which the scripts analyzed are written in.", ) @@ -46,7 +58,7 @@ def parse_and_check(type, path, module, parser, analyses, errors, stats): ) @click.option( "--type", - type=click.Choice(UnitBlockType), + type=click.Choice([t.value for t in UnitBlockType]), default=UnitBlockType.unknown, help="The type of scripts being analyzed.", ) @@ -97,19 +109,22 @@ def parse_and_check(type, path, module, parser, analyses, errors, stats): @click.argument("path", type=click.Path(exists=True), required=True) @click.argument("output", type=click.Path(), required=False) def glitch( - tech, - type, - path, - config, - module, - csv, - dataset, - includeall, - smell_types, - output, - tableformat, - linter, + tech: str, + type: str, + path: str, + config: str, + module: bool, + csv: bool, + dataset: bool, + includeall: Tuple[str, ...], + smell_types: Tuple[str, ...], + output: Optional[str], + tableformat: str, + linter: bool, ): + tech = Tech(tech) + type = UnitBlockType(type) + if config != "configs/default.ini" and not os.path.exists(config): raise click.BadOptionUsage( "config", f"Invalid value for 'config': Path '{config}' does not exist." @@ -138,7 +153,7 @@ def glitch( if smell_types == (): smell_types = get_smell_types() - analyses = [] + analyses: List[RuleVisitor] = [] rules = RuleVisitor.__subclasses__() for r in rules: if smell_types == () or r.get_name() in smell_types: @@ -146,10 +161,10 @@ def glitch( analysis.config(config) analyses.append(analysis) - errors = [] + errors: List[Error] = [] if dataset: if includeall != (): - iac_files = [] + iac_files: Set[str] = set() for root, _, files in os.walk(path): for name in files: name_split = name.split(".") @@ -157,13 +172,12 @@ def glitch( name_split[-1] in includeall and not Path(os.path.join(root, name)).is_symlink() ): - iac_files.append(os.path.join(root, name)) - iac_files = set(iac_files) + iac_files.add(os.path.join(root, name)) with alive_bar( len(iac_files), title=f"ANALYZING ALL FILES WITH EXTENSIONS {includeall}", - ) as bar: + ) as bar: # type: ignore for file in iac_files: parse_and_check( type, file, module, parser, analyses, errors, file_stats @@ -171,7 +185,7 @@ def glitch( bar() else: subfolders = [f.path for f in os.scandir(f"{path}") if f.is_dir()] - with alive_bar(len(subfolders), title="ANALYZING SUBFOLDERS") as bar: + with alive_bar(len(subfolders), title="ANALYZING SUBFOLDERS") as bar: # type: ignore for d in subfolders: parse_and_check( type, d, module, parser, analyses, errors, file_stats @@ -180,7 +194,7 @@ def glitch( files = [f.path for f in os.scandir(f"{path}") if f.is_file()] - with alive_bar(len(files), title="ANALYZING FILES IN ROOT FOLDER") as bar: + with alive_bar(len(files), title="ANALYZING FILES IN ROOT FOLDER") as bar: # type: ignore for file in files: parse_and_check( type, file, module, parser, analyses, errors, file_stats @@ -212,7 +226,7 @@ def glitch( print_stats(errors, get_smells(smell_types, tech), file_stats, tableformat) -def main(): +def main() -> None: glitch(prog_name="glitch") diff --git a/glitch/analysis/design.py b/glitch/analysis/design.py index b20dd388..9effbe75 100644 --- a/glitch/analysis/design.py +++ b/glitch/analysis/design.py @@ -1,16 +1,17 @@ -from cmath import inf import json import re import configparser + +from cmath import inf from glitch.analysis.rules import Error, RuleVisitor, SmellChecker from glitch.tech import Tech - from glitch.repr.inter import * +from typing import List, Tuple, Dict, Set class DesignVisitor(RuleVisitor): class ImproperAlignmentSmell(SmellChecker): - def check(self, element, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: if isinstance(element, AtomicUnit): identation = None for a in element.attributes: @@ -36,7 +37,12 @@ class PuppetImproperAlignmentSmell(SmellChecker): cached_file = "" lines = [] - def check(self, element, file: str) -> list[Error]: + def check(self, element: CodeElement, file: str) -> List[Error]: + if not isinstance(element, AtomicUnit) and not isinstance( + element, UnitBlock + ): + return [] + if DesignVisitor.PuppetImproperAlignmentSmell.cached_file != file: with open(file, "r") as f: DesignVisitor.PuppetImproperAlignmentSmell.lines = f.readlines() @@ -81,17 +87,17 @@ def check(self, element, file: str) -> list[Error]: class AnsibleImproperAlignmentSmell(SmellChecker): # YAML does not allow improper alignments (it also would have problems with generic attributes for all modules) - def check(self, element: AtomicUnit, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: return [] class MisplacedAttribute(SmellChecker): - def check(self, element, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: return [] class ChefMisplacedAttribute(SmellChecker): - def check(self, element, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: if isinstance(element, AtomicUnit): - order = [] + order: List[int] = [] for attribute in element.attributes: if attribute.name == "source": order.append(1) @@ -111,7 +117,7 @@ def check(self, element, file: str): return [] class PuppetMisplacedAttribute(SmellChecker): - def check(self, element, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: if isinstance(element, AtomicUnit): for i, attr in enumerate(element.attributes): if attr.name == "ensure" and i != 0: @@ -161,15 +167,15 @@ def __init__(self, tech: Tech) -> None: else: self.comment = "//" - self.variable_stack = [] - self.variables_names = [] + self.variable_stack: List[int] = [] + self.variables_names: List[str] = [] self.first_code_line = inf @staticmethod def get_name() -> str: return "design" - def config(self, config_path: str): + def config(self, config_path: str) -> None: config = configparser.ConfigParser() config.read(config_path) DesignVisitor.__EXEC = json.loads(config["design"]["exec_atomic_units"]) @@ -179,7 +185,7 @@ def config(self, config_path: str): if "var_refer_symbol" not in config["design"]: DesignVisitor.__VAR_REFER_SYMBOL = None else: - DesignVisitor.__VAR_REFER_SYMBOL = json.loads( + DesignVisitor.__VAR_REFER_SYMBOL = json.loads( # type: ignore config["design"]["var_refer_symbol"] ) @@ -190,8 +196,8 @@ def check_module(self, m: Module) -> list[Error]: # errors.append(Error('design_unnecessary_abstraction', m, m.path, repr(m))) return errors - def check_unitblock(self, u: UnitBlock) -> list[Error]: - def count_atomic_units(ub: UnitBlock): + def check_unitblock(self, u: UnitBlock) -> List[Error]: + def count_atomic_units(ub: UnitBlock) -> Tuple[int, int]: count_resources = len(ub.atomic_units) count_execs = 0 for au in ub.atomic_units: @@ -224,7 +230,7 @@ def count_atomic_units(ub: UnitBlock): for attr in u.attributes: self.variables_names.append(attr.name) - errors = [] + errors: List[Error] = [] # The order is important for au in u.atomic_units: errors += self.check_atomicunit(au, u.path) @@ -254,7 +260,7 @@ def count_atomic_units(ub: UnitBlock): error.line = i + 1 errors.append(error) - def count_variables(vars: list[Variable]): + def count_variables(vars: List[KeyValue]) -> int: count = 0 for var in vars: if isinstance(var.value, type(None)): @@ -266,7 +272,7 @@ def count_variables(vars: list[Variable]): # The UnitBlock should not be of type vars, because these files are supposed to only # have variables if ( - count_variables(u.variables) / max(len(code_lines), 1) > 0.3 + count_variables(u.variables) / max(len(code_lines), 1) > 0.3 # type: ignore and u.type != UnitBlockType.vars ): errors.append( @@ -293,12 +299,13 @@ def count_variables(vars: list[Variable]): error.line = i + 1 errors.append(error) - def get_line(i, lines): + def get_line(i: int, lines: List[Tuple[int, int]]): for j, line in lines: if i < j: return line + raise RuntimeError("Line not found") - lines = [] + lines: List[Tuple[int, int]] = [] current_line = 1 i = 0 for c in all_code: @@ -310,7 +317,7 @@ def get_line(i, lines): i += 1 lines.append((i, current_line)) - blocks = {} + blocks: Dict[int, List[int]] = {} for i in range(len(code) - 150): hash = code[i : i + 150].__hash__() if hash not in blocks: @@ -319,7 +326,7 @@ def get_line(i, lines): blocks[hash].append(i) # Note: changing the structure to a set instead of a list increased the speed A LOT - checked = set() + checked: Set[int] = set() for _, value in blocks.items(): if len(value) >= 2: for i in value: @@ -363,7 +370,9 @@ def check_atomicunit(self, au: AtomicUnit, file: str) -> list[Error]: errors += self.misplaced_attr.check(au, file) if au.type in DesignVisitor.__EXEC: - if "&&" in au.name or ";" in au.name or "|" in au.name: + if isinstance(au.name, str) and ( + "&&" in au.name or ";" in au.name or "|" in au.name + ): errors.append( Error("design_multifaceted_abstraction", au, file, repr(au)) ) @@ -391,9 +400,7 @@ def check_atomicunit(self, au: AtomicUnit, file: str) -> list[Error]: def check_dependency(self, d: Dependency, file: str) -> list[Error]: return [] - def check_attribute( - self, a: Attribute, file: str, au: AtomicUnit = None, parent_name: str = "" - ) -> list[Error]: + def check_attribute(self, a: Attribute, file: str) -> list[Error]: return [] def check_variable(self, v: Variable, file: str) -> list[Error]: @@ -401,7 +408,7 @@ def check_variable(self, v: Variable, file: str) -> list[Error]: return [] def check_comment(self, c: Comment, file: str) -> list[Error]: - errors = [] + errors: List[Error] = [] if c.line >= self.first_non_comm_line: errors.append(Error("design_avoid_comments", c, file, repr(c))) return errors diff --git a/glitch/analysis/rules.py b/glitch/analysis/rules.py index 49e5ba33..6ea900c4 100644 --- a/glitch/analysis/rules.py +++ b/glitch/analysis/rules.py @@ -1,10 +1,13 @@ +from typing import Dict, Optional, Union, List, Any +from abc import ABC, abstractmethod from glitch.tech import Tech from glitch.repr.inter import * -from abc import ABC, abstractmethod + +ErrorDict = Dict[str, Union[Union[Tech, str], "ErrorDict"]] class Error: - ERRORS = { + ERRORS: ErrorDict = { "security": { "sec_https": "Use of HTTP without TLS - The developers should always favor the usage of HTTPS. (CWE-319)", "sec_susp_comm": "Suspicious comment - Comments with keywords such as TODO HACK or FIXME may reveal problems possibly exploitable. (CWE-546)", @@ -59,11 +62,11 @@ class Error: }, } - ALL_ERRORS = {} + ALL_ERRORS: Dict[str, str] = {} @staticmethod - def agglomerate_errors(): - def aux_agglomerate_errors(key, errors): + def agglomerate_errors() -> None: + def aux_agglomerate_errors(key: str, errors: Union[str, ErrorDict]) -> None: if isinstance(errors, dict): for k, v in errors.items(): aux_agglomerate_errors(k, v) @@ -73,7 +76,7 @@ def aux_agglomerate_errors(key, errors): aux_agglomerate_errors("", Error.ERRORS) def __init__( - self, code: str, el, path: str, repr: str, opt_msg: str = None + self, code: str, el: Any, path: str, repr: str, opt_msg: Optional[str] = None ) -> None: self.code: str = code self.el = el @@ -110,7 +113,7 @@ def __repr__(self) -> str: def __hash__(self): return hash((self.code, self.path, self.line, self.opt_msg)) - def __eq__(self, other): + def __eq__(self, other: Any): if not isinstance(other, type(self)): return NotImplemented return ( @@ -129,24 +132,22 @@ def __init__(self, tech: Tech) -> None: self.tech = tech self.code = None - def check(self, code) -> list[Error]: + def check(self, code: Project | Module | UnitBlock) -> List[Error]: self.code = code if isinstance(code, Project): return self.check_project(code) elif isinstance(code, Module): return self.check_module(code) - elif isinstance(code, UnitBlock): + else: return self.check_unitblock(code) - def check_element( - self, c, file: str, au_type=None, parent_name: str = "" - ) -> list[Error]: + def check_element(self, c: CodeElement, file: str) -> list[Error]: if isinstance(c, AtomicUnit): return self.check_atomicunit(c, file) elif isinstance(c, Dependency): return self.check_dependency(c, file) elif isinstance(c, Attribute): - return self.check_attribute(c, file, au_type, parent_name) + return self.check_attribute(c, file) elif isinstance(c, Variable): return self.check_variable(c, file) elif isinstance(c, ConditionalStatement): @@ -154,13 +155,14 @@ def check_element( elif isinstance(c, Comment): return self.check_comment(c, file) elif isinstance(c, dict): - errors = [] - for k, v in c.items(): - errors += self.check_element(k, file) + self.check_element(v, file) + errors: List[Error] = [] + for k, v in c.items(): # type: ignore + errors += self.check_element(k, file) + self.check_element(v, file) # type: ignore return errors else: return [] + @staticmethod @abstractmethod def get_name() -> str: pass @@ -170,7 +172,7 @@ def config(self, config_path: str): pass def check_project(self, p: Project) -> list[Error]: - errors = [] + errors: List[Error] = [] for m in p.modules: errors += self.check_module(m) @@ -180,14 +182,14 @@ def check_project(self, p: Project) -> list[Error]: return errors def check_module(self, m: Module) -> list[Error]: - errors = [] + errors: List[Error] = [] for u in m.blocks: errors += self.check_unitblock(u) return errors def check_unitblock(self, u: UnitBlock) -> list[Error]: - errors = [] + errors: List[Error] = [] for au in u.atomic_units: errors += self.check_atomicunit(au, u.path) for c in u.comments: @@ -204,9 +206,9 @@ def check_unitblock(self, u: UnitBlock) -> list[Error]: return errors def check_atomicunit(self, au: AtomicUnit, file: str) -> list[Error]: - errors = [] + errors: List[Error] = [] for a in au.attributes: - errors += self.check_attribute(a, file, au.type) + errors += self.check_attribute(a, file) for s in au.statements: errors += self.check_element(s, file) @@ -218,9 +220,7 @@ def check_dependency(self, d: Dependency, file: str) -> list[Error]: pass @abstractmethod - def check_attribute( - self, a: Attribute, file: str, au_type: None, parent_name: str = "" - ) -> list[Error]: + def check_attribute(self, a: Attribute, file: str) -> list[Error]: pass @abstractmethod @@ -228,7 +228,7 @@ def check_variable(self, v: Variable, file: str) -> list[Error]: pass def check_condition(self, c: ConditionalStatement, file: str) -> list[Error]: - errors = [] + errors: List[Error] = [] for s in c.statements: errors += self.check_element(s, file) @@ -245,8 +245,8 @@ def check_comment(self, c: Comment, file: str) -> list[Error]: class SmellChecker(ABC): def __init__(self) -> None: - self.code = None + self.code: Optional[Project | UnitBlock | Module] = None @abstractmethod - def check(self, element, file: str) -> list[Error]: + def check(self, element: CodeElement, file: str) -> list[Error]: pass diff --git a/glitch/analysis/security.py b/glitch/analysis/security.py index aed7fbeb..31972910 100644 --- a/glitch/analysis/security.py +++ b/glitch/analysis/security.py @@ -5,7 +5,7 @@ import configparser from urllib.parse import urlparse from glitch.analysis.rules import Error, RuleVisitor, SmellChecker -from nltk.tokenize import WordPunctTokenizer +from nltk.tokenize import WordPunctTokenizer # type: ignore from typing import Tuple, List, Optional from glitch.tech import Tech @@ -18,15 +18,15 @@ class SecurityVisitor(RuleVisitor): __URL_REGEX = r"^(http:\/\/www\.|https:\/\/www\.|http:\/\/|https:\/\/)?[a-z0-9]+([_\-\.]{1}[a-z0-9]+)*\.[a-z]{2,5}(:[0-9]{1,5})?(\/.*)?$" class EmptyChecker(SmellChecker): - def check(self, element, file: str): + def check(self, element: CodeElement, file: str) -> List[Error]: return [] class NonOfficialImageSmell(SmellChecker): - def check(self, element, file: str) -> List[Error]: + def check(self, element: CodeElement, file: str) -> List[Error]: return [] class DockerNonOfficialImageSmell(SmellChecker): - def check(self, element, file: str) -> List[Error]: + def check(self, element: CodeElement, file: str) -> List[Error]: if ( not isinstance(element, UnitBlock) or element.name is None @@ -40,7 +40,7 @@ def check(self, element, file: str) -> List[Error]: def __init__(self, tech: Tech) -> None: super().__init__(tech) - self.checkers = [] + self.checkers: List[SmellChecker] = [] if tech == Tech.terraform: for child in TerraformSmellChecker.__subclasses__(): @@ -55,7 +55,7 @@ def __init__(self, tech: Tech) -> None: def get_name() -> str: return "security" - def config(self, config_path: str): + def config(self, config_path: str) -> None: config = configparser.ConfigParser() config.read(config_path) SecurityVisitor.__WRONG_WORDS = json.loads( @@ -94,76 +94,76 @@ def config(self, config_path: str): ) if self.tech == Tech.terraform: - SecurityVisitor._INTEGRITY_POLICY = json.loads( + SecurityVisitor.INTEGRITY_POLICY = json.loads( config["security"]["integrity_policy"] ) - SecurityVisitor._HTTPS_CONFIGS = json.loads( + SecurityVisitor.HTTPS_CONFIGS = json.loads( config["security"]["ensure_https"] ) - SecurityVisitor._SSL_TLS_POLICY = json.loads( + SecurityVisitor.SSL_TLS_POLICY = json.loads( config["security"]["ssl_tls_policy"] ) - SecurityVisitor._DNSSEC_CONFIGS = json.loads( + SecurityVisitor.DNSSEC_CONFIGS = json.loads( config["security"]["ensure_dnssec"] ) - SecurityVisitor._PUBLIC_IP_CONFIGS = json.loads( + SecurityVisitor.PUBLIC_IP_CONFIGS = json.loads( config["security"]["use_public_ip"] ) - SecurityVisitor._POLICY_KEYWORDS = json.loads( + SecurityVisitor.POLICY_KEYWORDS = json.loads( config["security"]["policy_keywords"] ) - SecurityVisitor._ACCESS_CONTROL_CONFIGS = json.loads( + SecurityVisitor.ACCESS_CONTROL_CONFIGS = json.loads( config["security"]["insecure_access_control"] ) - SecurityVisitor._AUTHENTICATION = json.loads( + SecurityVisitor.AUTHENTICATION = json.loads( config["security"]["authentication"] ) - SecurityVisitor._POLICY_ACCESS_CONTROL = json.loads( + SecurityVisitor.POLICY_ACCESS_CONTROL = json.loads( config["security"]["policy_insecure_access_control"] ) - SecurityVisitor._POLICY_AUTHENTICATION = json.loads( + SecurityVisitor.POLICY_AUTHENTICATION = json.loads( config["security"]["policy_authentication"] ) - SecurityVisitor._MISSING_ENCRYPTION = json.loads( + SecurityVisitor.MISSING_ENCRYPTION = json.loads( config["security"]["missing_encryption"] ) - SecurityVisitor._CONFIGURATION_KEYWORDS = json.loads( + SecurityVisitor.CONFIGURATION_KEYWORDS = json.loads( config["security"]["configuration_keywords"] ) - SecurityVisitor._ENCRYPT_CONFIG = json.loads( + SecurityVisitor.ENCRYPT_CONFIG = json.loads( config["security"]["encrypt_configuration"] ) - SecurityVisitor._FIREWALL_CONFIGS = json.loads( + SecurityVisitor.FIREWALL_CONFIGS = json.loads( config["security"]["firewall"] ) - SecurityVisitor._MISSING_THREATS_DETECTION_ALERTS = json.loads( + SecurityVisitor.MISSING_THREATS_DETECTION_ALERTS = json.loads( config["security"]["missing_threats_detection_alerts"] ) - SecurityVisitor._PASSWORD_KEY_POLICY = json.loads( + SecurityVisitor.PASSWORD_KEY_POLICY = json.loads( config["security"]["password_key_policy"] ) - SecurityVisitor._KEY_MANAGEMENT = json.loads( + SecurityVisitor.KEY_MANAGEMENT = json.loads( config["security"]["key_management"] ) - SecurityVisitor._NETWORK_SECURITY_RULES = json.loads( + SecurityVisitor.NETWORK_SECURITY_RULES = json.loads( config["security"]["network_security_rules"] ) - SecurityVisitor._PERMISSION_IAM_POLICIES = json.loads( + SecurityVisitor.PERMISSION_IAM_POLICIES = json.loads( config["security"]["permission_iam_policies"] ) - SecurityVisitor._GOOGLE_IAM_MEMBER = json.loads( + SecurityVisitor.GOOGLE_IAM_MEMBER = json.loads( config["security"]["google_iam_member_resources"] ) - SecurityVisitor._LOGGING = json.loads(config["security"]["logging"]) - SecurityVisitor._GOOGLE_SQL_DATABASE_LOG_FLAGS = json.loads( + SecurityVisitor.LOGGING = json.loads(config["security"]["logging"]) + SecurityVisitor.GOOGLE_SQL_DATABASE_LOG_FLAGS = json.loads( config["security"]["google_sql_database_log_flags"] ) - SecurityVisitor._POSSIBLE_ATTACHED_RESOURCES = json.loads( + SecurityVisitor.POSSIBLE_ATTACHED_RESOURCES = json.loads( config["security"]["possible_attached_resources_aws_route53"] ) - SecurityVisitor._VERSIONING = json.loads(config["security"]["versioning"]) - SecurityVisitor._NAMING = json.loads(config["security"]["naming"]) - SecurityVisitor._REPLICATION = json.loads(config["security"]["replication"]) + SecurityVisitor.VERSIONING = json.loads(config["security"]["versioning"]) + SecurityVisitor.NAMING = json.loads(config["security"]["naming"]) + SecurityVisitor.REPLICATION = json.loads(config["security"]["replication"]) SecurityVisitor.__FILE_COMMANDS = json.loads( config["security"]["file_commands"] @@ -194,11 +194,6 @@ def check_atomicunit(self, au: AtomicUnit, file: str) -> List[Error]: continue for a in au.attributes: values = [a.value] - if isinstance(a.value, ConditionalStatement): - statements = a.value.statements - if len(statements) == 0: - continue - values = statements[0].values() for value in values: if not isinstance(value, str): continue @@ -209,6 +204,13 @@ def check_atomicunit(self, au: AtomicUnit, file: str) -> List[Error]: Error("sec_full_permission_filesystem", a, file, repr(a)) ) + for attribute in au.attributes: + if ( + au.type in SecurityVisitor.__GITHUB_ACTIONS + and attribute.name == "plaintext_value" + ): + errors.append(Error("sec_hard_secr", attribute, file, repr(attribute))) + if au.type in SecurityVisitor.__OBSOLETE_COMMANDS: errors.append(Error("sec_obsolete_command", au, file, repr(au))) elif any(au.type.endswith(res) for res in SecurityVisitor.__SHELL_RESOURCES): @@ -234,17 +236,15 @@ def check_atomicunit(self, au: AtomicUnit, file: str) -> List[Error]: def check_dependency(self, d: Dependency, file: str) -> List[Error]: return [] - def __check_keyvalue( - self, c: KeyValue, file: str, au_type=None, parent_name: str = "" - ): - errors = [] + def __check_keyvalue(self, c: KeyValue, file: str) -> List[Error]: + errors: List[Error] = [] c.name = c.name.strip().lower() if isinstance(c.value, type(None)): for child in c.keyvalues: - errors += self.check_element(child, file, au_type, c.name) + errors += self.check_element(child, file) return errors - elif isinstance(c.value, str): + elif isinstance(c.value, str): # type: ignore c.value = c.value.strip().lower() else: errors += self.check_element(c.value, file) @@ -258,7 +258,7 @@ def __check_keyvalue( or (c.name == "ip" and c.value in {"*", "::"}) or ( c.name in SecurityVisitor.__IP_BIND_COMMANDS - and (c.value == True or c.value in {"*", "::"}) + and (c.value == True or c.value in {"*", "::"}) # type: ignore ) ): errors.append(Error("sec_invalid_bind", c, file, repr(c))) @@ -279,10 +279,12 @@ def __check_keyvalue( errors.append(Error("sec_def_admin", c, file, repr(c))) break - def get_au(c, name: str, type: str): + def get_au( + c: Project | Module | UnitBlock | None, name: str, type: str + ) -> AtomicUnit | None: if isinstance(c, Project): module_name = os.path.basename(os.path.dirname(file)) - for m in self.code.modules: + for m in c.modules: if m.name == module_name: return get_au(m, name, type) elif isinstance(c, Module): @@ -296,10 +298,12 @@ def get_au(c, name: str, type: str): return au return None - def get_module_var(c, name: str): + def get_module_var( + c: Project | Module | UnitBlock | None, name: str + ) -> Variable | None: if isinstance(c, Project): module_name = os.path.basename(os.path.dirname(file)) - for m in self.code.modules: + for m in c.modules: if m.name == module_name: return get_module_var(m, name) elif isinstance(c, Module): @@ -386,9 +390,6 @@ def get_module_var(c, name: str): if "password" in item_value: errors.append(Error("sec_hard_pass", c, file, repr(c))) - if au_type in SecurityVisitor.__GITHUB_ACTIONS and c.name == "plaintext_value": - errors.append(Error("sec_hard_secr", c, file, repr(c))) - if c.has_variable and var is not None: c.has_variable = var.has_variable c.value = var.value @@ -399,22 +400,20 @@ def get_module_var(c, name: str): return errors - def check_attribute( - self, a: Attribute, file: str, au_type=None, parent_name: str = "" - ) -> list[Error]: - return self.__check_keyvalue(a, file, au_type, parent_name) + def check_attribute(self, a: Attribute, file: str) -> list[Error]: + return self.__check_keyvalue(a, file) def check_variable(self, v: Variable, file: str) -> list[Error]: return self.__check_keyvalue(v, file) def check_comment(self, c: Comment, file: str) -> List[Error]: - errors = [] + errors: List[Error] = [] lines = c.content.split("\n") stop = False for word in SecurityVisitor.__WRONG_WORDS: for line in lines: tokenizer = WordPunctTokenizer() - tokens = tokenizer.tokenize(line.lower()) + tokens = tokenizer.tokenize(line.lower()) # type: ignore if word in tokens: errors.append(Error("sec_susp_comm", c, file, line)) stop = True @@ -465,6 +464,9 @@ def check_unitblock(self, u: UnitBlock) -> List[Error]: @staticmethod def check_integrity_check(au: AtomicUnit, path: str) -> Optional[Tuple[str, Error]]: for item in SecurityVisitor.__DOWNLOAD: + if not isinstance(au.name, str): + continue + if not re.search( r"(http|https|www)[^ ,]*\.{text}".format(text=item), au.name ): @@ -489,14 +491,14 @@ def check_integrity_check(au: AtomicUnit, path: str) -> Optional[Tuple[str, Erro continue if SecurityVisitor.__has_integrity_check(au.attributes): return None - return os.path.basename(a.value), Error( + return os.path.basename(a.value), Error( # type: ignore "sec_no_int_check", au, path, repr(a) - ) + ) # type: ignore return None @staticmethod def check_has_checksum(au: AtomicUnit) -> Optional[str]: - if au.type not in SecurityVisitor.__CHECKSUM: + if au.type not in SecurityVisitor.__CHECKSUM or au.name is None: return None if any(d in au.name for d in SecurityVisitor.__DOWNLOAD): return os.path.basename(au.name) @@ -517,9 +519,13 @@ def __has_integrity_check(attributes: List[Attribute]) -> bool: name = attr.name.strip().lower() if any([check in name for check in SecurityVisitor.__CHECKSUM]): return True + return False @staticmethod - def __is_http_url(value: str) -> bool: + def __is_http_url(value: str | None) -> bool: + if value is None: + return False + if ( re.match(SecurityVisitor.__URL_REGEX, value) and ("http" in value or "www" in value) @@ -536,7 +542,10 @@ def __is_http_url(value: str) -> bool: return False @staticmethod - def __is_weak_crypt(value: str, name: str) -> bool: + def __is_weak_crypt(value: str, name: str | None) -> bool: + if name is None: + return False + if any(crypt in value for crypt in SecurityVisitor.__CRYPT): whitelist = any( word in name or word in value diff --git a/glitch/analysis/terraform/__init__.py b/glitch/analysis/terraform/__init__.py index 3d8d8a6e..de10ffe6 100644 --- a/glitch/analysis/terraform/__init__.py +++ b/glitch/analysis/terraform/__init__.py @@ -1,7 +1,8 @@ import os +from typing import List -__all__ = [] +__all__: List[str] = [] for file in os.listdir(os.path.dirname(__file__)): if file.endswith(".py") and file != "__init__.py": - __all__.append(file[:-3]) + __all__.append(file[:-3]) # type: ignore diff --git a/glitch/analysis/terraform/access_control.py b/glitch/analysis/terraform/access_control.py index 3340a8ee..c52bf871 100644 --- a/glitch/analysis/terraform/access_control.py +++ b/glitch/analysis/terraform/access_control.py @@ -3,22 +3,28 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformAccessControl(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for item in SecurityVisitor._POLICY_KEYWORDS: + for item in SecurityVisitor.POLICY_KEYWORDS: if item.lower() == attribute.name: - for config in SecurityVisitor._POLICY_ACCESS_CONTROL: - expr = config["keyword"].lower() + "\s*" + config["value"].lower() + for config in SecurityVisitor.POLICY_ACCESS_CONTROL: + expr = config["keyword"].lower() + "\\s*" + config["value"].lower() pattern = re.compile(rf"{expr}") - allow_expr = '"effect":' + "\s*" + '"allow"' + allow_expr = '"effect":' + "\\s*" + '"allow"' allow_pattern = re.compile(rf"{allow_expr}") - if re.search(pattern, attribute.value) and re.search( - allow_pattern, attribute.value + if ( + isinstance(attribute.value, str) + and re.search(pattern, attribute.value) + and re.search(allow_pattern, attribute.value) ): return [ Error( @@ -50,16 +56,18 @@ def _check_attribute( attribute.name == "email" and parent_name == "service_account" and atomic_unit.type == "resource.google_compute_instance" + and isinstance(attribute.value, str) and re.search(r".-compute@developer.gserviceaccount.com", attribute.value) ): return [Error("sec_access_control", attribute, file, repr(attribute))] - for config in SecurityVisitor._ACCESS_CONTROL_CONFIGS: + for config in SecurityVisitor.ACCESS_CONTROL_CONFIGS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] and config["values"] != [""] ): @@ -67,8 +75,8 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_api_gateway_method": http_method = self.check_required_attribute( @@ -77,7 +85,12 @@ def check(self, element, file: str): authorization = self.check_required_attribute( element.attributes, [""], "authorization" ) - if http_method and authorization: + if ( + isinstance(http_method, KeyValue) + and isinstance(authorization, KeyValue) + and http_method.value is not None + and authorization.value is not None + ): if ( http_method.value.lower() == "get" and authorization.value.lower() == "none" @@ -86,7 +99,8 @@ def check(self, element, file: str): element.attributes, [""], "api_key_required" ) if ( - api_key_required + isinstance(api_key_required, KeyValue) + and api_key_required.value is not None and f"{api_key_required.value}".lower() != "true" ): errors.append( @@ -121,7 +135,9 @@ def check(self, element, file: str): visibility = self.check_required_attribute( element.attributes, [""], "visibility" ) - if visibility is not None: + if isinstance(visibility, KeyValue) and isinstance( + visibility.value, str + ): if visibility.value.lower() not in ["private", "internal"]: errors.append( Error( @@ -132,7 +148,7 @@ def check(self, element, file: str): private = self.check_required_attribute( element.attributes, [""], "private" ) - if private is not None: + if isinstance(private, KeyValue) and isinstance(private.value, str): if f"{private.value}".lower() != "true": errors.append( Error( @@ -158,7 +174,7 @@ def check(self, element, file: str): "off", ) elif element.type == "resource.aws_s3_bucket": - expr = "\${aws_s3_bucket\." + f"{element.name}\." + expr = "\\${aws_s3_bucket\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") if ( self.get_associated_au( @@ -181,7 +197,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._ACCESS_CONTROL_CONFIGS: + for config in SecurityVisitor.ACCESS_CONTROL_CONFIGS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/attached_resource.py b/glitch/analysis/terraform/attached_resource.py index a7ff7ae2..fe12f36e 100644 --- a/glitch/analysis/terraform/attached_resource.py +++ b/glitch/analysis/terraform/attached_resource.py @@ -1,15 +1,19 @@ +from typing import List from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit +from glitch.repr.inter import AtomicUnit, CodeElement, KeyValue, Attribute class TerraformAttachedResource(TerraformSmellChecker): - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] + if isinstance(element, AtomicUnit): - def check_attached_resource(attributes, resource_types): + def check_attached_resource( + attributes: List[KeyValue] | List[Attribute], resource_types: List[str] + ) -> bool: for a in attributes: if a.value != None: for resource_type in resource_types: @@ -32,7 +36,7 @@ def check_attached_resource(attributes, resource_types): element.attributes, [""], "type", "a" ) if type_A and not check_attached_resource( - element.attributes, SecurityVisitor._POSSIBLE_ATTACHED_RESOURCES + element.attributes, SecurityVisitor.POSSIBLE_ATTACHED_RESOURCES ): errors.append( Error("sec_attached_resource", element, file, repr(element)) diff --git a/glitch/analysis/terraform/authentication.py b/glitch/analysis/terraform/authentication.py index 07ca6f79..c2839b4f 100644 --- a/glitch/analysis/terraform/authentication.py +++ b/glitch/analysis/terraform/authentication.py @@ -3,22 +3,28 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformAuthentication(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for item in SecurityVisitor._POLICY_KEYWORDS: + for item in SecurityVisitor.POLICY_KEYWORDS: if item.lower() == attribute.name: - for config in SecurityVisitor._POLICY_AUTHENTICATION: + for config in SecurityVisitor.POLICY_AUTHENTICATION: if atomic_unit.type in config["au_type"]: expr = ( - config["keyword"].lower() + "\s*" + config["value"].lower() + config["keyword"].lower() + "\\s*" + config["value"].lower() ) pattern = re.compile(rf"{expr}") - if not re.search(pattern, attribute.value): + if isinstance(attribute.value, str) and not re.search( + pattern, attribute.value + ): return [ Error( "sec_authentication", @@ -28,12 +34,13 @@ def _check_attribute( ) ] - for config in SecurityVisitor._AUTHENTICATION: + for config in SecurityVisitor.AUTHENTICATION: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] and config["values"] != [""] ): @@ -41,8 +48,8 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.google_sql_database_instance": @@ -54,7 +61,7 @@ def check(self, element, file: str): "off", ) elif element.type == "resource.aws_iam_group": - expr = "\${aws_iam_group\." + f"{element.name}\." + expr = "\\${aws_iam_group\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") if not self.get_associated_au( file, "resource.aws_iam_group_policy", "group", pattern, [""] @@ -70,7 +77,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._AUTHENTICATION: + for config in SecurityVisitor.AUTHENTICATION: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/dns_policy.py b/glitch/analysis/terraform/dns_policy.py index 4712814e..f0474955 100644 --- a/glitch/analysis/terraform/dns_policy.py +++ b/glitch/analysis/terraform/dns_policy.py @@ -2,29 +2,34 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformDnsWithoutDnssec(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._DNSSEC_CONFIGS: + for config in SecurityVisitor.DNSSEC_CONFIGS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] and config["values"] != [""] ): return [Error("sec_dnssec", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for config in SecurityVisitor._DNSSEC_CONFIGS: + for config in SecurityVisitor.DNSSEC_CONFIGS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/firewall_misconfig.py b/glitch/analysis/terraform/firewall_misconfig.py index 2d783c8c..b1c065a1 100644 --- a/glitch/analysis/terraform/firewall_misconfig.py +++ b/glitch/analysis/terraform/firewall_misconfig.py @@ -2,14 +2,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformFirewallMisconfig(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._FIREWALL_CONFIGS: + for config in SecurityVisitor.FIREWALL_CONFIGS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -18,6 +22,7 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [ @@ -28,6 +33,7 @@ def _check_attribute( elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [ @@ -37,10 +43,10 @@ def _check_attribute( ] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for config in SecurityVisitor._FIREWALL_CONFIGS: + for config in SecurityVisitor.FIREWALL_CONFIGS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/http_without_tls.py b/glitch/analysis/terraform/http_without_tls.py index 37b3bc30..0f2bcb46 100644 --- a/glitch/analysis/terraform/http_without_tls.py +++ b/glitch/analysis/terraform/http_without_tls.py @@ -2,31 +2,40 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformHttpWithoutTls(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._HTTPS_CONFIGS: + for config in SecurityVisitor.HTTPS_CONFIGS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [Error("sec_https", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "data.http": url = self.check_required_attribute(element.attributes, [""], "url") - if "${" in url.value: + if ( + isinstance(url, KeyValue) + and isinstance(url.value, str) + and "${" in url.value + ): vars = url.value.split("${") r = url.value.split("${")[1].split("}")[0] for var in vars: @@ -44,7 +53,7 @@ def check(self, element, file: str): if self.get_au(file, resource_name, type + "." + resource_type): errors.append(Error("sec_https", url, file, repr(url))) - for config in SecurityVisitor._HTTPS_CONFIGS: + for config in SecurityVisitor.HTTPS_CONFIGS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/integrity_policy.py b/glitch/analysis/terraform/integrity_policy.py index 7202b87f..cdf658a4 100644 --- a/glitch/analysis/terraform/integrity_policy.py +++ b/glitch/analysis/terraform/integrity_policy.py @@ -2,29 +2,34 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformIntegrityPolicy(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for policy in SecurityVisitor._INTEGRITY_POLICY: + for policy in SecurityVisitor.INTEGRITY_POLICY: if ( attribute.name == policy["attribute"] and atomic_unit.type in policy["au_type"] and parent_name in policy["parents"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in policy["values"] ): return [Error("sec_integrity_policy", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for policy in SecurityVisitor._INTEGRITY_POLICY: + for policy in SecurityVisitor.INTEGRITY_POLICY: if ( policy["required"] == "yes" and element.type in policy["au_type"] diff --git a/glitch/analysis/terraform/key_management.py b/glitch/analysis/terraform/key_management.py index 67e69856..3fa30bca 100644 --- a/glitch/analysis/terraform/key_management.py +++ b/glitch/analysis/terraform/key_management.py @@ -3,14 +3,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformKeyManagement(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._KEY_MANAGEMENT: + for config in SecurityVisitor.KEY_MANAGEMENT: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -19,6 +23,7 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [ @@ -27,6 +32,7 @@ def _check_attribute( elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [ @@ -39,7 +45,9 @@ def _check_attribute( ): expr1 = r"\d+\.\d{0,9}s" expr2 = r"\d+s" - if re.search(expr1, attribute.value) or re.search(expr2, attribute.value): + if isinstance(attribute.value, str) and ( + re.search(expr1, attribute.value) or re.search(expr2, attribute.value) + ): if int(attribute.value.split("s")[0]) > 7776000: return [ Error("sec_key_management", attribute, file, repr(attribute)) @@ -60,11 +68,11 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.azurerm_storage_account": - expr = "\${azurerm_storage_account\." + f"{element.name}\." + expr = "\\${azurerm_storage_account\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") if not self.get_associated_au( file, @@ -83,7 +91,7 @@ def check(self, element, file: str): + f"associated to an 'azurerm_storage_account' resource.", ) ) - for config in SecurityVisitor._KEY_MANAGEMENT: + for config in SecurityVisitor.KEY_MANAGEMENT: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/logging.py b/glitch/analysis/terraform/logging.py index c27b2432..8d1d242c 100644 --- a/glitch/analysis/terraform/logging.py +++ b/glitch/analysis/terraform/logging.py @@ -4,19 +4,19 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformLogging(TerraformSmellChecker): def __check_log_attribute( self, - element, + element: AtomicUnit, attribute_name: str, file: str, values: List[str], all: bool = False, - ): - errors = [] + ) -> List[Error]: + errors: List[Error] = [] attribute = self.check_required_attribute( element.attributes, [""], f"{attribute_name}[0]" ) @@ -25,14 +25,18 @@ def __check_log_attribute( active = True for v in values[:]: attribute_checked, _ = self.iterate_required_attributes( - element.attributes, attribute_name, lambda x: x.value.lower() == v + element.attributes, + attribute_name, + lambda x: isinstance(x.value, str) and x.value.lower() == v, ) if attribute_checked: values.remove(v) active = active and attribute_checked else: active, _ = self.iterate_required_attributes( - element.attributes, attribute_name, lambda x: x.value.lower() in values + element.attributes, + attribute_name, + lambda x: isinstance(x.value, str) and x.value.lower() in values, ) if attribute is None: @@ -60,16 +64,22 @@ def __check_log_attribute( return errors - def __check_azurerm_storage_container(self, element, file: str): - errors = [] + def __check_azurerm_storage_container(self, element: AtomicUnit, file: str): + errors: List[Error] = [] container_access_type = self.check_required_attribute( element.attributes, [""], "container_access_type" ) - if container_access_type and container_access_type.value.lower() not in [ - "blob", - "private", - ]: + if ( + container_access_type + and isinstance(container_access_type, Attribute) + and isinstance(container_access_type.value, str) + and container_access_type.value.lower() + not in [ + "blob", + "private", + ] + ): errors.append( Error( "sec_logging", @@ -84,6 +94,8 @@ def __check_azurerm_storage_container(self, element, file: str): ) if not ( storage_account_name is not None + and isinstance(storage_account_name, Attribute) + and isinstance(storage_account_name.value, str) and storage_account_name.value.lower().startswith( "${azurerm_storage_account." ) @@ -115,7 +127,7 @@ def __check_azurerm_storage_container(self, element, file: str): ) return errors - expr = "\${azurerm_storage_account\." + f"{name}\." + expr = "\\${azurerm_storage_account\\." + f"{name}\\." pattern = re.compile(rf"{expr}") assoc_au = self.get_associated_au( file, @@ -153,7 +165,7 @@ def __check_azurerm_storage_container(self, element, file: str): return errors contains_blob_name, _ = self.iterate_required_attributes( - assoc_au.attributes, "blob_container_names", lambda x: x.value + assoc_au.attributes, "blob_container_names", lambda x: x.value # type: ignore ) if not contains_blob_name: errors.append( @@ -168,13 +180,19 @@ def __check_azurerm_storage_container(self, element, file: str): return errors def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: if ( attribute.name == "cloud_watch_logs_group_arn" and atomic_unit.type == "resource.aws_cloudtrail" ): - if re.match(r"^\${aws_cloudwatch_log_group\..", attribute.value): + if isinstance(attribute.value, str) and re.match( + r"^\${aws_cloudwatch_log_group\..", attribute.value + ): aws_cloudwatch_log_group_name = attribute.value.split(".")[1] if not self.get_au( file, @@ -209,8 +227,11 @@ def _check_attribute( and atomic_unit.type == "resource.azurerm_network_watcher_flow_log" ) ) and ( - (not attribute.value.isnumeric()) - or (attribute.value.isnumeric() and int(attribute.value) < 90) + isinstance(attribute.value, str) + and ( + not attribute.value.isnumeric() + or (attribute.value.isnumeric() and int(attribute.value) < 90) + ) ): return [Error("sec_logging", attribute, file, repr(attribute))] elif ( @@ -218,13 +239,16 @@ def _check_attribute( and parent_name == "retention_policy" and atomic_unit.type == "resource.azurerm_monitor_log_profile" and ( - not attribute.value.isnumeric() - or (attribute.value.isnumeric() and int(attribute.value) < 365) + isinstance(attribute.value, str) + and ( + not attribute.value.isnumeric() + or (attribute.value.isnumeric() and int(attribute.value) < 365) + ) ) ): return [Error("sec_logging", attribute, file, repr(attribute))] - for config in SecurityVisitor._LOGGING: + for config in SecurityVisitor.LOGGING: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -233,20 +257,22 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [Error("sec_logging", attribute, file, repr(attribute))] elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [Error("sec_logging", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_eks_cluster": errors.extend( @@ -268,21 +294,27 @@ def check(self, element, file: str): broker_logs = self.check_required_attribute( element.attributes, ["logging_info"], "broker_logs" ) - if broker_logs is not None: + if isinstance(broker_logs, KeyValue): active = False logs_type = ["cloudwatch_logs", "firehose", "s3"] - a_list = [] + a_list: List[KeyValue] = [] for type in logs_type: log = self.check_required_attribute( broker_logs.keyvalues, [""], type ) - if log is not None: + if isinstance(log, KeyValue): enabled = self.check_required_attribute( log.keyvalues, [""], "enabled" ) - if enabled and f"{enabled.value}".lower() == "true": + if ( + isinstance(enabled, KeyValue) + and f"{enabled.value}".lower() == "true" + ): active = True - elif enabled and f"{enabled.value}".lower() != "true": + elif ( + isinstance(enabled, KeyValue) + and f"{enabled.value}".lower() != "true" + ): a_list.append(enabled) if not active and a_list == []: errors.append( @@ -325,7 +357,7 @@ def check(self, element, file: str): ) ) elif element.type == "resource.azurerm_mssql_server": - expr = "\${azurerm_mssql_server\." + f"{element.name}\." + expr = "\\${azurerm_mssql_server\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") assoc_au = self.get_associated_au( file, @@ -346,7 +378,7 @@ def check(self, element, file: str): ) ) elif element.type == "resource.azurerm_mssql_database": - expr = "\${azurerm_mssql_database\." + f"{element.name}\." + expr = "\\${azurerm_mssql_database\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") assoc_au = self.get_associated_au( file, @@ -370,10 +402,12 @@ def check(self, element, file: str): name = self.check_required_attribute(element.attributes, [""], "name") value = self.check_required_attribute(element.attributes, [""], "value") if ( - name + isinstance(name, KeyValue) + and isinstance(name.value, str) and name.value.lower() in ["log_connections", "connection_throttling", "log_checkpoints"] - and value + and isinstance(value, KeyValue) + and isinstance(value.value, str) and value.value.lower() != "on" ): errors.append(Error("sec_logging", value, file, repr(value))) @@ -388,7 +422,7 @@ def check(self, element, file: str): ) ) elif element.type == "resource.google_sql_database_instance": - for flag in SecurityVisitor._GOOGLE_SQL_DATABASE_LOG_FLAGS: + for flag in SecurityVisitor.GOOGLE_SQL_DATABASE_LOG_FLAGS: required_flag = True if flag["required"] == "no": required_flag = False @@ -410,8 +444,11 @@ def check(self, element, file: str): enabled = self.check_required_attribute( element.attributes, ["setting"], "value" ) - if enabled is not None: - if enabled.value.lower() != "enabled": + if isinstance(enabled, KeyValue): + if ( + isinstance(enabled.value, str) + and enabled.value.lower() != "enabled" + ): errors.append( Error("sec_logging", enabled, file, repr(enabled)) ) @@ -436,7 +473,7 @@ def check(self, element, file: str): ) ) elif element.type == "resource.aws_vpc": - expr = "\${aws_vpc\." + f"{element.name}\." + expr = "\\${aws_vpc\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") assoc_au = self.get_associated_au( file, "resource.aws_flow_log", "vpc_id", pattern, [""] @@ -453,7 +490,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._LOGGING: + for config in SecurityVisitor.LOGGING: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/missing_encryption.py b/glitch/analysis/terraform/missing_encryption.py index 5637b694..e0fcc46a 100644 --- a/glitch/analysis/terraform/missing_encryption.py +++ b/glitch/analysis/terraform/missing_encryption.py @@ -3,14 +3,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformMissingEncryption(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._MISSING_ENCRYPTION: + for config in SecurityVisitor.MISSING_ENCRYPTION: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -19,6 +23,7 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [ @@ -29,6 +34,7 @@ def _check_attribute( elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [ @@ -36,16 +42,17 @@ def _check_attribute( "sec_missing_encryption", attribute, file, repr(attribute) ) ] - for item in SecurityVisitor._CONFIGURATION_KEYWORDS: + for item in SecurityVisitor.CONFIGURATION_KEYWORDS: if item.lower() == attribute.name: - for config in SecurityVisitor._ENCRYPT_CONFIG: + for config in SecurityVisitor.ENCRYPT_CONFIG: if atomic_unit.type in config["au_type"]: expr = ( - config["keyword"].lower() + "\s*" + config["value"].lower() + config["keyword"].lower() + "\\s*" + config["value"].lower() ) pattern = re.compile(rf"{expr}") if ( - not re.search(pattern, attribute.value) + isinstance(attribute.value, str) + and not re.search(pattern, attribute.value) and config["required"] == "yes" ): return [ @@ -57,7 +64,8 @@ def _check_attribute( ) ] elif ( - re.search(pattern, attribute.value) + isinstance(attribute.value, str) + and re.search(pattern, attribute.value) and config["required"] == "must_not_exist" ): return [ @@ -71,11 +79,11 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_s3_bucket": - expr = "\${aws_s3_bucket\." + f"{element.name}\." + expr = "\\${aws_s3_bucket\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") r = self.get_associated_au( file, @@ -99,12 +107,15 @@ def check(self, element, file: str): resources = self.check_required_attribute( element.attributes, ["encryption_config"], "resources[0]" ) - if resources is not None: + if isinstance(resources, KeyValue): i = 0 valid = False - while resources: + while isinstance(resources, KeyValue): a = resources - if resources.value.lower() == "secrets": + if ( + isinstance(resources.value, str) + and resources.value.lower() == "secrets" + ): valid = True break i += 1 @@ -112,7 +123,7 @@ def check(self, element, file: str): element.attributes, ["encryption_config"], f"resources[{i}]" ) if not valid: - errors.append(Error("sec_missing_encryption", a, file, repr(a))) + errors.append(Error("sec_missing_encryption", a, file, repr(a))) # type: ignore else: errors.append( Error( @@ -130,7 +141,7 @@ def check(self, element, file: str): ebs_block_device = self.check_required_attribute( element.attributes, [""], "ebs_block_device" ) - if ebs_block_device is not None: + if isinstance(ebs_block_device, KeyValue): encrypted = self.check_required_attribute( ebs_block_device.keyvalues, [""], "encrypted" ) @@ -148,11 +159,11 @@ def check(self, element, file: str): volume = self.check_required_attribute( element.attributes, [""], "volume" ) - if volume is not None: + if isinstance(volume, KeyValue): efs_volume_config = self.check_required_attribute( volume.keyvalues, [""], "efs_volume_configuration" ) - if efs_volume_config is not None: + if isinstance(efs_volume_config, KeyValue): transit_encryption = self.check_required_attribute( efs_volume_config.keyvalues, [""], "transit_encryption" ) @@ -168,7 +179,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._MISSING_ENCRYPTION: + for config in SecurityVisitor.MISSING_ENCRYPTION: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/naming.py b/glitch/analysis/terraform/naming.py index 3d9f31ff..508f4a2c 100644 --- a/glitch/analysis/terraform/naming.py +++ b/glitch/analysis/terraform/naming.py @@ -3,21 +3,27 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformNaming(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: if attribute.name == "name" and atomic_unit.type in [ "resource.azurerm_storage_account" ]: pattern = r"^[a-z0-9]{3,24}$" - if not re.match(pattern, attribute.value): + if isinstance(attribute.value, str) and not re.match( + pattern, attribute.value + ): return [Error("sec_naming", attribute, file, repr(attribute))] - for config in SecurityVisitor._NAMING: + for config in SecurityVisitor.NAMING: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -26,20 +32,22 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [Error("sec_naming", attribute, file, repr(attribute))] elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [Error("sec_naming", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_security_group": ingress = self.check_required_attribute( @@ -48,7 +56,7 @@ def check(self, element, file: str): egress = self.check_required_attribute( element.attributes, [""], "egress" ) - if ingress and not self.check_required_attribute( + if isinstance(ingress, KeyValue) and not self.check_required_attribute( ingress.keyvalues, [""], "description" ): errors.append( @@ -60,7 +68,7 @@ def check(self, element, file: str): f"Suggestion: check for a required attribute with name 'ingress.description'.", ) ) - if egress and not self.check_required_attribute( + if isinstance(egress, KeyValue) and not self.check_required_attribute( egress.keyvalues, [""], "description" ): errors.append( @@ -76,7 +84,10 @@ def check(self, element, file: str): resource_labels = self.check_required_attribute( element.attributes, [""], "resource_labels", None ) - if resource_labels and resource_labels.value == None: + if ( + isinstance(resource_labels, KeyValue) + and resource_labels.value is None + ): if resource_labels.keyvalues == []: errors.append( Error( @@ -98,7 +109,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._NAMING: + for config in SecurityVisitor.NAMING: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/network_policy.py b/glitch/analysis/terraform/network_policy.py index 84d32dca..83c001c0 100644 --- a/glitch/analysis/terraform/network_policy.py +++ b/glitch/analysis/terraform/network_policy.py @@ -3,14 +3,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformNetworkSecurityRules(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for rule in SecurityVisitor._NETWORK_SECURITY_RULES: + for rule in SecurityVisitor.NETWORK_SECURITY_RULES: if ( attribute.name == rule["attribute"] and atomic_unit.type in rule["au_type"] @@ -28,47 +32,71 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.azurerm_network_security_rule": access = self.check_required_attribute( element.attributes, [""], "access" ) - if access and access.value.lower() == "allow": + if ( + isinstance(access, KeyValue) + and isinstance(access.value, str) + and access.value.lower() == "allow" + ): protocol = self.check_required_attribute( element.attributes, [""], "protocol" ) - if protocol and protocol.value.lower() == "udp": + if ( + isinstance(protocol, KeyValue) + and isinstance(protocol.value, str) + and protocol.value.lower() == "udp" + ): errors.append( Error( "sec_network_security_rules", access, file, repr(access) ) ) - elif protocol and protocol.value.lower() == "tcp": + elif ( + isinstance(protocol, KeyValue) + and isinstance(protocol.value, str) + and protocol.value.lower() == "tcp" + ): dest_port_range = self.check_required_attribute( element.attributes, [""], "destination_port_range" ) - port = dest_port_range and dest_port_range.value.lower() in [ - "22", - "3389", - "*", - ] + port = ( + isinstance(dest_port_range, KeyValue) + and isinstance(dest_port_range.value, str) + and dest_port_range.value.lower() + in [ + "22", + "3389", + "*", + ] + ) port_ranges, _ = self.iterate_required_attributes( element.attributes, "destination_port_ranges", - lambda x: (x.value.lower() in ["22", "3389", "*"]), + lambda x: ( + isinstance(x.value, str) + and x.value.lower() in ["22", "3389", "*"] + ), ) if port or port_ranges: source_address_prefix = self.check_required_attribute( element.attributes, [""], "source_address_prefix" ) - if source_address_prefix and ( - source_address_prefix.value.lower() - in ["*", "/0", "internet", "any"] - or re.match( - r"^0.0.0.0", source_address_prefix.value.lower() + if ( + isinstance(source_address_prefix, KeyValue) + and isinstance(source_address_prefix.value, str) + and ( + source_address_prefix.value.lower() + in ["*", "/0", "internet", "any"] + or re.match( + r"^0.0.0.0", source_address_prefix.value.lower() + ) ) ): errors.append( @@ -83,35 +111,56 @@ def check(self, element, file: str): access = self.check_required_attribute( element.attributes, ["security_rule"], "access" ) - if access and access.value.lower() == "allow": + if ( + isinstance(access, KeyValue) + and isinstance(access.value, str) + and access.value.lower() == "allow" + ): protocol = self.check_required_attribute( element.attributes, ["security_rule"], "protocol" ) - if protocol and protocol.value.lower() == "udp": + if ( + isinstance(protocol, KeyValue) + and isinstance(protocol.value, str) + and protocol.value.lower() == "udp" + ): errors.append( Error( "sec_network_security_rules", access, file, repr(access) ) ) - elif protocol and protocol.value.lower() == "tcp": + elif ( + isinstance(protocol, KeyValue) + and isinstance(protocol.value, str) + and protocol.value.lower() == "tcp" + ): dest_port_range = self.check_required_attribute( element.attributes, ["security_rule"], "destination_port_range", ) - if dest_port_range and dest_port_range.value.lower() in [ - "22", - "3389", - "*", - ]: + if ( + isinstance(dest_port_range, KeyValue) + and isinstance(dest_port_range.value, str) + and dest_port_range.value.lower() + in [ + "22", + "3389", + "*", + ] + ): source_address_prefix = self.check_required_attribute( element.attributes, [""], "source_address_prefix" ) - if source_address_prefix and ( - source_address_prefix.value.lower() - in ["*", "/0", "internet", "any"] - or re.match( - r"^0.0.0.0", source_address_prefix.value.lower() + if ( + isinstance(source_address_prefix, KeyValue) + and isinstance(source_address_prefix.value, str) + and ( + source_address_prefix.value.lower() + in ["*", "/0", "internet", "any"] + or re.match( + r"^0.0.0.0", source_address_prefix.value.lower() + ) ) ): errors.append( @@ -123,7 +172,7 @@ def check(self, element, file: str): ) ) - for rule in SecurityVisitor._NETWORK_SECURITY_RULES: + for rule in SecurityVisitor.NETWORK_SECURITY_RULES: if ( rule["required"] == "yes" and element.type in rule["au_type"] diff --git a/glitch/analysis/terraform/permission_iam_policies.py b/glitch/analysis/terraform/permission_iam_policies.py index ebf3708e..fa7ed426 100644 --- a/glitch/analysis/terraform/permission_iam_policies.py +++ b/glitch/analysis/terraform/permission_iam_policies.py @@ -3,16 +3,21 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformPermissionIAMPolicies(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: if ( (attribute.name == "member" or attribute.name.split("[")[0] == "members") - and atomic_unit.type in SecurityVisitor._GOOGLE_IAM_MEMBER + and atomic_unit.type in SecurityVisitor.GOOGLE_IAM_MEMBER + and isinstance(attribute.value, str) and ( re.search(r".-compute@developer.gserviceaccount.com", attribute.value) or re.search(r".@appspot.gserviceaccount.com", attribute.value) @@ -23,7 +28,7 @@ def _check_attribute( Error("sec_permission_iam_policies", attribute, file, repr(attribute)) ] - for config in SecurityVisitor._PERMISSION_IAM_POLICIES: + for config in SecurityVisitor.PERMISSION_IAM_POLICIES: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -33,9 +38,11 @@ def _check_attribute( if ( config["logic"] == "equal" and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ) or ( config["logic"] == "diff" + and isinstance(attribute.value, str) and attribute.value.lower() in config["values"] ): return [ @@ -49,11 +56,11 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_iam_user": - expr = "\${aws_iam_user\." + f"{element.name}\." + expr = "\\${aws_iam_user\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") assoc_au = self.get_associated_au( file, "resource.aws_iam_user_policy", "user", pattern, [""] diff --git a/glitch/analysis/terraform/public_ip.py b/glitch/analysis/terraform/public_ip.py index 3ff6b89a..b249ddc6 100644 --- a/glitch/analysis/terraform/public_ip.py +++ b/glitch/analysis/terraform/public_ip.py @@ -2,14 +2,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformPublicIp(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._PUBLIC_IP_CONFIGS: + for config in SecurityVisitor.PUBLIC_IP_CONFIGS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -23,10 +27,10 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for config in SecurityVisitor._PUBLIC_IP_CONFIGS: + for config in SecurityVisitor.PUBLIC_IP_CONFIGS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/replication.py b/glitch/analysis/terraform/replication.py index 74b9d1bf..cf858270 100644 --- a/glitch/analysis/terraform/replication.py +++ b/glitch/analysis/terraform/replication.py @@ -3,31 +3,36 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, CodeElement, KeyValue class TerraformReplication(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._REPLICATION: + for config in SecurityVisitor.REPLICATION: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and config["values"] != [""] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [Error("sec_replication", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type == "resource.aws_s3_bucket": - expr = "\${aws_s3_bucket\." + f"{element.name}\." + expr = "\\${aws_s3_bucket\\." + f"{element.name}\\." pattern = re.compile(rf"{expr}") if not self.get_associated_au( file, @@ -47,7 +52,7 @@ def check(self, element, file: str): ) ) - for config in SecurityVisitor._REPLICATION: + for config in SecurityVisitor.REPLICATION: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/sensitive_iam_action.py b/glitch/analysis/terraform/sensitive_iam_action.py index 50e462f1..964fec3e 100644 --- a/glitch/analysis/terraform/sensitive_iam_action.py +++ b/glitch/analysis/terraform/sensitive_iam_action.py @@ -1,19 +1,20 @@ import json +from typing import List, Dict from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error -from glitch.repr.inter import AtomicUnit +from glitch.repr.inter import AtomicUnit, CodeElement, KeyValue class TerraformSensitiveIAMAction(TerraformSmellChecker): - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] - def convert_string_to_dict(input_string): + def convert_string_to_dict(input_string: str): cleaned_string = input_string.strip() try: dict_data = json.loads(cleaned_string) return dict_data - except json.JSONDecodeError as e: + except json.JSONDecodeError: return None if not isinstance(element, AtomicUnit): @@ -25,14 +26,20 @@ def convert_string_to_dict(input_string): statements = self.check_required_attribute( element.attributes, [""], "statement", return_all=True ) - if statements is not None: + if isinstance(statements, list): for statement in statements: allow = self.check_required_attribute( statement.keyvalues, [""], "effect" ) - if (allow and allow.value.lower() == "allow") or (not allow): + if ( + isinstance(allow, KeyValue) + and isinstance(allow.value, str) + and allow.value.lower() == "allow" + ) or (not allow): sensitive_action, action = self.iterate_required_attributes( - statement.keyvalues, "actions", lambda x: "*" in x.value.lower() + statement.keyvalues, + "actions", + lambda x: isinstance(x.value, str) and "*" in x.value.lower(), ) if sensitive_action: errors.append( @@ -44,8 +51,8 @@ def convert_string_to_dict(input_string): wildcarded_resource, resource = self.iterate_required_attributes( statement.keyvalues, "resources", - lambda x: (x.value.lower() in ["*"]) - or (":*" in x.value.lower()), + lambda x: isinstance(x.value, str) + and ((x.value.lower() in ["*"]) or (":*" in x.value.lower())), ) if wildcarded_resource: errors.append( @@ -63,18 +70,18 @@ def convert_string_to_dict(input_string): "resource.aws_iam_group_policy", ]: policy = self.check_required_attribute(element.attributes, [""], "policy") - if policy is None: + if not isinstance(policy, KeyValue) or not isinstance(policy.value, str): return errors policy_dict = convert_string_to_dict(policy.value.lower()) if not (policy_dict and policy_dict["statement"]): return errors - statements = policy_dict["statement"] + policy_statements = policy_dict["statement"] if isinstance(statements, dict): - statements = [statements] + policy_statements: List[Dict[str, str | List[str]]] = [statements] - for statement in statements: + for statement in policy_statements: if not ( statement["effect"] and statement["action"] diff --git a/glitch/analysis/terraform/smell_checker.py b/glitch/analysis/terraform/smell_checker.py index e0f97544..b453e7c5 100644 --- a/glitch/analysis/terraform/smell_checker.py +++ b/glitch/analysis/terraform/smell_checker.py @@ -1,12 +1,20 @@ import os import re -from typing import List, Callable + +from re import Pattern +from typing import Optional, List, Callable, Any from glitch.repr.inter import * from glitch.analysis.rules import Error, SmellChecker class TerraformSmellChecker(SmellChecker): - def get_au(self, file: str, name: str, type: str, c=None): + def get_au( + self, + file: str, + name: str, + type: str, + c: Project | Module | UnitBlock | None = None, + ) -> Optional[AtomicUnit]: c = self.code if c is None else c if isinstance(c, Project): module_name = os.path.basename(os.path.dirname(file)) @@ -29,10 +37,10 @@ def get_associated_au( file: str, type: str, attribute_name: str, - pattern, - attribute_parents: list, - code=None, - ): + pattern: Pattern[str], + attribute_parents: List[str], + code: Project | Module | UnitBlock | None = None, + ) -> Optional[AtomicUnit]: code = self.code if code is None else code if isinstance(code, Project): module_name = os.path.basename(os.path.dirname(file)) @@ -57,17 +65,30 @@ def get_associated_au( return None def get_attributes_with_name_and_value( - self, attributes, parents, name, value=None, pattern=None - ): - aux = [] + self, + attributes: List[KeyValue] | List[Attribute], + parents: List[str], + name: str, + value: Optional[Any] = None, + pattern: Optional[Pattern[str]] = None, + ) -> List[KeyValue]: + aux: List[KeyValue] = [] for a in attributes: if a.name.split("dynamic.")[-1] == name and parents == [""]: - if (value and a.value.lower() == value) or ( - pattern and re.match(pattern, a.value.lower()) + if ( + value and isinstance(a.value, str) and a.value.lower() == value + ) or ( + pattern + and isinstance(a.value, str) + and re.match(pattern, a.value.lower()) ): aux.append(a) - elif (value and a.value.lower() != value) or ( - pattern and not re.match(pattern, a.value.lower()) + elif ( + value and isinstance(a.value, str) and a.value.lower() != value + ) or ( + pattern + and isinstance(a.value, str) + and not re.match(pattern, a.value.lower()) ): continue elif not value and not pattern: @@ -83,17 +104,23 @@ def get_attributes_with_name_and_value( return aux def check_required_attribute( - self, attributes, parents, name, value=None, pattern=None, return_all=False - ): + self, + attributes: List[Attribute] | List[KeyValue], + parents: List[str], + name: str, + value: Optional[Any] = None, + pattern: Optional[Pattern[str]] = None, + return_all: bool = False, + ) -> Union[Optional[KeyValue], List[KeyValue]]: attributes = self.get_attributes_with_name_and_value( attributes, parents, name, value, pattern ) if attributes != []: if return_all: - return attributes + return attributes # type: ignore return attributes[0] - else: - return None + + return None def check_database_flags( self, @@ -102,13 +129,13 @@ def check_database_flags( smell: str, flag_name: str, safe_value: str, - required_flag=True, - ): + required_flag: bool = True, + ) -> List[Error]: database_flags = self.get_attributes_with_name_and_value( au.attributes, ["settings"], "database_flags" ) found_flag = False - errors = [] + errors: List[Error] = [] if database_flags != []: for flag in database_flags: name = self.check_required_attribute( @@ -117,7 +144,11 @@ def check_database_flags( if name is not None: found_flag = True value = self.check_required_attribute(flag.keyvalues, [""], "value") - if value and value.value.lower() != safe_value: + if ( + isinstance(value, KeyValue) + and isinstance(value.value, str) + and value.value.lower() != safe_value + ): errors.append(Error(smell, value, file, repr(value))) break elif not value and required_flag: @@ -144,12 +175,15 @@ def check_database_flags( return errors def iterate_required_attributes( - self, attributes: List[KeyValue], name: str, check: Callable[[KeyValue], bool] + self, + attributes: List[KeyValue] | List[Attribute], + name: str, + check: Callable[[KeyValue], bool], ): i = 0 attribute = self.check_required_attribute(attributes, [""], f"{name}[{i}]") - while attribute: + while isinstance(attribute, KeyValue): if check(attribute): return True, attribute i += 1 @@ -158,14 +192,22 @@ def iterate_required_attributes( return False, None def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - pass + return [] def __check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - errors = [] + errors: List[Error] = [] errors += self._check_attribute(attribute, atomic_unit, parent_name, file) for attr_child in attribute.keyvalues: errors += self.__check_attribute( @@ -174,7 +216,7 @@ def __check_attribute( return errors def _check_attributes(self, atomic_unit: AtomicUnit, file: str) -> List[Error]: - errors = [] + errors: List[Error] = [] for attribute in atomic_unit.attributes: errors += self.__check_attribute(attribute, atomic_unit, "", file) return errors diff --git a/glitch/analysis/terraform/ssl_tls_policy.py b/glitch/analysis/terraform/ssl_tls_policy.py index 83a9c715..97a7b230 100644 --- a/glitch/analysis/terraform/ssl_tls_policy.py +++ b/glitch/analysis/terraform/ssl_tls_policy.py @@ -2,14 +2,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, KeyValue, CodeElement class TerraformSslTlsPolicy(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for policy in SecurityVisitor._SSL_TLS_POLICY: + for policy in SecurityVisitor.SSL_TLS_POLICY: if ( attribute.name == policy["attribute"] and atomic_unit.type in policy["au_type"] @@ -22,8 +26,8 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): if element.type in [ "resource.aws_alb_listener", @@ -32,7 +36,11 @@ def check(self, element, file: str): protocol = self.check_required_attribute( element.attributes, [""], "protocol" ) - if protocol and protocol.value.lower() in ["https", "tls"]: + if ( + isinstance(protocol, KeyValue) + and isinstance(protocol.value, str) + and protocol.value.lower() in ["https", "tls"] + ): ssl_policy = self.check_required_attribute( element.attributes, [""], "ssl_policy" ) @@ -47,7 +55,7 @@ def check(self, element, file: str): ) ) - for policy in SecurityVisitor._SSL_TLS_POLICY: + for policy in SecurityVisitor.SSL_TLS_POLICY: if ( policy["required"] == "yes" and element.type in policy["au_type"] diff --git a/glitch/analysis/terraform/threats_detection.py b/glitch/analysis/terraform/threats_detection.py index 7678c564..2ae509f4 100644 --- a/glitch/analysis/terraform/threats_detection.py +++ b/glitch/analysis/terraform/threats_detection.py @@ -2,14 +2,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, KeyValue, CodeElement class TerraformThreatsDetection(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._MISSING_THREATS_DETECTION_ALERTS: + for config in SecurityVisitor.MISSING_THREATS_DETECTION_ALERTS: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] @@ -18,6 +22,7 @@ def _check_attribute( ): if ( "any_not_empty" in config["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [ @@ -31,6 +36,7 @@ def _check_attribute( elif ( "any_not_empty" not in config["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [ @@ -44,10 +50,10 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for config in SecurityVisitor._MISSING_THREATS_DETECTION_ALERTS: + for config in SecurityVisitor.MISSING_THREATS_DETECTION_ALERTS: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/versioning.py b/glitch/analysis/terraform/versioning.py index 5642238b..5d640d10 100644 --- a/glitch/analysis/terraform/versioning.py +++ b/glitch/analysis/terraform/versioning.py @@ -2,30 +2,35 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, KeyValue, CodeElement class TerraformVersioning(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for config in SecurityVisitor._VERSIONING: + for config in SecurityVisitor.VERSIONING: if ( attribute.name == config["attribute"] and atomic_unit.type in config["au_type"] and parent_name in config["parents"] and config["values"] != [""] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in config["values"] ): return [Error("sec_versioning", attribute, file, repr(attribute))] return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for config in SecurityVisitor._VERSIONING: + for config in SecurityVisitor.VERSIONING: if ( config["required"] == "yes" and element.type in config["au_type"] diff --git a/glitch/analysis/terraform/weak_password_key_policy.py b/glitch/analysis/terraform/weak_password_key_policy.py index 08965a91..052d92a1 100644 --- a/glitch/analysis/terraform/weak_password_key_policy.py +++ b/glitch/analysis/terraform/weak_password_key_policy.py @@ -2,14 +2,18 @@ from glitch.analysis.terraform.smell_checker import TerraformSmellChecker from glitch.analysis.rules import Error from glitch.analysis.security import SecurityVisitor -from glitch.repr.inter import AtomicUnit, Attribute +from glitch.repr.inter import AtomicUnit, Attribute, KeyValue, CodeElement class TerraformWeakPasswordKeyPolicy(TerraformSmellChecker): def _check_attribute( - self, attribute: Attribute, atomic_unit: AtomicUnit, parent_name: str, file: str + self, + attribute: Attribute | KeyValue, + atomic_unit: AtomicUnit, + parent_name: str, + file: str, ) -> List[Error]: - for policy in SecurityVisitor._PASSWORD_KEY_POLICY: + for policy in SecurityVisitor.PASSWORD_KEY_POLICY: if ( attribute.name == policy["attribute"] and atomic_unit.type in policy["au_type"] @@ -19,6 +23,7 @@ def _check_attribute( if policy["logic"] == "equal": if ( "any_not_empty" in policy["values"] + and isinstance(attribute.value, str) and attribute.value.lower() == "" ): return [ @@ -32,6 +37,7 @@ def _check_attribute( elif ( "any_not_empty" not in policy["values"] and not attribute.has_variable + and isinstance(attribute.value, str) and attribute.value.lower() not in policy["values"] ): return [ @@ -42,8 +48,13 @@ def _check_attribute( repr(attribute), ) ] - elif (policy["logic"] == "gte" and not attribute.value.isnumeric()) or ( + elif ( policy["logic"] == "gte" + and isinstance(attribute.value, str) + and not attribute.value.isnumeric() + ) or ( + policy["logic"] == "gte" + and isinstance(attribute.value, str) and attribute.value.isnumeric() and int(attribute.value) < int(policy["values"][0]) ): @@ -55,8 +66,13 @@ def _check_attribute( repr(attribute), ) ] - elif (policy["logic"] == "lte" and not attribute.value.isnumeric()) or ( + elif ( + policy["logic"] == "lte" + and isinstance(attribute.value, str) + and not attribute.value.isnumeric() + ) or ( policy["logic"] == "lte" + and isinstance(attribute.value, str) and attribute.value.isnumeric() and int(attribute.value) > int(policy["values"][0]) ): @@ -71,10 +87,10 @@ def _check_attribute( return [] - def check(self, element, file: str): - errors = [] + def check(self, element: CodeElement, file: str) -> List[Error]: + errors: List[Error] = [] if isinstance(element, AtomicUnit): - for policy in SecurityVisitor._PASSWORD_KEY_POLICY: + for policy in SecurityVisitor.PASSWORD_KEY_POLICY: if ( policy["required"] == "yes" and element.type in policy["au_type"] diff --git a/glitch/exceptions.py b/glitch/exceptions.py index c31ae278..883fde4c 100644 --- a/glitch/exceptions.py +++ b/glitch/exceptions.py @@ -15,5 +15,5 @@ } -def throw_exception(exception, *args): +def throw_exception(exception: str, *args: str) -> None: print(exception.format(*args), file=sys.stderr) diff --git a/glitch/helpers.py b/glitch/helpers.py index b36630bc..84396b3b 100644 --- a/glitch/helpers.py +++ b/glitch/helpers.py @@ -1,6 +1,7 @@ import click -from typing import List +from click.types import ParamType +from typing import Optional, List, Tuple, Iterable, Sequence, Any, Union from glitch.tech import Tech from glitch.analysis.rules import Error @@ -8,22 +9,22 @@ class RulesListOption(click.Option): def __init__( self, - param_decls=None, - show_default=False, - prompt=False, - confirmation_prompt=False, - hide_input=False, - is_flag=None, - flag_value=None, - multiple=False, - count=False, - allow_from_autoenv=True, - type=None, - help=None, - hidden=False, - show_choices=True, - show_envvar=False, - ): + param_decls: Optional[Sequence[str]] = None, + show_default: bool = False, + prompt: bool = False, + confirmation_prompt: bool = False, + hide_input: bool = False, + is_flag: Optional[bool] = None, + flag_value: Optional[Any] = None, + multiple: bool = False, + count: bool = False, + allow_from_autoenv: bool = True, + type: Optional[Union[ParamType, Any]] = None, + help: Optional[str] = None, + hidden: bool = False, + show_choices: bool = True, + show_envvar: bool = False, + ) -> None: super().__init__( param_decls=param_decls, show_default=show_default, @@ -44,16 +45,16 @@ def __init__( self.type = click.Choice(get_smell_types(), case_sensitive=False) -def get_smell_types() -> List[str]: +def get_smell_types() -> Tuple[str, ...]: """Get list of smell types. Returns: List[str]: List of smell types. """ - return Error.ERRORS.keys() + return tuple(Error.ERRORS.keys()) -def get_smells(smell_types: List[str], tech: Tech) -> List[str]: +def get_smells(smell_types: Iterable[str], tech: Tech) -> List[str]: """Get list of smells. Args: @@ -64,19 +65,20 @@ def get_smells(smell_types: List[str], tech: Tech) -> List[str]: List[str]: List of smells. """ - smells = [] + smells: List[str] = [] for smell_type in smell_types: errors = Error.ERRORS[smell_type] for error in errors: if error == tech: - smells.extend(errors[error].keys()) + smells.extend(errors[error].keys()) # type: ignore elif not isinstance(error, Tech): smells.append(error) return smells -def remove_unmatched_brackets(string): - stack, aux = [], "" +def remove_unmatched_brackets(string: str): + stack: List[str] = [] + aux = "" for c in string: if c in ["(", "[", "{"]: @@ -101,10 +103,10 @@ def remove_unmatched_brackets(string): # Python program for KMP Algorithm (https://www.geeksforgeeks.org/python-program-for-kmp-algorithm-for-pattern-searching-2/) # Based on code by Bhavya Jain -def kmp_search(pat, txt): +def kmp_search(pat: str, txt: str): M = len(pat) N = len(txt) - res = [] + res: List[int] = [] # create lps[] that will hold the longest prefix suffix # values for pattern @@ -136,7 +138,7 @@ def kmp_search(pat, txt): return res -def compute_LPS_array(pat, M, lps): +def compute_LPS_array(pat: str, M: int, lps: List[int]) -> None: len = 0 # length of the previous longest prefix suffix lps[0] i = 1 diff --git a/glitch/parsers/ansible.py b/glitch/parsers/ansible.py index d608cfc9..3fee9b9d 100644 --- a/glitch/parsers/ansible.py +++ b/glitch/parsers/ansible.py @@ -1,23 +1,28 @@ import os -import ruamel.yaml as yaml -from ruamel.yaml import ( +import glitch.parsers.parser as p + +from typing import List, TextIO, Tuple, Union, Any, Optional, Callable +from ruamel.yaml.main import YAML +from ruamel.yaml.nodes import ( + Node, ScalarNode, MappingNode, SequenceNode, - CommentToken, CollectionNode, ) +from ruamel.yaml.tokens import Token, CommentToken from glitch.exceptions import EXCEPTIONS, throw_exception - -import glitch.parsers.parser as p from glitch.repr.inter import * +RecursiveTokenList = List[Union[Token, "RecursiveTokenList", None]] + + class AnsibleParser(p.Parser): @staticmethod - def __get_yaml_comments(d, file): - def extract_from_token(tokenlist): - res = [] + def __get_yaml_comments(d: Node, file: TextIO): + def extract_from_token(tokenlist: RecursiveTokenList) -> List[Tuple[int, str]]: + res: List[Tuple[int, str]] = [] for token in tokenlist: if token is None: continue @@ -27,8 +32,8 @@ def extract_from_token(tokenlist): res.append((token.start_mark.line, token.value)) return res - def yaml_comments(d): - res = [] + def yaml_comments(d: Node) -> List[Tuple[int, str]]: + res: List[Tuple[int, str]] = [] if isinstance(d, MappingNode): if d.comment is not None: @@ -53,7 +58,7 @@ def yaml_comments(d): file.seek(0, 0) f_lines = file.readlines() - comments = [] + comments: List[Tuple[int, str]] = [] for c_group in yaml_comments(d): line = c_group[0] c_group_comments = c_group[1].strip().split("\n") @@ -75,7 +80,11 @@ def yaml_comments(d): return set(comments) @staticmethod - def __get_element_code(start_token, end_token, code): + def __get_element_code( + start_token: Token | Node, + end_token: List[Token | Node] | Token | Node | str, + code: List[str], + ): if isinstance(end_token, list) and len(end_token) > 0: end_token = end_token[-1] elif isinstance(end_token, list) or isinstance(end_token, str): @@ -97,8 +106,16 @@ def __get_element_code(start_token, end_token, code): return res @staticmethod - def __parse_vars(unit_block, cur_name, token, code, child=False): - def create_variable(token, name, value, child=False) -> Variable: + def __parse_vars( + unit_block: UnitBlock, + cur_name: str, + node: Node, + code: List[str], + child: bool = False, + ) -> List[Variable]: + def create_variable( + token: Token | Node, name: str, value: str | None, child: bool = False + ) -> Variable: has_variable = ( (("{{" in value) and ("}}" in value)) if value != None else False ) @@ -117,48 +134,51 @@ def create_variable(token, name, value, child=False) -> Variable: unit_block.add_variable(v) return v - variables = [] - if isinstance(token, MappingNode): + variables: List[Variable] = [] + if isinstance(node, MappingNode): if cur_name == "": - for key, v in token.value: + for key, v in node.value: if hasattr(key, "value") and isinstance(key.value, str): AnsibleParser.__parse_vars( unit_block, key.value, v, code, child ) elif isinstance(key.value, MappingNode): AnsibleParser.__parse_vars( - unit_block, cur_name, key.value[0][0], code, child + unit_block, cur_name, key.value[0][0], code, child # type: ignore ) else: - var = create_variable(token, cur_name, None, child) - for key, v in token.value: + var = create_variable(node, cur_name, None, child) + for key, v in node.value: if hasattr(key, "value") and isinstance(key.value, str): var.keyvalues += AnsibleParser.__parse_vars( unit_block, key.value, v, code, True ) elif isinstance(key.value, MappingNode): var.keyvalues += AnsibleParser.__parse_vars( - unit_block, cur_name, key.value[0][0], code, True + unit_block, cur_name, key.value[0][0], code, True # type: ignore ) - elif isinstance(token, ScalarNode): - create_variable(token, cur_name, str(token.value), child) - elif isinstance(token, SequenceNode): - value = [] - for i, val in enumerate(token.value): + elif isinstance(node, ScalarNode): + create_variable(node, cur_name, str(node.value), child) + elif isinstance(node, SequenceNode): + value: List[Any] = [] + for i, val in enumerate(node.value): if isinstance(val, CollectionNode): variables += AnsibleParser.__parse_vars( unit_block, f"{cur_name}[{i}]", val, code, child ) else: value.append(val.value) - if len(value) > 0: - create_variable(val, cur_name, str(value), child) + + if len(value) > 0 and isinstance(node.value[-1], (Node, Token)): + create_variable(node.value[-1], cur_name, str(value), child) return variables @staticmethod - def __parse_attribute(cur_name, token, val, code): - def create_attribute(token, name, value) -> Attribute: + def __parse_attribute( + cur_name: str, token: Token | Node, val: Any, code: List[str] + ) -> List[Attribute]: + def create_attribute(token: Token | Node, name: str, value: Any) -> Attribute: has_variable = ( (("{{" in value) and ("}}" in value)) if value != None else False ) @@ -175,10 +195,10 @@ def create_attribute(token, name, value) -> Attribute: return a - attributes = [] + attributes: List[Attribute] = [] if isinstance(val, MappingNode): attribute = create_attribute(token, cur_name, None) - aux_attributes = [] + aux_attributes: List[KeyValue] = [] for aux, aux_val in val.value: aux_attributes += AnsibleParser.__parse_attribute( f"{aux.value}", aux, aux_val, code @@ -187,7 +207,7 @@ def create_attribute(token, name, value) -> Attribute: elif isinstance(val, ScalarNode): create_attribute(token, cur_name, str(val.value)) elif isinstance(val, SequenceNode): - value = [] + value: List[Any] = [] for i, v in enumerate(val.value): if not isinstance(v, ScalarNode): attributes += AnsibleParser.__parse_attribute( @@ -202,9 +222,10 @@ def create_attribute(token, name, value) -> Attribute: return attributes @staticmethod - def __parse_tasks(unit_block, tasks, code): + def __parse_tasks(unit_block: UnitBlock, tasks: Node, code: List[str]) -> None: for task in tasks.value: - atomic_units, attributes = [], [] + atomic_units: List[AtomicUnit] = [] + attributes: List[Attribute] = [] type, name, line = "", "", 0 is_block = False @@ -230,7 +251,7 @@ def __parse_tasks(unit_block, tasks, code): type = key.value line = task.start_mark.line + 1 - names = [n.strip() for n in name.split(",")] + names: List[str] = [n.strip() for n in name.split(",")] for name in names: if name == "": continue @@ -274,16 +295,21 @@ def __parse_tasks(unit_block, tasks, code): au.code = code[au.line - 1] unit_block.add_atomic_unit(au) - def __parse_playbook(self, name, file, parsed_file=None) -> UnitBlock: + def __parse_playbook( + self, name: str, file: TextIO, parsed_file: Optional[Node] = None + ) -> Optional[UnitBlock]: try: if parsed_file is None: - parsed_file = yaml.YAML().compose(file) + parsed_file = YAML().compose(file) unit_block = UnitBlock(name, UnitBlockType.script) unit_block.path = file.name file.seek(0, 0) code = file.readlines() code.append("") # HACK allows to parse code in the end of the file + if parsed_file is None: + return unit_block + for p in parsed_file.value: # Plays are unit blocks inside a unit block play = UnitBlock("", UnitBlockType.block) @@ -314,10 +340,12 @@ def __parse_playbook(self, name, file, parsed_file=None) -> UnitBlock: throw_exception(EXCEPTIONS["ANSIBLE_PLAYBOOK"], file.name) return None - def __parse_tasks_file(self, name, file, parsed_file=None) -> UnitBlock: + def __parse_tasks_file( + self, name: str, file: TextIO, parsed_file: Optional[Node] = None + ) -> Optional[UnitBlock]: try: if parsed_file is None: - parsed_file = yaml.YAML().compose(file) + parsed_file = YAML().compose(file) unit_block = UnitBlock(name, UnitBlockType.tasks) unit_block.path = file.name file.seek(0, 0) @@ -339,10 +367,12 @@ def __parse_tasks_file(self, name, file, parsed_file=None) -> UnitBlock: throw_exception(EXCEPTIONS["ANSIBLE_TASKS_FILE"], file.name) return None - def __parse_vars_file(self, name, file, parsed_file=None) -> UnitBlock: + def __parse_vars_file( + self, name: str, file: TextIO, parsed_file: Optional[Node] = None + ) -> Optional[UnitBlock]: try: if parsed_file is None: - parsed_file = yaml.YAML().compose(file) + parsed_file = YAML().compose(file) unit_block = UnitBlock(name, UnitBlockType.vars) unit_block.path = file.name file.seek(0, 0) @@ -365,7 +395,11 @@ def __parse_vars_file(self, name, file, parsed_file=None) -> UnitBlock: return None @staticmethod - def __apply_to_files(module, path, p_function): + def __apply_to_files( + module: Module | Project, + path: str, + p_function: Callable[[str, TextIO], Optional[UnitBlock]], + ) -> None: if os.path.exists(path) and os.path.isdir(path) and not os.path.islink(path): files = [ f @@ -406,7 +440,7 @@ def parse_module(self, path: str) -> Module: return res - def parse_folder(self, path: str, root=True) -> Project: + def parse_folder(self, path: str, root: bool = True) -> Project: """ It follows the sample directory layout found in: https://docs.ansible.com/ansible/latest/user_guide/sample_setup.html#sample-directory-layout @@ -449,18 +483,18 @@ def parse_folder(self, path: str, root=True) -> Project: return res - def parse_file(self, path: str, blocktype: UnitBlockType) -> UnitBlock: + def parse_file(self, path: str, type: UnitBlockType) -> Optional[UnitBlock]: with open(path) as f: try: - parsed_file = yaml.YAML().compose(f) + parsed_file = YAML().compose(f) f.seek(0, 0) except: throw_exception(EXCEPTIONS["ANSIBLE_COULD_NOT_PARSE"], path) return None - if blocktype == UnitBlockType.unknown: + if type == UnitBlockType.unknown: if isinstance(parsed_file, MappingNode): - blocktype = UnitBlockType.vars + type = UnitBlockType.vars elif ( isinstance(parsed_file, SequenceNode) and len(parsed_file.value) > 0 @@ -473,19 +507,19 @@ def parse_file(self, path: str, blocktype: UnitBlockType) -> UnitBlock: hosts = True break - blocktype = UnitBlockType.script if hosts else UnitBlockType.tasks + type = UnitBlockType.script if hosts else UnitBlockType.tasks elif ( isinstance(parsed_file, SequenceNode) and len(parsed_file.value) == 0 ): - blocktype = UnitBlockType.script + type = UnitBlockType.script else: throw_exception(EXCEPTIONS["ANSIBLE_FILE_TYPE"], path) return None - if blocktype == UnitBlockType.script: + if type == UnitBlockType.script: return self.__parse_playbook(path, f, parsed_file=parsed_file) - elif blocktype == UnitBlockType.tasks: + elif type == UnitBlockType.tasks: return self.__parse_tasks_file(path, f, parsed_file=parsed_file) - elif blocktype == UnitBlockType.vars: + elif type == UnitBlockType.vars: return self.__parse_vars_file(path, f, parsed_file=parsed_file) diff --git a/glitch/parsers/chef.py b/glitch/parsers/chef.py index 92f0163b..8bab97a2 100644 --- a/glitch/parsers/chef.py +++ b/glitch/parsers/chef.py @@ -1,24 +1,25 @@ import os +import sys import re import tempfile +import glitch.parsers.parser as p + from string import Template from pkg_resources import resource_filename -from glitch.exceptions import EXCEPTIONS, throw_exception - -import glitch.parsers.parser as p +from typing import Any, List, Tuple, Callable from glitch.repr.inter import * from glitch.parsers.ripper_parser import parser_yacc from glitch.helpers import remove_unmatched_brackets +from glitch.exceptions import EXCEPTIONS, throw_exception + +ChefValue = Tuple[str, str] | str | int | bool | List["ChefValue"] class ChefParser(p.Parser): class Node: - id: str - args: list - - def __init__(self, id, args) -> None: - self.id = id - self.args = args + def __init__(self, id: str, args: List[Any]) -> None: + self.id: str = id + self.args: List[Any] = args def __repr__(self) -> str: return str(self.id) @@ -30,15 +31,15 @@ def __reversed__(self): return reversed(self.args) @staticmethod - def _check_id(ast, ids): + def _check_id(ast: Any, ids: List[Any]) -> bool: return isinstance(ast, ChefParser.Node) and ast.id in ids @staticmethod - def _check_node(ast, ids, size): + def _check_node(ast: Any, ids: List[Any], size: int) -> bool: return ChefParser._check_id(ast, ids) and len(ast.args) == size @staticmethod - def _check_has_variable(ast): + def _check_has_variable(ast: Node) -> bool: references = ["vcall", "call", "aref", "fcall", "var_ref"] if ChefParser._check_id(ast, ["args_add_block"]): return ChefParser._check_id(ast.args[0][0], references) @@ -54,16 +55,16 @@ def _check_has_variable(ast): return ChefParser._check_id(ast, references) @staticmethod - def _get_content_bounds(ast, source): - def is_bounds(l): + def _get_content_bounds(ast: Any, source: List[str]) -> Tuple[int, int, int, int]: + def is_bounds(l: Any) -> bool: return ( isinstance(l, list) - and len(l) == 2 + and len(l) == 2 # type: ignore and isinstance(l[0], int) and isinstance(l[1], int) ) - start_line, start_column = float("inf"), float("inf") + start_line, start_column = sys.maxsize, sys.maxsize end_line, end_column = 0, 0 bounded_structures = [ "brace_block", @@ -111,7 +112,7 @@ def is_bounds(l): # We have to consider extra characters which correspond # to enclosing characters of these structures - if start_line != float("inf") and ChefParser._check_id( + if start_line != sys.maxsize and ChefParser._check_id( ast, bounded_structures ): r_brackets = ["}", ")", "]", '"', "'"] @@ -147,11 +148,11 @@ def is_bounds(l): return (start_line, start_column, end_line, end_column) @staticmethod - def _get_content(ast, source): + def _get_content(ast: Any, source: List[str]) -> str: empty_structures = {"string_literal": "", "hash": "{}", "array": "[]"} if isinstance(ast, list): - return "".join(list(map(lambda a: ChefParser._get_content(a, source), ast))) + return "".join(list(map(lambda a: ChefParser._get_content(a, source), ast))) # type: ignore if (ast.id in empty_structures and len(ast.args) == 0) or ( ast.id == "string_literal" and len(ast.args[0].args) == 0 @@ -161,7 +162,7 @@ def _get_content(ast, source): bounds = ChefParser._get_content_bounds(ast, source) res = "" - if bounds[0] == float("inf"): + if bounds[0] == sys.maxsize: return res for l in range(bounds[0] - 1, bounds[2]): @@ -184,16 +185,16 @@ def _get_content(ast, source): return remove_unmatched_brackets(res) @staticmethod - def _get_source(ast, source): + def _get_source(ast: Any, source: List[str]) -> str: bounds = ChefParser._get_content_bounds(ast, source) return "".join(source[bounds[0] - 1 : bounds[2]]) class Checker: - def __init__(self, source): - self.tests_ast_stack = [] + def __init__(self, source: List[str]) -> None: + self.tests_ast_stack: List[Tuple[List[Callable[[Any], bool]], Any]] = [] self.source = source - def check(self): + def check(self) -> bool: tests, ast = self.pop() for test in tests: if test(ast): @@ -207,19 +208,21 @@ def check_all(self): status = self.check() return status - def push(self, tests, ast): + def push(self, tests: List[Callable[[Any], bool]], ast: Any) -> None: self.tests_ast_stack.append((tests, ast)) def pop(self): return self.tests_ast_stack.pop() class ResourceChecker(Checker): - def __init__(self, atomic_unit, source, ast): + def __init__( + self, atomic_unit: AtomicUnit, source: List[str], ast: Any + ) -> None: super().__init__(source) self.push([self.is_block_resource, self.is_inline_resource], ast) self.atomic_unit = atomic_unit - def is_block_resource(self, ast): + def is_block_resource(self, ast: Any) -> bool: if ( ChefParser._check_node(ast, ["method_add_block"], 2) and ChefParser._check_node(ast.args[0], ["command"], 2) @@ -234,7 +237,7 @@ def is_block_resource(self, ast): return True return False - def is_inline_resource(self, ast): + def is_inline_resource(self, ast: Any) -> bool: if ( ChefParser._check_node(ast, ["command"], 2) and ChefParser._check_id(ast.args[0], ["@ident"]) @@ -255,7 +258,7 @@ def is_inline_resource(self, ast): return True return False - def is_resource_def(self, ast): + def is_resource_def(self, ast: Any) -> bool: if ChefParser._check_node( ast.args[0], ["@ident"], 2 ) and ChefParser._check_node(ast.args[1], ["args_add_block"], 2): @@ -264,7 +267,7 @@ def is_resource_def(self, ast): return True return False - def is_resource_type(self, ast): + def is_resource_type(self, ast: "ChefParser.Node") -> bool: if ( isinstance(ast.args[0], str) and isinstance(ast.args[1], list) @@ -282,7 +285,7 @@ def is_resource_type(self, ast): return True return False - def is_resource_name(self, ast): + def is_resource_name(self, ast: "ChefParser.Node") -> bool: if isinstance(ast.args[0][0], ChefParser.Node) and ast.args[1] is False: resource_id = ast.args[0][0] self.atomic_unit.name = ChefParser._get_content( @@ -291,7 +294,7 @@ def is_resource_name(self, ast): return True return False - def is_inline_resource_name(self, ast): + def is_inline_resource_name(self, ast: "ChefParser.Node") -> bool: if ( ChefParser._check_node(ast.args[0][0], ["method_add_block"], 2) and ast.args[1] is False @@ -304,13 +307,13 @@ def is_inline_resource_name(self, ast): return True return False - def is_resource_body(self, ast): + def is_resource_body(self, ast: "ChefParser.Node") -> bool: if ChefParser._check_id(ast.args[0], ["bodystmt"]): self.push([self.is_attribute], ast.args[0].args[0]) return True return False - def is_resource_body_without_attributes(self, ast): + def is_resource_body_without_attributes(self, ast: "ChefParser.Node") -> bool: if ( ChefParser._check_id(ast.args[0][0], ["string_literal"]) and ast.args[1] is False @@ -321,7 +324,7 @@ def is_resource_body_without_attributes(self, ast): return True return False - def is_attribute(self, ast): + def is_attribute(self, ast: Any) -> bool: if ChefParser._check_node( ast, ["method_add_arg"], 2 ) and ChefParser._check_id(ast.args[0], ["call"]): @@ -348,25 +351,33 @@ def is_attribute(self, ast): a.code = ChefParser._get_source(ast, self.source) self.atomic_unit.add_attribute(a) elif isinstance(ast, (ChefParser.Node, list)): - for arg in reversed(ast): + for arg in reversed(ast): # type: ignore self.push([self.is_attribute], arg) return True class VariableChecker(Checker): - def __init__(self, source, ast): + def __init__(self, source: List[str], ast: Any) -> None: super().__init__(source) - self.variables = [] + self.variables: List[Variable] = [] self.push([self.is_variable], ast) - def is_variable(self, ast): - def create_variable(key, name, value, has_variable): + def is_variable(self, ast: Any) -> bool: + def create_variable( + key: Any, name: str, value: str | None, has_variable: bool + ): variable = Variable(name, value, has_variable) variable.line = ChefParser._get_content_bounds(key, self.source)[0] variable.code = ChefParser._get_source(ast, self.source) return variable - def parse_variable(parent, ast, key, current_name, value_ast): + def parse_variable( + parent: KeyValue | None, + ast: Any, + key: Any, + current_name: str, + value_ast: "ChefParser.Node", + ) -> None: if ChefParser._check_node( value_ast, ["hash"], 1 ) and ChefParser._check_id(value_ast.args[0], ["assoclist_from_args"]): @@ -420,7 +431,7 @@ def parse_variable(parent, ast, key, current_name, value_ast): variable = create_variable(ast.args[0], name, None, False) if i == 0: self.variables.append(variable) - else: + elif parent is not None: parent.keyvalues.append(variable) parent = variable return True @@ -428,12 +439,12 @@ def parse_variable(parent, ast, key, current_name, value_ast): return False class IncludeChecker(Checker): - def __init__(self, source, ast): + def __init__(self, source: List[str], ast: Any) -> None: super().__init__(source) self.push([self.is_include], ast) self.code = "" - def is_include(self, ast): + def is_include(self, ast: Any) -> bool: if ( ChefParser._check_node(ast, ["command"], 2) and ChefParser._check_id(ast.args[0], ["@ident"]) @@ -445,7 +456,7 @@ def is_include(self, ast): return True return False - def is_include_type(self, ast): + def is_include_type(self, ast: "ChefParser.Node") -> bool: if ( isinstance(ast.args[0], str) and isinstance(ast.args[1], list) @@ -454,7 +465,7 @@ def is_include_type(self, ast): return True return False - def is_include_name(self, ast): + def is_include_name(self, ast: "ChefParser.Node") -> bool: if ( ChefParser._check_id(ast.args[0][0], ["string_literal"]) and ast.args[1] is False @@ -468,11 +479,11 @@ def is_include_name(self, ast): # FIXME only identifying case statement class ConditionChecker(Checker): - def __init__(self, source, ast): + def __init__(self, source: List[str], ast: Any) -> None: super().__init__(source) self.push([self.is_case], ast) - def is_case(self, ast): + def is_case(self, ast: Any) -> bool: if ChefParser._check_node(ast, ["case"], 2): self.case_head = ChefParser._get_content(ast.args[0], self.source) self.condition = None @@ -485,7 +496,7 @@ def is_case(self, ast): return True return False - def is_case_condition(self, ast): + def is_case_condition(self, ast: Any) -> bool: if ChefParser._check_node(ast, ["when"], 3) or ChefParser._check_node( ast, ["when"], 2 ): @@ -532,48 +543,50 @@ def is_case_condition(self, ast): return False @staticmethod - def __create_ast(l): - args = [] + def __create_ast(l: List[ChefValue | "ChefParser.Node"]) -> "ChefParser.Node": + args: List[Any] = [] for el in l[1:]: if isinstance(el, list): if len(el) > 0 and isinstance(el[0], tuple) and el[0][0] == "id": - args.append(ChefParser.__create_ast(el)) + args.append(ChefParser.__create_ast(el)) # type: ignore else: - arg = [] + arg: List["ChefParser.Node" | ChefValue] = [] for e in el: if ( isinstance(e, list) and isinstance(e[0], tuple) and e[0][0] == "id" ): - arg.append(ChefParser.__create_ast(e)) + arg.append(ChefParser.__create_ast(e)) # type: ignore else: arg.append(e) args.append(arg) else: args.append(el) - return ChefParser.Node(l[0][1], args) + return ChefParser.Node(l[0][1], args) # type: ignore @staticmethod - def __transverse_ast(ast, unit_block, source): - def get_var(parent_name, vars): + def __transverse_ast(ast: Any, unit_block: UnitBlock, source: List[str]) -> None: + def get_var(parent_name: str, vars: List[Variable]): for var in vars: if var.name == parent_name: return var return None - def add_variable_to_unit_block(variable, unit_block_vars): + def add_variable_to_unit_block( + variable: Variable, unit_block_vars: List[Variable] + ) -> None: var_name = variable.name var = get_var(var_name, unit_block_vars) if var and var.value == None and variable.value == None: for v in variable.keyvalues: - add_variable_to_unit_block(v, var.keyvalues) + add_variable_to_unit_block(v, var.keyvalues) # type: ignore else: unit_block_vars.append(variable) if isinstance(ast, list): - for arg in ast: + for arg in ast: # type: ignore if isinstance(arg, (ChefParser.Node, list)): ChefParser.__transverse_ast(arg, unit_block, source) else: @@ -599,7 +612,8 @@ def add_variable_to_unit_block(variable, unit_block_vars): if_checker = ChefParser.ConditionChecker(source, ast) if if_checker.check_all(): - unit_block.add_statement(if_checker.condition) + if if_checker.condition is not None: + unit_block.add_statement(if_checker.condition) # Check blocks inside ChefParser.__transverse_ast( ast.args[len(ast.args) - 1], unit_block, source @@ -611,7 +625,7 @@ def add_variable_to_unit_block(variable, unit_block_vars): ChefParser.__transverse_ast(arg, unit_block, source) @staticmethod - def __parse_recipe(path, file) -> UnitBlock: + def __parse_recipe(path: str, file: str) -> UnitBlock: with open(os.path.join(path, file)) as f: ripper = resource_filename( "glitch.parsers", "resources/comments.rb.template" @@ -635,6 +649,7 @@ def __parse_recipe(path, file) -> UnitBlock: throw_exception( EXCEPTIONS["CHEF_COULD_NOT_PARSE"], os.path.join(path, file) ) + return unit_block with tempfile.NamedTemporaryFile(mode="w+") as tmp: tmp.write(ripper_script) @@ -678,7 +693,7 @@ def __parse_recipe(path, file) -> UnitBlock: return unit_block def parse_module(self, path: str) -> Module: - def parse_folder(path: str): + def parse_folder(path: str) -> None: if os.path.exists(path): files = [ f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) diff --git a/glitch/parsers/docker.py b/glitch/parsers/docker.py index 8e850977..747092c7 100644 --- a/glitch/parsers/docker.py +++ b/glitch/parsers/docker.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Union -import bashlex +import bashlex # type: ignore from dockerfile_parse import DockerfileParser import glitch.parsers.parser as p @@ -69,7 +69,7 @@ def parse_file(self, path: str, type: UnitBlockType) -> UnitBlock: main_block.path = path return main_block except Exception as e: - throw_exception(EXCEPTIONS["DOCKER_UNKNOW_ERROR"], e) + throw_exception(EXCEPTIONS["DOCKER_UNKNOW_ERROR"], str(e)) main_block = UnitBlock(os.path.basename(path), type) main_block.path = path return main_block @@ -129,7 +129,7 @@ def __parse_stage( return u @staticmethod - def __parse_instruction(element: DFPStructure, unit_block: UnitBlock): + def __parse_instruction(element: DFPStructure, unit_block: UnitBlock) -> None: instruction = element.instruction if instruction in ["ENV", "USER", "ARG", "LABEL"]: unit_block.variables += DockerParser.__create_variable_block(element) @@ -175,7 +175,7 @@ def __parse_instruction(element: DFPStructure, unit_block: UnitBlock): def __get_stages( stage_indexes: List[int], structure: List[DFPStructure] ) -> List[Tuple[str, List[DFPStructure]]]: - stages = [] + stages: List[Tuple[str, List[DFPStructure]]] = [] for i, stage_i in enumerate(stage_indexes): stage_image = structure[stage_i].value.split(" ")[0] stage_start = stage_i if i != 0 else 0 @@ -202,7 +202,7 @@ def __get_stage_structure( @staticmethod def __create_variable_block(element: DFPStructure) -> List[Variable]: - variables = [] + variables: List[Variable] = [] if element.instruction == "USER": variables.append(Variable("user-profile", element.value, False)) elif element.instruction == "ARG": @@ -236,7 +236,7 @@ def __create_variable_block(element: DFPStructure) -> List[Variable]: def __parse_multiple_key_value_variables( content: str, base_line: int ) -> List[Variable]: - variables = [] + variables: List[Variable] = [] for i, line in enumerate(content.split("\n")): for match in re.finditer( r"([\w_]*)=(?:(?:'|\")([\w\. <>@]*)(?:\"|')|([\w\.]*))", line @@ -249,11 +249,7 @@ def __parse_multiple_key_value_variables( return variables @staticmethod - def __has_user_tag(structure: List[DFPStructure]) -> bool: - return bool(s for s in structure if s.instruction == "USER") - - @staticmethod - def __add_user_tag(structure: List[DFPStructure]): + def __add_user_tag(structure: List[DFPStructure]) -> None: if len([s for s in structure if s.instruction == "USER"]) > 0: return @@ -289,9 +285,9 @@ def to_atomic_unit(self) -> AtomicUnit: sudo.code = "sudo" sudo.line = self.line au.add_attribute(sudo) - for key, (value, code) in self.options.items(): + for key, (value, _) in self.options.items(): has_variable = "$" in value if isinstance(value, str) else False - attr = Attribute(key, value, has_variable) + attr = Attribute(key, value, has_variable) # type: ignore attr.code = self.code attr.line = self.line au.add_attribute(attr) @@ -299,7 +295,7 @@ def to_atomic_unit(self) -> AtomicUnit: class CommandParser: - def __init__(self, command: DFPStructure): + def __init__(self, command: DFPStructure) -> None: value = ( command.content.replace("RUN ", "") if command.instruction == "RUN" @@ -315,7 +311,7 @@ def __init__(self, command: DFPStructure): def parse_command(self) -> List[AtomicUnit]: # TODO: Fix get commands lines for scripts with multiline values commands = self.__get_sub_commands() - aus = [] + aus: List[AtomicUnit] = [] for line, c in commands: try: aus.append(self.__parse_single_command(c, line)) @@ -354,26 +350,26 @@ def __strip_shell_command(command: List[str]) -> Tuple[List[str], int]: i for i, c in enumerate(command) if c not in ["\n", "", " ", "\r"] ] if not non_empty_indexes: - return [] + return ([], 0) start, end = non_empty_indexes[0], non_empty_indexes[-1] return command[start : end + 1], sum(1 for c in command if c == "\n") @staticmethod - def __parse_shell_command(command: ShellCommand): + def __parse_shell_command(command: ShellCommand) -> None: if command.command == "chmod": reference = [arg for arg in command.args if "--reference" in arg] command.args = [arg for arg in command.args if not arg.startswith("-")] command.main_arg = command.args[-1] if reference: reference[0] - command.options["reference"] = (reference.split("=")[1], reference) + command.options["reference"] = (reference.split("=")[1], reference) # type: ignore else: command.options["mode"] = command.args[0], command.args[0] else: CommandParser.__parse_general_command(command) @staticmethod - def __parse_general_command(command: ShellCommand): + def __parse_general_command(command: ShellCommand) -> None: args = command.args.copy() # TODO: Solve issue where last argument is part of a parameter main_arg_index = -1 if not args[-1].startswith("-") else 0 @@ -400,8 +396,8 @@ def __parse_general_command(command: ShellCommand): command.options[o] = args[i + 1], f"{code} {args[i+1]}" def __get_sub_commands(self) -> List[Tuple[int, List[str]]]: - commands = [] - tmp = [] + commands: List[Tuple[int, List[str]]] = [] + tmp: List[str] = [] lines = ( self.command.split("\n") if not self.__contains_multi_line_values(self.command) @@ -409,13 +405,13 @@ def __get_sub_commands(self) -> List[Tuple[int, List[str]]]: ) current_line = self.line for i, line in enumerate(lines): - for part in bashlex.split(line): + for part in bashlex.split(line): # type: ignore if part in ["&&", "&", "|", ";"]: commands.append((current_line, tmp)) current_line = self.line + i tmp = [] continue - tmp.append(part) + tmp.append(part) # type: ignore commands.append((current_line, tmp)) return commands diff --git a/glitch/parsers/parser.py b/glitch/parsers/parser.py index 810db843..d363c2c6 100644 --- a/glitch/parsers/parser.py +++ b/glitch/parsers/parser.py @@ -1,12 +1,15 @@ import os from glitch.repr.inter import * from abc import ABC, abstractmethod +from typing import Optional from glitch.repr.inter import UnitBlockType class Parser(ABC): - def parse(self, path: str, type: UnitBlockType, is_module: bool) -> Module: + def parse( + self, path: str, type: UnitBlockType, is_module: bool + ) -> Optional[Module | Project | UnitBlock]: if is_module: return self.parse_module(path) elif os.path.isfile(path): @@ -15,7 +18,7 @@ def parse(self, path: str, type: UnitBlockType, is_module: bool) -> Module: return self.parse_folder(path) @abstractmethod - def parse_file(self, path: str, type: UnitBlockType) -> UnitBlock: + def parse_file(self, path: str, type: UnitBlockType) -> Optional[UnitBlock]: pass @abstractmethod @@ -26,7 +29,7 @@ def parse_folder(self, path: str) -> Project: def parse_module(self, path: str) -> Module: pass - def parse_file_structure(self, folder, path): + def parse_file_structure(self, folder: Folder, path: str) -> None: for f in os.listdir(path): if os.path.islink(os.path.join(path, f)): continue diff --git a/glitch/parsers/puppet.py b/glitch/parsers/puppet.py index 02684254..e13ce6b0 100644 --- a/glitch/parsers/puppet.py +++ b/glitch/parsers/puppet.py @@ -1,3 +1,5 @@ +# type: ignore +# TODO: The file needs a refactor so the types make sense import os import traceback from puppetparser.parser import parse as parse_puppet @@ -6,18 +8,23 @@ import glitch.parsers.parser as p from glitch.repr.inter import * +from typing import List, Any, Tuple, Dict class PuppetParser(p.Parser): @staticmethod - def __process_unitblock_component(ce, unit_block: UnitBlock): - def get_var(parent_name, vars): + def __process_unitblock_component( + ce: CodeElement | List[CodeElement], unit_block: UnitBlock + ) -> None: + def get_var(parent_name: str, vars: List[KeyValue]): for var in vars: if var.name == parent_name: return var return None - def add_variable_to_unit_block(variable, unit_block_vars): + def add_variable_to_unit_block( + variable: KeyValue, unit_block_vars: List[KeyValue] + ) -> None: var_name = variable.name var = get_var(var_name, unit_block_vars) if var and var.value == None and variable.value == None: @@ -29,7 +36,7 @@ def add_variable_to_unit_block(variable, unit_block_vars): if isinstance(ce, Dependency): unit_block.add_dependency(ce) elif isinstance(ce, Variable): - add_variable_to_unit_block(ce, unit_block.variables) + add_variable_to_unit_block(ce, unit_block.variables) # type: ignore elif isinstance(ce, AtomicUnit): unit_block.add_atomic_unit(ce) elif isinstance(ce, UnitBlock): @@ -43,8 +50,10 @@ def add_variable_to_unit_block(variable, unit_block_vars): PuppetParser.__process_unitblock_component(c, unit_block) @staticmethod - def __process_codeelement(codeelement, path, code): - def get_code(ce): + def __process_codeelement( + codeelement: puppetmodel.CodeElement, path: str, code: List[str] + ): + def get_code(ce: puppetmodel.CodeElement): if ce.line == ce.end_line: res = code[ce.line - 1][max(0, ce.col - 1) : ce.end_col - 1] else: @@ -58,7 +67,9 @@ def get_code(ce): return res - def process_hash_value(name: str, temp_value): + def process_hash_value( + name: str, temp_value: Any + ) -> Tuple[str, Dict[str, Any]]: if "[" in name and "]" in name: start = name.find("[") + 1 end = name.find("]") @@ -69,7 +80,7 @@ def process_hash_value(name: str, temp_value): d[key_name] = temp_value return n, d else: - new_d: dict = {} + new_d: Dict[str, Any] = {} new_d[key_name] = d return n, new_d else: @@ -89,18 +100,15 @@ def process_hash_value(name: str, temp_value): return str( PuppetParser.__process_codeelement(codeelement.value, path, code) ) - elif codeelement.value == None: + elif codeelement.value is None: return "" return str(codeelement.value) elif isinstance(codeelement, puppetmodel.Attribute): name = PuppetParser.__process_codeelement(codeelement.key, path, code) - if codeelement.value is not None: - temp_value = PuppetParser.__process_codeelement( - codeelement.value, path, code - ) - value = "" if temp_value == "undef" else temp_value - else: - value = None + temp_value = PuppetParser.__process_codeelement( + codeelement.value, path, code + ) + value = "" if temp_value == "undef" else temp_value has_variable = not isinstance(value, str) or value.startswith("$") attribute = Attribute(name, value, has_variable) attribute.line, attribute.column = codeelement.line, codeelement.col @@ -569,7 +577,7 @@ def parse_file(self, path: str, type: UnitBlockType) -> UnitBlock: PuppetParser.__process_codeelement(parsed_script, path, code), unit_block, ) - except Exception as e: + except Exception: traceback.print_exc() throw_exception(EXCEPTIONS["PUPPET_COULD_NOT_PARSE"], path) return unit_block diff --git a/glitch/parsers/ripper_parser.py b/glitch/parsers/ripper_parser.py index e2c560bb..ce318197 100644 --- a/glitch/parsers/ripper_parser.py +++ b/glitch/parsers/ripper_parser.py @@ -1,8 +1,9 @@ -from ply.lex import lex -from ply.yacc import yacc +# pyright: reportUnusedFunction=false, reportUnusedVariable=false +from ply.lex import lex, LexToken +from ply.yacc import yacc, YaccProduction -def parser_yacc(script_ast): +def parser_yacc(script_ast: str): tokens = ( "LPAREN", "RPAREN", @@ -23,38 +24,38 @@ def parser_yacc(script_ast): t_ignore_ANY = r"[nil\,\ \n]" t_PLUS = r"\+" - def t_INTEGER(t): + def t_INTEGER(t: LexToken): r"[0-9]+" - t.value = int(t.value) + t.value = int(t.value) # type: ignore return t - def t_STRING(t): + def t_STRING(t: LexToken): r"\"([^\\\n]|(\\.))*?\" " t.value = t.value[1:-1] return t - def t_begin_id(t): + def t_begin_id(t: LexToken) -> None: r"\:" t.lexer.begin("id") - def t_id_end(t): + def t_id_end(t: LexToken) -> None: r"[\,]" t.lexer.begin("INITIAL") - def t_id_RPAREN(t): + def t_id_RPAREN(t: LexToken): r"\]" t.lexer.begin("INITIAL") return t - def t_id_COMMENT(t): + def t_id_COMMENT(t: LexToken): r"@comment" return t - def t_id_ID(t): + def t_id_ID(t: LexToken): r"[^,\]]+" return t - def t_ANY_error(t): + def t_ANY_error(t: LexToken) -> None: print(f"Illegal character {t.value[0]!r}.") t.lexer.skip(1) @@ -62,70 +63,70 @@ def t_ANY_error(t): # Give the lexer some input lexer.input(script_ast) - def p_program(p): + def p_program(p: YaccProduction) -> None: r"program : comments list" p[0] = (p[1], p[2]) - def p_comments(p): + def p_comments(p: YaccProduction) -> None: r"comments : comments comment" p[0] = [p[2]] + p[1] - def p_comments_empty(p): + def p_comments_empty(p: YaccProduction) -> None: r"comments : empty" p[0] = [] - def p_comment(p): + def p_comment(p: YaccProduction) -> None: r"comment : LPAREN COMMENT STRING LPAREN INTEGER INTEGER RPAREN RPAREN" p[0] = (p[3], p[5]) - def p_list(p): + def p_list(p: YaccProduction) -> None: r"list : LPAREN args RPAREN" p[0] = p[2] - def p_args_value(p): + def p_args_value(p: YaccProduction) -> None: r"args : value args" p[0] = [p[1]] + p[2] - def p_args_list(p): + def p_args_list(p: YaccProduction) -> None: r"args : list args" p[0] = [p[1]] + p[2] - def p_args_empty(p): + def p_args_empty(p: YaccProduction) -> None: r"args : empty" p[0] = [] - def p_empty(p): + def p_empty(p: YaccProduction) -> None: r"empty :" - def p_value_string(p): + def p_value_string(p: YaccProduction) -> None: r"value : string" p[0] = p[1] - def p_multi_string(p): + def p_multi_string(p: YaccProduction) -> None: r"string : STRING PLUS string" p[0] = p[1] + p[3] - def p_string(p): + def p_string(p: YaccProduction) -> None: r"string : STRING" p[0] = p[1] - def p_value_integer(p): + def p_value_integer(p: YaccProduction) -> None: r"value : INTEGER" p[0] = p[1] - def p_value_false(p): + def p_value_false(p: YaccProduction) -> None: r"value : FALSE" p[0] = False - def p_value_true(p): + def p_value_true(p: YaccProduction) -> None: r"value : TRUE" p[0] = True - def p_value_id(p): + def p_value_id(p: YaccProduction) -> None: r"value : ID" p[0] = ("id", p[1]) # FIXME - def p_error(p): + def p_error(p: YaccProduction) -> None: print(f"Syntax error at {p.value!r}") # Build the parser diff --git a/glitch/parsers/terraform.py b/glitch/parsers/terraform.py index 02ae3de5..f76fe318 100644 --- a/glitch/parsers/terraform.py +++ b/glitch/parsers/terraform.py @@ -1,3 +1,5 @@ +# type: ignore +# TODO: The file needs a refactor so the types make sense import os import re import hcl2 @@ -5,21 +7,28 @@ from glitch.exceptions import EXCEPTIONS, throw_exception from glitch.repr.inter import * +from typing import Sequence, List, Dict, Any class TerraformParser(p.Parser): @staticmethod - def __get_element_code(start_line, end_line, code): + def __get_element_code(start_line: int, end_line: int, code: List[str]) -> str: lines = code[start_line - 1 : end_line] res = "" for line in lines: res += line return res - def parse_keyvalues(self, unit_block: UnitBlock, keyvalues, code, type: str): - def create_keyvalue(start_line, end_line, name: str, value: str): + def parse_keyvalues( + self, + unit_block: UnitBlock, + keyvalues: Dict[Any, Any], + code: List[str], + type: str, + ) -> List[KeyValue]: + def create_keyvalue(start_line: int, end_line: int, name: str, value: str): has_variable = ( - ("${" in f"{value}") and ("}" in f"{value}") if value != None else False + ("${" in f"{value}") and ("}" in f"{value}") if value != None else False # type: ignore ) pattern = r"^[+-]?\d+(\.\d+)?$" if has_variable and re.match(pattern, re.sub(r"^\${(.*)}$", r"\1", value)): @@ -33,18 +42,20 @@ def create_keyvalue(start_line, end_line, name: str, value: str): if type == "attribute": keyvalue = Attribute(str(name), value, has_variable) - elif type == "variable": + else: keyvalue = Variable(str(name), value, has_variable) + keyvalue.line = start_line keyvalue.code = TerraformParser.__get_element_code( start_line, end_line, code ) + return keyvalue - def process_list(name, value, start_line, end_line): + def process_list(name: str, value: str, start_line: int, end_line: int) -> None: for i, v in enumerate(value): if isinstance(v, dict): - k = create_keyvalue(start_line, end_line, name + f"[{i}]", None) + k = create_keyvalue(start_line, end_line, name + f"[{i}]", None) # type: ignore k.keyvalues = self.parse_keyvalues(unit_block, v, code, type) k_values.append(k) elif isinstance(v, list): @@ -53,7 +64,7 @@ def process_list(name, value, start_line, end_line): k = create_keyvalue(start_line, end_line, name + f"[{i}]", v) k_values.append(k) - k_values = [] + k_values: List[KeyValue] = [] for name, keyvalue in keyvalues.items(): if name == "__start_line__" or name == "__end_line__": continue @@ -64,7 +75,7 @@ def process_list(name, value, start_line, end_line): value = keyvalue["value"] if isinstance(value, dict): # (ex: labels = {}) k = create_keyvalue( - keyvalue["__start_line__"], keyvalue["__end_line__"], name, None + keyvalue["__start_line__"], keyvalue["__end_line__"], name, None # type: ignore ) k.keyvalues = self.parse_keyvalues(unit_block, value, code, type) k_values.append(k) @@ -85,7 +96,7 @@ def process_list(name, value, start_line, end_line): value, ) k_values.append(k) - elif isinstance(keyvalue, list) and type == "attribute": + elif isinstance(keyvalue, list) and type == "attribute": # type: ignore # block (ex: access {} or dynamic setting {}; blocks of attributes; not allowed inside local values (variables)) try: for block_attributes in keyvalue: @@ -115,16 +126,18 @@ def process_list(name, value, start_line, end_line): return k_values - def parse_atomic_unit(self, type: str, unit_block: UnitBlock, dict, code): + def parse_atomic_unit( + self, type: str, unit_block: UnitBlock, dict, code: List[str] + ) -> None: def create_atomic_unit( - start_line, end_line, type: str, name: str, code + start_line: int, end_line: int, type: str, name: str, code: List[str] ) -> AtomicUnit: au = AtomicUnit(name, type) au.line = start_line au.code = TerraformParser.__get_element_code(start_line, end_line, code) return au - def parse_resource(): + def parse_resource() -> None: for resource_type, resource in dict.items(): for name, attributes in resource.items(): au = create_atomic_unit( @@ -139,7 +152,7 @@ def parse_resource(): ) unit_block.add_atomic_unit(au) - def parse_simple_unit(): + def parse_simple_unit() -> None: for name, attributes in dict.items(): au = create_atomic_unit( attributes["__start_line__"], @@ -158,8 +171,10 @@ def parse_simple_unit(): elif type in ["variable", "module", "output"]: parse_simple_unit() - def parse_comments(self, unit_block: UnitBlock, comments, code): - def create_comment(value, start_line, end_line, code): + def parse_comments( + self, unit_block: UnitBlock, comments: Sequence[str], code: List[str] + ) -> None: + def create_comment(value: str, start_line: int, end_line: int, code: List[str]): c = Comment(value) c.line = start_line c.code = TerraformParser.__get_element_code(start_line, end_line, code) @@ -210,7 +225,7 @@ def parse_module(self, path: str) -> Module: f.path for f in os.scandir(f"{path}") if f.is_file() and not f.is_symlink() ] for f in files: - unit_block = self.parse_file(f, "unknown") + unit_block = self.parse_file(f, UnitBlockType.unknown) res.add_block(unit_block) return res diff --git a/glitch/repair/interactive/compiler/compiler.py b/glitch/repair/interactive/compiler/compiler.py index e59b5489..e94bbf20 100644 --- a/glitch/repair/interactive/compiler/compiler.py +++ b/glitch/repair/interactive/compiler/compiler.py @@ -13,28 +13,28 @@ class DeltaPCompiler: _condition = 0 class __Attributes: - def __init__(self, au_type: str, tech: Tech): + def __init__(self, au_type: str, tech: Tech) -> None: self.au_type = NamesDatabase.get_au_type(au_type, tech) self.__tech = tech self.__attributes: Dict[str, Tuple[PExpr, Attribute]] = {} - def add_attribute(self, attribute: Attribute): + def add_attribute(self, attribute: Attribute) -> None: attr_name = NamesDatabase.get_attr_name( attribute.name, self.au_type, self.__tech ) - if attr_name is not None: - self.__attributes[attr_name] = ( - DeltaPCompiler._compile_expr( - NamesDatabase.get_attr_value( - attribute.value, - attr_name, - self.au_type, - self.__tech, - ), + + self.__attributes[attr_name] = ( # type: ignore + DeltaPCompiler._compile_expr( + NamesDatabase.get_attr_value( + attribute.value, # type: ignore + attr_name, + self.au_type, self.__tech, ), - attribute, - ) + self.__tech, + ), + attribute, + ) def get_attribute(self, attr_name: str) -> Optional[Attribute]: return self.__attributes.get(attr_name, (None, None))[1] @@ -60,19 +60,21 @@ def create_label_var_pair( attr_name: str, atomic_unit: AtomicUnit, labeled_script: LabeledUnitBlock, - ) -> Optional[Tuple[str, str]]: + ) -> Tuple[int, str]: attr = self.get_attribute(attr_name) if attr is not None: label = labeled_script.get_label(attr) else: # Creates sketched attribute - if attr_name == "state": # HACK + if attr_name == "state" and isinstance( + DefaultValue.DEFAULT_STATE.const, PStr + ): # HACK attr = Attribute( attr_name, DefaultValue.DEFAULT_STATE.const.value, False ) else: - attr = Attribute(attr_name, PEUndef(), False) + attr = Attribute(attr_name, PEUndef(), False) # type: ignore attr.line, attr.column = ( DeltaPCompiler._sketched, @@ -86,7 +88,7 @@ def create_label_var_pair( return label, labeled_script.get_var(label) @staticmethod - def _compile_expr(expr: Optional[str], tech: Tech) -> PExpr: + def _compile_expr(expr: Optional[str], tech: Tech) -> Optional[PExpr]: # FIXME to fix this I need to extend GLITCH's IR if expr is None: return None @@ -104,7 +106,7 @@ def __handle_file( path = attributes["path"] # The path may be defined as the name of the atomic unit if path == PEUndef(): - path = PEConst(PStr(atomic_unit.name)) + path = PEConst(PStr(atomic_unit.name)) # type: ignore state_label, state_var = attributes.create_label_var_pair( "state", atomic_unit, labeled_script diff --git a/glitch/repair/interactive/compiler/labeler.py b/glitch/repair/interactive/compiler/labeler.py index 6cfbd68a..da39e6ab 100644 --- a/glitch/repair/interactive/compiler/labeler.py +++ b/glitch/repair/interactive/compiler/labeler.py @@ -5,7 +5,7 @@ class LabeledUnitBlock: - def __init__(self, script: UnitBlock, tech: Tech): + def __init__(self, script: UnitBlock, tech: Tech) -> None: """Initializes a new instance of a labeled unit block. Args: @@ -48,7 +48,7 @@ def add_label( def add_sketch_location( self, sketch_location: CodeElement, codeelement: CodeElement - ): + ) -> None: """Defines where a sketched code element is defined in the script. Args: @@ -79,7 +79,7 @@ def get_codeelement(self, label: int) -> CodeElement: """ return self.__label_to_codeelement[label] - def remove_label(self, codeelement: CodeElement): + def remove_label(self, codeelement: CodeElement) -> None: """Removes the label of the code element. Args: @@ -114,7 +114,7 @@ class GLITCHLabeler: @staticmethod def label_attribute( labeled: LabeledUnitBlock, atomic_unit: AtomicUnit, attribute: Attribute - ): + ) -> None: """Labels an attribute. Args: @@ -124,10 +124,10 @@ def label_attribute( """ type = NamesDatabase.get_au_type(atomic_unit.type, labeled.tech) name = NamesDatabase.get_attr_name(attribute.name, type, labeled.tech) - labeled.add_label(name, attribute) + labeled.add_label(name, attribute) # type: ignore @staticmethod - def label_atomic_unit(labeled: LabeledUnitBlock, atomic_unit: AtomicUnit): + def label_atomic_unit(labeled: LabeledUnitBlock, atomic_unit: AtomicUnit) -> None: """Labels an atomic unit. Args: @@ -138,7 +138,7 @@ def label_atomic_unit(labeled: LabeledUnitBlock, atomic_unit: AtomicUnit): GLITCHLabeler.label_attribute(labeled, atomic_unit, attribute) @staticmethod - def label_variable(labeled: LabeledUnitBlock, variable: Variable): + def label_variable(labeled: LabeledUnitBlock, variable: Variable) -> None: """Labels a variable. Args: @@ -148,7 +148,9 @@ def label_variable(labeled: LabeledUnitBlock, variable: Variable): labeled.add_label(variable.name, variable) @staticmethod - def label_conditional(labeled: LabeledUnitBlock, conditional: ConditionalStatement): + def label_conditional( + labeled: LabeledUnitBlock, conditional: ConditionalStatement + ) -> None: """Labels a conditional statement. Args: diff --git a/glitch/repair/interactive/compiler/names_database.py b/glitch/repair/interactive/compiler/names_database.py index db66d3d1..29577110 100644 --- a/glitch/repair/interactive/compiler/names_database.py +++ b/glitch/repair/interactive/compiler/names_database.py @@ -19,10 +19,12 @@ def get_au_type(type: str, tech: Tech) -> str: return "file" case "ansible.builtin.file", Tech.ansible: return "file" - return None + case _: + pass + return type @staticmethod - def reverse_attr_name(name: str, au_type: str, tech: Tech) -> Optional[str]: + def reverse_attr_name(name: str, au_type: str, tech: Tech) -> str: """Returns the technology-specific name of the attribute with the given name, atomic unit type and tech. Args: @@ -46,12 +48,12 @@ def reverse_attr_name(name: str, au_type: str, tech: Tech) -> Optional[str]: return "state" case "state", "file", Tech.puppet: return "ensure" - return None + case _: + pass + return name @staticmethod - def reverse_attr_value( - value: str, attr_name: str, au_type: str, tech: Tech - ) -> Optional[str]: + def reverse_attr_value(value: str, attr_name: str, au_type: str, tech: Tech) -> str: """Returns the technology-specific value of the attribute with the given value, attribute name, atomic unit type and tech. Args: @@ -66,10 +68,12 @@ def reverse_attr_value( match value, attr_name, au_type, tech: case "present", "state", "file", Tech.ansible: return "file" + case _: + pass return value @staticmethod - def get_attr_name(name: str, au_type: str, tech: Tech) -> Optional[str]: + def get_attr_name(name: str, au_type: str, tech: Tech) -> str: """Returns the generic name of the attribute with the given name, atomic unit type and tech. Args: @@ -95,8 +99,10 @@ def get_attr_name(name: str, au_type: str, tech: Tech) -> Optional[str]: return "state" case "state", "file", Tech.ansible: return "state" + case _: + pass - return None + return name @staticmethod def get_attr_value( @@ -119,5 +125,7 @@ def get_attr_value( return "present" case "touch", "state", "file", Tech.ansible: return "present" + case _: + pass return value diff --git a/glitch/repair/interactive/delta_p.py b/glitch/repair/interactive/delta_p.py index ce599b8f..52ad7e6f 100644 --- a/glitch/repair/interactive/delta_p.py +++ b/glitch/repair/interactive/delta_p.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from abc import ABC -from typing import Optional, List, Union +from typing import Optional, List, Union, Callable, Sequence from glitch.repair.interactive.filesystem import * @@ -121,17 +121,17 @@ class PConcat(PBinOp): class PStatement(ABC): - def __get_str(self, expr: PExpr, vars: Dict[str, PExpr]) -> Optional[str]: + def __get_str(self, expr: PExpr, vars: Dict[str, PExpr]) -> str: if isinstance(expr, PEConst) and isinstance(expr.const, PStr): return expr.const.value elif isinstance(expr, PEVar): return self.__get_str(vars[expr.id], vars) elif isinstance(expr, PEUndef): - return None + return None # type: ignore raise RuntimeError(f"Unsupported expression, got {expr}") - def __eval(self, expr: PExpr, vars: Dict[str, PExpr]) -> PExpr: + def __eval(self, expr: PExpr, vars: Dict[str, PExpr]) -> PExpr | None: if isinstance(expr, PEVar) and expr.id.startswith("dejavu-condition"): return expr if isinstance(expr, PEVar): @@ -144,6 +144,8 @@ def __eval(self, expr: PExpr, vars: Dict[str, PExpr]) -> PExpr: return PEConst(PBool(True)) else: return PEConst(PBool(False)) + + return None # TODO: Add support for other operators and expressions @staticmethod @@ -162,7 +164,7 @@ def minimize(statement: "PStatement", considered_paths: List[str]) -> "PStatemen """ def minimize_aux( - statement: "PStatement", considered_paths: List[PExpr] + statement: "PStatement", considered_paths: Sequence[PExpr] ) -> "PStatement": # FIXME compile statement.path if isinstance(statement, PMkdir) and statement.path in considered_paths: @@ -211,10 +213,10 @@ def minimize_aux( return PSkip() - considered_paths = list( + considered_paths_exprs: List[PEConst] = list( map(lambda path: PEConst(const=PStr(value=path)), considered_paths) ) - return minimize_aux(statement, considered_paths) + return minimize_aux(statement, considered_paths_exprs) def to_filesystems( self, @@ -229,9 +231,9 @@ def to_filesystems( if vars is None: vars = {} - res_fss = [] + res_fss: List[FileSystemState] = [] for fs in fss: - get_str = lambda expr: self.__get_str(expr, vars) + get_str: Callable[[PExpr], str] = lambda expr: self.__get_str(expr, vars) if isinstance(self, PSkip): pass @@ -241,20 +243,23 @@ def to_filesystems( fs.state[get_str(self.path)] = File(None, None, None) elif isinstance(self, PWrite): path, content = get_str(self.path), get_str(self.content) - if isinstance(fs.state[path], File): - fs.state[path].content = content + file = fs.state.get(path) + if isinstance(file, File): + file.content = content elif isinstance(self, PRm): fs.state[get_str(self.path)] = Nil() elif isinstance(self, PCp): fs.state[get_str(self.dst)] = fs.state[get_str(self.src)] elif isinstance(self, PChmod): path, mode = get_str(self.path), get_str(self.mode) - if isinstance(fs.state[path], (File, Dir)): - fs.state[path].mode = mode + file = fs.state.get(path) + if isinstance(file, (File, Dir)): + file.mode = mode elif isinstance(self, PChown): path, owner = get_str(self.path), get_str(self.owner) - if isinstance(fs.state[path], (File, Dir)): - fs.state[path].owner = owner + file = fs.state.get(path) + if isinstance(file, (File, Dir)): + file.owner = owner elif isinstance(self, PSeq): fss_lhs = self.lhs.to_filesystems(fs, vars) for fs_lhs in fss_lhs: diff --git a/glitch/repair/interactive/filesystem.py b/glitch/repair/interactive/filesystem.py index e607031d..d1145498 100644 --- a/glitch/repair/interactive/filesystem.py +++ b/glitch/repair/interactive/filesystem.py @@ -3,12 +3,6 @@ class State: - def is_dir(self) -> bool: - return isinstance(self, Dir) - - def is_file(self) -> bool: - return isinstance(self, File) - def __str__(self) -> str: return self.__class__.__name__.lower() @@ -32,7 +26,7 @@ class Nil(State): class FileSystemState: - def __init__(self): + def __init__(self) -> None: self.state: Dict[str, State] = {} def copy(self): diff --git a/glitch/repair/interactive/solver.py b/glitch/repair/interactive/solver.py index f7f138b0..5a15174a 100644 --- a/glitch/repair/interactive/solver.py +++ b/glitch/repair/interactive/solver.py @@ -1,13 +1,13 @@ import time from copy import deepcopy -from typing import List, Callable, Tuple +from typing import List, Callable, Tuple, Any from z3 import ( Solver, sat, - BoolRef, If, StringVal, + IntVal, String, Bool, And, @@ -16,12 +16,11 @@ Or, Sum, ModelRef, - Z3PPObject, Context, + ExprRef, ) from glitch.repair.interactive.filesystem import FileSystemState -from glitch.repair.interactive.tracer.transform import get_file_system_state from glitch.repair.interactive.filesystem import * from glitch.repair.interactive.delta_p import * from glitch.repair.interactive.values import DefaultValue, UNDEF @@ -30,7 +29,7 @@ from glitch.repair.interactive.compiler.labeler import GLITCHLabeler from glitch.repair.interactive.compiler.names_database import NamesDatabase -Fun = Callable[[PStatement], Z3PPObject] +Fun = Callable[[ExprRef], ExprRef] class PatchSolver: @@ -47,16 +46,16 @@ def __init__( filesystem: FileSystemState, timeout: int = 180, ctx: Optional[Context] = None, - ): + ) -> None: # FIXME: the filesystem in here should be generated from # checking the affected paths in statement self.solver = Solver(ctx=ctx) self.timeout = timeout self.statement = statement self.sum_var = Int("sum") - self.unchanged = {} - self.vars = {} - self.holes = {} + self.unchanged: Dict[int, ExprRef] = {} + self.vars: Dict[str, ExprRef] = {} + self.holes: Dict[str, ExprRef] = {} # FIXME: check the defaults self.__funs = PatchSolver.__Funs( @@ -84,14 +83,7 @@ def __init__( self.solver.add(Sum(list(self.unchanged.values())) == self.sum_var) - def __get_default_fs(self): - # Returns the current file system state for all the files affected by the script - # TODO: For now we will consider only the files defined in the script - fs = self.statement.to_filesystems() - affected_files = fs.state.keys() - return get_file_system_state(affected_files) - - def __collect_labels(self, statement: PStatement) -> List[str]: + def __collect_labels(self, statement: PStatement | PExpr) -> List[int]: if isinstance(statement, PSeq): return self.__collect_labels(statement.lhs) + self.__collect_labels( statement.rhs @@ -102,11 +94,11 @@ def __collect_labels(self, statement: PStatement) -> List[str]: + self.__collect_labels(statement.cons) + self.__collect_labels(statement.alt) ) - elif isinstance(statement, PLet): + elif isinstance(statement, PLet) and isinstance(statement.label, int): return [statement.label] + self.__collect_labels(statement.body) return [] - def __compile_expr(self, expr: PExpr): + def __compile_expr(self, expr: PExpr) -> ExprRef: if isinstance(expr, PEConst) and isinstance(expr.const, PStr): return StringVal(expr.const.value) elif isinstance(expr, PEVar) and expr.id.startswith("dejavu-condition-"): @@ -123,24 +115,28 @@ def __compile_expr(self, expr: PExpr): raise ValueError(f"Not supported {expr}") - def __generate_hard_constraints(self, filesystem: FileSystemState): + def __generate_hard_constraints(self, filesystem: FileSystemState) -> None: for path, state in filesystem.state.items(): - self.solver.add(self.__funs.state_fun(path) == StringVal(str(state))) + self.solver.add( + self.__funs.state_fun(StringVal(path)) == StringVal(str(state)) + ) content, mode, owner = UNDEF, UNDEF, UNDEF - if state.is_file(): + if isinstance(state, File): content = UNDEF if state.content is None else state.content - if state.is_file() or state.is_dir(): + if isinstance(state, File) or isinstance(state, Dir): mode = UNDEF if state.mode is None else state.mode owner = UNDEF if state.owner is None else state.owner - self.solver.add(self.__funs.contents_fun(path) == StringVal(content)) - self.solver.add(self.__funs.mode_fun(path) == StringVal(mode)) - self.solver.add(self.__funs.owner_fun(path) == StringVal(owner)) + self.solver.add( + self.__funs.contents_fun(StringVal(path) == StringVal(content)) + ) + self.solver.add(self.__funs.mode_fun(StringVal(path) == StringVal(mode))) + self.solver.add(self.__funs.owner_fun(StringVal(path) == StringVal(owner))) def __generate_soft_constraints( self, statement: PStatement, funs: __Funs - ) -> Tuple[List[Z3PPObject], __Funs,]: + ) -> Tuple[List[ExprRef], __Funs,]: # Avoids infinite recursion funs = deepcopy(funs) # NOTE: For now it doesn't make sense to update the funs for the @@ -149,7 +145,7 @@ def __generate_soft_constraints( previous_contents_fun = funs.contents_fun previous_mode_fun = funs.mode_fun previous_owner_fun = funs.owner_fun - constraints = [] + constraints: List[ExprRef] = [] if isinstance(statement, PMkdir): path = self.__compile_expr(statement.path) @@ -219,12 +215,12 @@ def __generate_soft_constraints( hole, var = String(f"loc-{statement.label}"), String(statement.id) self.holes[f"loc-{statement.label}"] = hole self.vars[statement.id] = var - unchanged = self.unchanged[statement.label] + unchanged = self.unchanged[statement.label] # type: ignore constraints.append( - Or( - And(unchanged == 1, var == self.__compile_expr(statement.expr)), - And(unchanged == 0, var == hole), - ) + Or( # type: ignore + And(unchanged == 1, var == self.__compile_expr(statement.expr)), # type: ignore + And(unchanged == 0, var == hole), # type: ignore + ) # type: ignore ) body_constraints, funs = self.__generate_soft_constraints( statement.body, funs @@ -270,7 +266,7 @@ def __generate_soft_constraints( return constraints, funs def solve(self) -> Optional[List[ModelRef]]: - models = [] + models: List[ModelRef] = [] start = time.time() elapsed = 0 @@ -282,7 +278,7 @@ def solve(self) -> Optional[List[ModelRef]]: while lo < hi and elapsed < self.timeout: mid = (lo + hi) // 2 self.solver.push() - self.solver.add(self.sum_var >= mid) + self.solver.add(self.sum_var >= IntVal(mid)) if self.solver.check() == sat: lo = mid + 1 model = self.solver.model() @@ -300,17 +296,18 @@ def solve(self) -> Optional[List[ModelRef]]: models.append(model) # Removes conditional variables that were not used - dvars = filter(lambda v: model[v] is not None, self.vars.values()) + dvars = filter(lambda v: model[v] is not None, self.vars.values()) # type: ignore self.solver.add(Not(And([v == model[v] for v in dvars]))) if elapsed >= self.timeout: return None return models + @staticmethod def __find_atomic_unit( labeled_script: LabeledUnitBlock, attribute: Attribute - ) -> AtomicUnit: - def aux_find_atomic_unit(code_element: CodeElement) -> AtomicUnit: + ) -> Optional[AtomicUnit]: + def aux_find_atomic_unit(code_element: CodeElement) -> Optional[AtomicUnit]: if ( isinstance(code_element, AtomicUnit) and attribute in code_element.attributes @@ -337,11 +334,13 @@ def aux_find_atomic_unit(code_element: CodeElement) -> AtomicUnit: def __is_sketch(self, codeelement: CodeElement) -> bool: return codeelement.line < 0 and codeelement.column < 0 - def apply_patch(self, model_ref: ModelRef, labeled_script: LabeledUnitBlock): - changed = [] + def apply_patch( + self, model_ref: ModelRef, labeled_script: LabeledUnitBlock + ) -> None: + changed: List[Tuple[int, Any]] = [] for label, unchanged in self.unchanged.items(): - if model_ref[unchanged] == 0: + if model_ref[unchanged] == 0: # type: ignore hole = self.holes[f"loc-{label}"] changed.append((label, model_ref[hole])) @@ -354,12 +353,16 @@ def apply_patch(self, model_ref: ModelRef, labeled_script: LabeledUnitBlock): if self.__is_sketch(codeelement): atomic_unit = labeled_script.get_sketch_location(codeelement) + if not isinstance(atomic_unit, AtomicUnit): + raise RuntimeError("Atomic unit not found") + atomic_unit_type = NamesDatabase.get_au_type( atomic_unit.type, labeled_script.tech ) - codeelement.name = NamesDatabase.reverse_attr_name( + name = NamesDatabase.reverse_attr_name( codeelement.name, atomic_unit_type, labeled_script.tech ) + codeelement.name = name atomic_unit.attributes.append(codeelement) # Remove sketch label and add regular label labeled_script.remove_label(codeelement) @@ -370,10 +373,10 @@ def apply_patch(self, model_ref: ModelRef, labeled_script: LabeledUnitBlock): ) # Remove attributes that are not defined - if value == UNDEF: + if value == UNDEF and isinstance(atomic_unit, AtomicUnit): atomic_unit.attributes.remove(codeelement) labeled_script.remove_label(codeelement) - else: + elif isinstance(atomic_unit, AtomicUnit): codeelement.value = NamesDatabase.reverse_attr_value( value, codeelement.name, atomic_unit.type, labeled_script.tech ) diff --git a/glitch/repair/interactive/tracer/parser.py b/glitch/repair/interactive/tracer/parser.py index 57ec227a..47106f43 100644 --- a/glitch/repair/interactive/tracer/parser.py +++ b/glitch/repair/interactive/tracer/parser.py @@ -1,15 +1,16 @@ +# pyright: reportUnusedFunction=false, reportUnusedVariable=false import logging from enum import Enum -from ply.lex import lex -from ply.yacc import yacc +from ply.lex import lex, LexToken +from ply.yacc import yacc, YaccProduction from dataclasses import dataclass -from typing import List +from typing import List, Any @dataclass class Syscall: cmd: str - args: List[str] + args: List[Any] exitCode: int @@ -59,74 +60,74 @@ class UnlinkFlag(Enum): AT_REMOVEDIR = 0 -def parse_tracer_output(tracer_output: str, debug=False) -> Syscall: +def parse_tracer_output(tracer_output: str, debug: bool = False) -> Syscall: # Tokens defined as functions preserve order - def t_ADDRESS(t): + def t_ADDRESS(t: LexToken): r"0[xX][0-9a-fA-F]+" return t - def t_PID(t): + def t_PID(t: LexToken): r"\[pid\s\d+\]" return t - def t_COMMA(t): + def t_COMMA(t: LexToken): r"," return t - def t_EQUAL(t): + def t_EQUAL(t: LexToken): r"=" return t - def t_PIPE(t): + def t_PIPE(t: LexToken): r"\|" return t - def t_LCURLY(t): + def t_LCURLY(t: LexToken): r"\{" return t - def t_RCURLY(t): + def t_RCURLY(t: LexToken): r"\}" return t - def t_LPARENS(t): + def t_LPARENS(t: LexToken): r"\(" return t - def t_RPARENS(t): + def t_RPARENS(t: LexToken): r"\)" return t - def t_LPARENSR(t): + def t_LPARENSR(t: LexToken): r"\[" return t - def t_RPARENSR(t): + def t_RPARENSR(t: LexToken): r"\]" return t - def t_POSITIVE_NUMBER(t): + def t_POSITIVE_NUMBER(t: LexToken): r"[0-9]+" return t - def t_NEGATIVE_NUMBER(t): + def t_NEGATIVE_NUMBER(t: LexToken): "-[0-9]+" return t - def t_ID(t): + def t_ID(t: LexToken): r"[a-zA-Z][a-zA-Z0-9_]*" if t.value in [flag.name for flag in OpenFlag]: t.type = "OPEN_FLAG" - t.value = OpenFlag[t.value] + t.value = OpenFlag[t.value] # type: ignore elif t.value in [flag.name for flag in ORedFlag]: t.type = "ORED_FLAG" - t.value = ORedFlag[t.value] + t.value = ORedFlag[t.value] # type: ignore elif t.value in [flag.name for flag in UnlinkFlag]: t.type = "UNLINK_FLAG" - t.value = UnlinkFlag[t.value] + t.value = UnlinkFlag[t.value] # type: ignore return t - def t_STRING(t): + def t_STRING(t: LexToken): r"(\'([^\\]|(\\(\n|.)))*?\')|(\"([^\\]|(\\(\n|.)))*?\")" t.value = t.value[1:-1] t.lexer.lineno += t.value.count("\n") @@ -141,11 +142,11 @@ def t_STRING(t): t_ignore_ANY = r"[\t\ \n]" - def t_COMMENT(t): + def t_COMMENT(t: LexToken) -> None: r"/\*.*?\*/" # Ignore comments - def t_ANY_error(t): + def t_ANY_error(t: LexToken) -> None: logging.error(f"Illegal character {t.value[0]!r}.") t.lexer.skip(1) @@ -161,136 +162,136 @@ def t_ANY_error(t): break print(tok) - def p_syscalls_pid(p): + def p_syscalls_pid(p: YaccProduction) -> None: r"syscalls : PID syscall" p[0] = p[2] - def p_syscalls_exit(p): + def p_syscalls_exit(p: YaccProduction) -> None: r"syscalls : PID syscall exit_message" p[0] = p[2] - def p_syscalls_no_pid(p): + def p_syscalls_no_pid(p: YaccProduction) -> None: r"syscalls : syscall exit_message" p[0] = p[1] - def p_syscalls(p): + def p_syscalls(p: YaccProduction) -> None: r"syscalls : syscall" p[0] = p[1] - def p_exit_message(p): + def p_exit_message(p: YaccProduction) -> None: r"exit_message : ID LPARENS ids RPARENS" p[0] = p[1] - def p_ids(p): + def p_ids(p: YaccProduction) -> None: r"ids : ids ID" p[0] = p[1] + [p[2]] - def p_ids_single(p): + def p_ids_single(p: YaccProduction) -> None: r"ids : ID" p[0] = [p[1]] - def p_syscall(p): + def p_syscall(p: YaccProduction) -> None: r"syscall : ID LPARENS terms RPARENS EQUAL number" p[0] = Syscall(p[1], p[3], int(p[6])) - def p_terms(p): + def p_terms(p: YaccProduction) -> None: r"terms : terms COMMA term" p[0] = p[1] + [p[3]] - def p_terms_single(p): + def p_terms_single(p: YaccProduction) -> None: r"terms : term" p[0] = [p[1]] - def p_term_number(p): + def p_term_number(p: YaccProduction) -> None: r"term : number" p[0] = p[1] - def p_term_id(p): + def p_term_id(p: YaccProduction) -> None: r"term : ID" p[0] = p[1] - def p_call(p): + def p_call(p: YaccProduction) -> None: r"term : ID LPARENS terms RPARENS" p[0] = Call(p[1], p[3]) - def p_term_address(p): + def p_term_address(p: YaccProduction) -> None: r"term : ADDRESS" p[0] = p[1] - def p_term_string(p): + def p_term_string(p: YaccProduction) -> None: r"term : STRING" p[0] = p[1] - def p_term_open_flags(p): + def p_term_open_flags(p: YaccProduction) -> None: r"term : open_flags" p[0] = p[1] - def p_term_ored_flags(p): + def p_term_ored_flags(p: YaccProduction) -> None: r"term : ored_flags" p[0] = p[1] - def p_term_unlink_flags(p): + def p_term_unlink_flags(p: YaccProduction) -> None: r"term : unlink_flags" p[0] = p[1] - def p_term_list(p): + def p_term_list(p: YaccProduction) -> None: r"term : LPARENSR terms RPARENSR" p[0] = p[2] - def p_term_dict(p): + def p_term_dict(p: YaccProduction) -> None: r"term : LCURLY key_values RCURLY" p[0] = p[2] - def p_term_or(p): + def p_term_or(p: YaccProduction) -> None: r"term : term PIPE term" p[0] = BinaryOperation(p[1], p[3], "|") - def p_key_values(p): + def p_key_values(p: YaccProduction) -> None: r"key_values : key_values COMMA key_value" p[1].update({p[3][0]: p[3][1]}) p[0] = p[1] - def p_key_values_single(p): + def p_key_values_single(p: YaccProduction) -> None: r"key_values : key_value" p[0] = {p[1][0]: p[1][1]} - def p_key_value(p): + def p_key_value(p: YaccProduction) -> None: r"key_value : ID EQUAL term" p[0] = (p[1], p[3]) - def p_number(p): + def p_number(p: YaccProduction) -> None: r"number : POSITIVE_NUMBER" p[0] = p[1] - def p_number_negative(p): + def p_number_negative(p: YaccProduction) -> None: r"number : NEGATIVE_NUMBER" p[0] = p[1] - def p_open_flags_single(p): + def p_open_flags_single(p: YaccProduction) -> None: r"open_flags : OPEN_FLAG" p[0] = [p[1]] - def p_open_flags(p): + def p_open_flags(p: YaccProduction) -> None: r"open_flags : open_flags PIPE OPEN_FLAG" p[0] = p[1] + [p[3]] - def p_ored_flags_single(p): + def p_ored_flags_single(p: YaccProduction) -> None: r"ored_flags : ORED_FLAG" p[0] = [p[1]] - def p_ored_flags(p): + def p_ored_flags(p: YaccProduction) -> None: r"ored_flags : ored_flags PIPE ORED_FLAG" p[0] = p[1] + [p[3]] - def p_unlink_flags_single(p): + def p_unlink_flags_single(p: YaccProduction) -> None: r"unlink_flags : UNLINK_FLAG" p[0] = [p[1]] - def p_unlink_flags(p): + def p_unlink_flags(p: YaccProduction) -> None: r"unlink_flags : unlink_flags PIPE UNLINK_FLAG" p[0] = p[1] + [p[3]] - def p_error(p): + def p_error(p: YaccProduction) -> None: logging.error(f"Syntax error at {p.value!r}") # Build the parser diff --git a/glitch/repair/interactive/tracer/tracer.py b/glitch/repair/interactive/tracer/tracer.py index 35ddf4d9..faabb34d 100644 --- a/glitch/repair/interactive/tracer/tracer.py +++ b/glitch/repair/interactive/tracer/tracer.py @@ -4,19 +4,15 @@ from typing import List from glitch.repair.interactive.tracer.parser import parse_tracer_output from glitch.repair.interactive.tracer.model import get_syscall_with_type, Syscall -from glitch.repair.interactive.tracer.transform import ( - get_affected_paths, - get_file_system_state, -) class STrace(threading.Thread): - def __init__(self, pid: str): + def __init__(self, pid: str) -> None: threading.Thread.__init__(self) self.syscalls: List[Syscall] = [] self.pid = pid - def run(self) -> None: + def run(self) -> List[Syscall]: # type: ignore proc = subprocess.Popen( [ "sudo", @@ -36,6 +32,9 @@ def run(self) -> None: universal_newlines=True, ) + if proc.stdout is None: + return self.syscalls + for line in proc.stdout: if ( line.startswith("strace: Process") diff --git a/glitch/repair/interactive/tracer/transform.py b/glitch/repair/interactive/tracer/transform.py index 3f5ac380..4314cf3c 100644 --- a/glitch/repair/interactive/tracer/transform.py +++ b/glitch/repair/interactive/tracer/transform.py @@ -1,6 +1,6 @@ import os from pwd import getpwuid -from typing import Set +from typing import Set, Callable from glitch.repair.interactive.tracer.model import * from glitch.repair.interactive.filesystem import * @@ -17,7 +17,7 @@ def get_affected_paths(workdir: str, syscalls: List[Syscall]) -> Set[str]: Set[str]: A set of all paths affected by the given syscalls. """ - def abspath(workdir, path): + def abspath(workdir: str, path: str): if os.path.isabs(path): return path return os.path.realpath(os.path.join(workdir, path)) @@ -61,8 +61,8 @@ def get_file_system_state(files: Set[str]) -> FileSystemState: FileSystemState: The file system state. """ fs = FileSystemState() - get_owner = lambda f: getpwuid(os.stat(f).st_uid).pw_name - get_mode = lambda f: oct(os.stat(f).st_mode & 0o777)[2:] + get_owner: Callable[[str], str] = lambda f: getpwuid(os.stat(f).st_uid).pw_name + get_mode: Callable[[str], str] = lambda f: oct(os.stat(f).st_mode & 0o777)[2:] for file in files: if not os.path.exists(file): diff --git a/glitch/repr/inter.py b/glitch/repr/inter.py index 13cfb5f7..13811d57 100644 --- a/glitch/repr/inter.py +++ b/glitch/repr/inter.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum +from typing import List, Union class CodeElement(ABC): @@ -20,16 +21,16 @@ def __str__(self) -> str: return self.__repr__() @abstractmethod - def print(self, tab): + def print(self, tab: int) -> str: pass class Block(CodeElement): def __init__(self) -> None: super().__init__() - self.statements = [] + self.statements: List[CodeElement] = [] - def add_statement(self, statement): + def add_statement(self, statement: "ConditionalStatement") -> None: self.statements.append(statement) @@ -38,17 +39,22 @@ class ConditionType(Enum): IF = 1 SWITCH = 2 - def __init__(self, condition: str, type, is_default=False) -> None: + def __init__( + self, + condition: str, + type: "ConditionalStatement.ConditionType", + is_default: bool = False, + ) -> None: super().__init__() self.condition: str = condition - self.else_statement = None + self.else_statement: ConditionalStatement | None = None self.is_default = is_default self.type = type def __repr__(self) -> str: return self.code.strip().split("\n")[0] - def print(self, tab) -> str: + def print(self, tab: int) -> str: res = ( (tab * "\t") + str(self.type) @@ -81,16 +87,16 @@ def __init__(self, content: str) -> None: def __repr__(self) -> str: return self.content - def print(self, tab) -> str: + def print(self, tab: int) -> str: return (tab * "\t") + self.content + " (on line " + str(self.line) + ")" class KeyValue(CodeElement): - def __init__(self, name: str, value: str, has_variable: bool): + def __init__(self, name: str, value: str | None, has_variable: bool) -> None: self.name: str = name - self.value: str = value + self.value: str | None = value self.has_variable: bool = has_variable - self.keyvalues: list = [] + self.keyvalues: List[KeyValue] = [] def __repr__(self) -> str: value = repr(self.value).split("\n")[0] @@ -101,10 +107,10 @@ def __repr__(self) -> str: class Variable(KeyValue): - def __init__(self, name: str, value: str, has_variable: bool) -> None: + def __init__(self, name: str, value: str | None, has_variable: bool) -> None: super().__init__(name, value, has_variable) - def print(self, tab) -> str: + def print(self, tab: int) -> str: if isinstance(self.value, str): return ( (tab * "\t") @@ -143,7 +149,7 @@ class Attribute(KeyValue): def __init__(self, name: str, value: str, has_variable: bool) -> None: super().__init__(name, value, has_variable) - def print(self, tab) -> str: + def print(self, tab: int) -> str: if isinstance(self.value, str): return ( (tab * "\t") @@ -179,9 +185,9 @@ def print(self, tab) -> str: class AtomicUnit(Block): - def __init__(self, name: str, type: str) -> None: + def __init__(self, name: str | None, type: str) -> None: super().__init__() - self.name: str = name + self.name: str | None = name self.type: str = type self.attributes: list[Attribute] = [] @@ -191,15 +197,11 @@ def add_attribute(self, a: Attribute) -> None: def __repr__(self) -> str: return f"{self.name} {self.type}" - def print(self, tab) -> str: + def print(self, tab: int) -> str: res = ( - (tab * "\t") - + self.type - + " " - + self.name - + " (on line " - + str(self.line) - + ")\n" + (tab * "\t") + self.type + " " + self.name + if self.name is not None + else "" + " (on line " + str(self.line) + ")\n" ) for attribute in self.attributes: @@ -221,7 +223,7 @@ def __init__(self, name: str) -> None: def __repr__(self) -> str: return self.name - def print(self, tab) -> str: + def print(self, tab: int) -> str: return (tab * "\t") + self.name + " (on line " + str(self.line) + ")" @@ -242,12 +244,12 @@ def __init__(self, name: str, type: UnitBlockType) -> None: self.atomic_units: list[AtomicUnit] = [] self.unit_blocks: list["UnitBlock"] = [] self.attributes: list[Attribute] = [] - self.name: str = name + self.name: str | None = name self.path: str = "" self.type: UnitBlockType = type def __repr__(self) -> str: - return self.name + return self.name if self.name is not None else "" def add_dependency(self, d: Dependency) -> None: self.dependencies.append(d) @@ -267,8 +269,11 @@ def add_unit_block(self, u: "UnitBlock") -> None: def add_attribute(self, a: Attribute) -> None: self.attributes.append(a) - def print(self, tab) -> str: - res = (tab * "\t") + self.name + "\n" + def print(self, tab: int) -> str: + if self.name is not None: + res = (tab * "\t") + self.name + "\n" + else: + res = "" res += (tab * "\t") + "\tdependencies:\n" for dependency in self.dependencies: @@ -302,16 +307,16 @@ def print(self, tab) -> str: class File: - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: self.name: str = name - def print(self, tab) -> str: + def print(self, tab: int) -> str: return (tab * "\t") + self.name class Folder: - def __init__(self, name) -> None: - self.content: list = [] + def __init__(self, name: str) -> None: + self.content: List[Union["Folder", File]] = [] self.name: str = name def add_folder(self, folder: "Folder") -> None: @@ -320,7 +325,7 @@ def add_folder(self, folder: "Folder") -> None: def add_file(self, file: File) -> None: self.content.append(file) - def print(self, tab) -> str: + def print(self, tab: int) -> str: res = (tab * "\t") + self.name + "\n" for c in self.content: @@ -331,10 +336,11 @@ def print(self, tab) -> str: class Module: - def __init__(self, name, path) -> None: + def __init__(self, name: str, path: str) -> None: self.name: str = name self.path: str = path self.blocks: list[UnitBlock] = [] + self.modules: list[Module] = [] self.folder: Folder = Folder(name) def __repr__(self) -> str: @@ -343,7 +349,7 @@ def __repr__(self) -> str: def add_block(self, u: UnitBlock) -> None: self.blocks.append(u) - def print(self, tab) -> str: + def print(self, tab: int) -> str: res = (tab * "\t") + self.name + "\n" res += (tab * "\t") + "\tblocks:\n" @@ -357,7 +363,7 @@ def print(self, tab) -> str: class Project: - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: self.name: str = name self.modules: list[Module] = [] self.blocks: list[UnitBlock] = [] @@ -365,13 +371,13 @@ def __init__(self, name) -> None: def __repr__(self) -> str: return self.name - def add_module(self, m: Module): + def add_module(self, m: Module) -> None: self.modules.append(m) - def add_block(self, u: UnitBlock): + def add_block(self, u: UnitBlock) -> None: self.blocks.append(u) - def print(self, tab) -> str: + def print(self, tab: int) -> str: res = self.name + "\n" res += (tab * "\t") + "\tmodules:\n" diff --git a/glitch/stats/print.py b/glitch/stats/print.py index a64457b2..a9ff18ea 100644 --- a/glitch/stats/print.py +++ b/glitch/stats/print.py @@ -1,13 +1,16 @@ -from glob import escape -import pandas as pd -from glitch.analysis.rules import Error +import pandas as pd # type: ignore from prettytable import PrettyTable +from typing import List, Dict, Set, Tuple +from glitch.analysis.rules import Error +from glitch.stats.stats import FileStats -def print_stats(errors, smells, file_stats, format): +def print_stats( + errors: List[Error], smells: List[str], file_stats: FileStats, format: str +) -> None: total_files = len(file_stats.files) - occurrences = {} - files_with_the_smell = {"Combined": set()} + occurrences: Dict[str, int] = {} + files_with_the_smell: Dict[str, Set[str]] = {"Combined": set()} for smell in smells: occurrences[smell] = 0 @@ -18,27 +21,27 @@ def print_stats(errors, smells, file_stats, format): files_with_the_smell[error.code].add(error.path) files_with_the_smell["Combined"].add(error.path) - stats_info = [] + stats_info: List[Tuple[str, int, float, float]] = [] total_occur = 0 total_smell_density = 0 for code, n in occurrences.items(): total_occur += n total_smell_density += round(n / (file_stats.loc / 1000), 2) stats_info.append( - [ + ( Error.ALL_ERRORS[code], n, round(n / (file_stats.loc / 1000), 2), round((len(files_with_the_smell[code]) / total_files) * 100, 2), - ] + ) ) stats_info.append( - [ + ( "Combined", total_occur, total_smell_density, round((len(files_with_the_smell["Combined"]) / total_files) * 100, 2), - ] + ) ) if format == "prettytable": @@ -49,13 +52,17 @@ def print_stats(errors, smells, file_stats, format): "Smell density (Smell/KLoC)", "Proportion of scripts (%)", ] - table.align["Smell"] = "r" - table.align["Occurrences"] = "l" - table.align["Smell density (Smell/KLoC)"] = "l" - table.align["Proportion of scripts (%)"] = "l" + + table.align["Smell"] = "r" # type: ignore + table.align["Occurrences"] = "l" # type: ignore + table.align["Smell density (Smell/KLoC)"] = "l" # type: ignore + table.align["Proportion of scripts (%)"] = "l" # type: ignore smells_info = stats_info[:-1] - for smell in smells_info: - smell[0] = smell[0].split(" - ")[0] + + smells_info = map( + lambda smell: (smell[0].split(" - ")[0], smell[1], smell[2], smell[3]), + smells_info, + ) smells_info = sorted(smells_info, key=lambda x: x[0]) biggest_value = [len(name) for name in table.field_names] @@ -64,22 +71,26 @@ def print_stats(errors, smells, file_stats, format): if len(str(s)) > biggest_value[i]: biggest_value[i] = len(str(s)) - table.add_row(stats) + table.add_row(stats) # type: ignore div_row = [i * "-" for i in biggest_value] - table.add_row(div_row) - table.add_row(stats_info[-1]) + table.add_row(div_row) # type: ignore + table.add_row(stats_info[-1]) # type: ignore print(table) attributes = PrettyTable() attributes.field_names = ["Total IaC files", "Lines of Code"] - attributes.add_row([total_files, file_stats.loc]) + attributes.add_row([total_files, file_stats.loc]) # type: ignore print(attributes) elif format == "latex": smells_info = stats_info[:-1] smells_info = sorted(smells_info, key=lambda x: x[0]) - for smell in smells_info: - smell[0] = smell[0].split(" - ")[0] + smells_info = list( + map( + lambda smell: (smell[0].split(" - ")[0], smell[1], smell[2], smell[3]), + smells_info, + ) + ) smells_info.append(stats_info[-1]) table = pd.DataFrame( smells_info, @@ -87,24 +98,24 @@ def print_stats(errors, smells, file_stats, format): "\\textbf{Smell}", "\\textbf{Occurrences}", "\\textbf{Smell density (Smell/KLoC)}", - "\\textbf{Proportion of scripts (\%)}", + "\\textbf{Proportion of scripts (%)}", ], ) - latex = ( - table.style.hide(axis="index") - .format(escape=None, precision=2, thousands=",") - .to_latex() + latex = ( # type: ignore + table.style.hide(axis="index") # type: ignore + .format(escape=None, precision=2, thousands=",") # type: ignore + .to_latex() # type: ignore ) - combined = latex[: latex.rfind("\\\\")].rfind("\\\\") - latex = latex[:combined] + "\\\\\n\midrule\n" + latex[combined + 3 :] - print(latex) + combined = latex[: latex.rfind("\\\\")].rfind("\\\\") # type: ignore + latex = latex[:combined] + "\\\\\n\\midrule\n" + latex[combined + 3 :] # type: ignore + print(latex) # type: ignore attributes = pd.DataFrame( [[total_files, file_stats.loc]], - columns=["\\textbf{Total IaC files}", "\\textbf{Lines of Code}"], + columns=["\\textbf{Total IaC files}", "\\textbf{Lines of Code}"], # type: ignore ) print( - attributes.style.hide(axis="index") - .format(escape=None, precision=2, thousands=",") - .to_latex() + attributes.style.hide(axis="index") # type: ignore + .format(escape=None, precision=2, thousands=",") # type: ignore + .to_latex() # type: ignore ) diff --git a/glitch/stats/stats.py b/glitch/stats/stats.py index 8acd56af..92318280 100644 --- a/glitch/stats/stats.py +++ b/glitch/stats/stats.py @@ -1,11 +1,16 @@ import os +from typing import Union, Set from abc import ABC, abstractmethod from glitch.repr.inter import * +CodeElementDict = dict[ + Union["CodeElementDict", CodeElement], Union["CodeElementDict", CodeElement] +] + class Stats(ABC): - def compute(self, c): + def compute(self, c: CodeElement | Project | Module | CodeElementDict) -> None: if isinstance(c, Project): self.compute_project(c) elif isinstance(c, Module): @@ -69,16 +74,16 @@ def compute_comment(self, c: Comment): class FileStats(Stats): def __init__(self) -> None: super().__init__() - self.files = set() + self.files: Set[str] = set() self.loc = 0 - def compute_project(self, p: Project): + def compute_project(self, p: Project) -> None: for m in p.modules: self.compute(m) for u in p.blocks: self.compute(u) - def compute_module(self, m: Module): + def compute_module(self, m: Module) -> None: for u in m.blocks: self.compute(u) if os.path.isfile(m.path) and m.path not in self.files: @@ -86,7 +91,7 @@ def compute_module(self, m: Module): with open(m.path, "r") as f: self.loc += len(f.readlines()) - def compute_unitblock(self, u: UnitBlock): + def compute_unitblock(self, u: UnitBlock) -> None: for ub in u.unit_blocks: self.compute(ub) if os.path.isfile(u.path) and u.path not in self.files: @@ -97,20 +102,20 @@ def compute_unitblock(self, u: UnitBlock): except UnicodeDecodeError: pass - def compute_atomicunit(self, au: AtomicUnit): + def compute_atomicunit(self, au: AtomicUnit) -> None: pass - def compute_dependency(self, d: Dependency): + def compute_dependency(self, d: Dependency) -> None: pass - def compute_attribute(self, a: Attribute): + def compute_attribute(self, a: Attribute) -> None: pass - def compute_variable(self, v: Variable): + def compute_variable(self, v: Variable) -> None: pass - def compute_condition(self, c: ConditionalStatement): + def compute_condition(self, c: ConditionalStatement) -> None: pass - def compute_comment(self, c: Comment): + def compute_comment(self, c: Comment) -> None: pass diff --git a/glitch/tests/design/ansible/test_design.py b/glitch/tests/design/ansible/test_design.py index 22de353f..e4789626 100644 --- a/glitch/tests/design/ansible/test_design.py +++ b/glitch/tests/design/ansible/test_design.py @@ -6,7 +6,7 @@ class TestDesign(unittest.TestCase): - def __help_test(self, path, type, n_errors, codes, lines): + def __help_test(self, path, type, n_errors: int, codes, lines) -> None: parser = AnsibleParser() inter = parser.parse(path, type, False) analysis = DesignVisitor(Tech.ansible) @@ -24,7 +24,7 @@ def __help_test(self, path, type, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_ansible_long_statement(self): + def test_ansible_long_statement(self) -> None: self.__help_test( "tests/design/ansible/files/long_statement.yml", "tasks", @@ -34,7 +34,7 @@ def test_ansible_long_statement(self): ) # Tabs - def test_ansible_improper_alignment(self): + def test_ansible_improper_alignment(self) -> None: self.__help_test( "tests/design/ansible/files/improper_alignment.yml", "tasks", @@ -48,7 +48,7 @@ def test_ansible_improper_alignment(self): [2, 4, 5, 6], ) - def test_ansible_duplicate_block(self): + def test_ansible_duplicate_block(self) -> None: self.__help_test( "tests/design/ansible/files/duplicate_block.yml", "tasks", @@ -62,7 +62,7 @@ def test_ansible_duplicate_block(self): [2, 10, 25, 33], ) - def test_ansible_avoid_comments(self): + def test_ansible_avoid_comments(self) -> None: self.__help_test( "tests/design/ansible/files/avoid_comments.yml", "tasks", @@ -73,7 +73,7 @@ def test_ansible_avoid_comments(self): [11], ) - def test_ansible_long_resource(self): + def test_ansible_long_resource(self) -> None: self.__help_test( "tests/design/ansible/files/long_resource.yml", "tasks", @@ -85,7 +85,7 @@ def test_ansible_long_resource(self): [2, 2], ) - def test_ansible_multifaceted_abstraction(self): + def test_ansible_multifaceted_abstraction(self) -> None: self.__help_test( "tests/design/ansible/files/multifaceted_abstraction.yml", "tasks", @@ -96,7 +96,7 @@ def test_ansible_multifaceted_abstraction(self): [2, 2], ) - def test_ansible_too_many_variables(self): + def test_ansible_too_many_variables(self) -> None: self.__help_test( "tests/design/ansible/files/too_many_variables.yml", "script", diff --git a/glitch/tests/design/chef/test_design.py b/glitch/tests/design/chef/test_design.py index 131d251d..bb2177fc 100644 --- a/glitch/tests/design/chef/test_design.py +++ b/glitch/tests/design/chef/test_design.py @@ -6,7 +6,7 @@ class TestDesign(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = ChefParser() inter = parser.parse(path, "script", False) analysis = DesignVisitor(Tech.chef) @@ -24,7 +24,7 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_chef_long_statement(self): + def test_chef_long_statement(self) -> None: self.__help_test( "tests/design/chef/files/long_statement.rb", 1, @@ -32,7 +32,7 @@ def test_chef_long_statement(self): [6], ) - def test_chef_improper_alignment(self): + def test_chef_improper_alignment(self) -> None: self.__help_test( "tests/design/chef/files/improper_alignment.rb", 1, @@ -40,7 +40,7 @@ def test_chef_improper_alignment(self): [1], ) - def test_chef_duplicate_block(self): + def test_chef_duplicate_block(self) -> None: self.__help_test( "tests/design/chef/files/duplicate_block.rb", 4, @@ -53,7 +53,7 @@ def test_chef_duplicate_block(self): [3, 4, 9, 10], ) - def test_chef_avoid_comments(self): + def test_chef_avoid_comments(self) -> None: self.__help_test( "tests/design/chef/files/avoid_comments.rb", 1, @@ -63,7 +63,7 @@ def test_chef_avoid_comments(self): [7], ) - def test_chef_long_resource(self): + def test_chef_long_resource(self) -> None: self.__help_test( "tests/design/chef/files/long_resource.rb", 1, @@ -73,7 +73,7 @@ def test_chef_long_resource(self): [1], ) - def test_chef_multifaceted_abstraction(self): + def test_chef_multifaceted_abstraction(self) -> None: self.__help_test( "tests/design/chef/files/multifaceted_abstraction.rb", 1, @@ -83,7 +83,7 @@ def test_chef_multifaceted_abstraction(self): [1], ) - def test_chef_misplaced_attribute(self): + def test_chef_misplaced_attribute(self) -> None: self.__help_test( "tests/design/chef/files/misplaced_attribute.rb", 1, @@ -93,7 +93,7 @@ def test_chef_misplaced_attribute(self): [1], ) - def test_chef_too_many_variables(self): + def test_chef_too_many_variables(self) -> None: self.__help_test( "tests/design/chef/files/too_many_variables.rb", 1, diff --git a/glitch/tests/design/docker/test_design.py b/glitch/tests/design/docker/test_design.py index 923a9de4..35ef1a61 100644 --- a/glitch/tests/design/docker/test_design.py +++ b/glitch/tests/design/docker/test_design.py @@ -6,7 +6,7 @@ class TestDesign(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = DockerParser() inter = parser.parse(path, "script", False) analysis = DesignVisitor(Tech.docker) @@ -24,7 +24,7 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_docker_long_statement(self): + def test_docker_long_statement(self) -> None: self.__help_test( "tests/design/docker/files/long_statement.Dockerfile", 1, @@ -32,7 +32,7 @@ def test_docker_long_statement(self): [4], ) - def test_docker_improper_alignment(self): + def test_docker_improper_alignment(self) -> None: # TODO: Fix smell, due to docker parsing method the attributes are not # detected in differents lines, making it impossible to trigger alignment pass @@ -44,7 +44,7 @@ def test_docker_improper_alignment(self): # ], [1] # ) - def test_docker_duplicate_block(self): + def test_docker_duplicate_block(self) -> None: self.__help_test( "tests/design/docker/files/duplicate_block.Dockerfile", 2, @@ -55,7 +55,7 @@ def test_docker_duplicate_block(self): [1, 9], ) - def test_docker_avoid_comments(self): + def test_docker_avoid_comments(self) -> None: self.__help_test( "tests/design/docker/files/avoid_comments.Dockerfile", 1, @@ -65,7 +65,7 @@ def test_docker_avoid_comments(self): [1], ) - def test_docker_too_many_variables(self): + def test_docker_too_many_variables(self) -> None: self.__help_test( "tests/design/docker/files/too_many_variables.Dockerfile", 1, diff --git a/glitch/tests/design/puppet/test_design.py b/glitch/tests/design/puppet/test_design.py index 07058194..f5cd3fe7 100644 --- a/glitch/tests/design/puppet/test_design.py +++ b/glitch/tests/design/puppet/test_design.py @@ -6,7 +6,7 @@ class TestDesign(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = PuppetParser() inter = parser.parse(path, "script", False) analysis = DesignVisitor(Tech.puppet) @@ -24,7 +24,7 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_puppet_long_statement(self): + def test_puppet_long_statement(self) -> None: self.__help_test( "tests/design/puppet/files/long_statement.pp", 1, @@ -32,7 +32,7 @@ def test_puppet_long_statement(self): [6], ) - def test_puppet_improper_alignment(self): + def test_puppet_improper_alignment(self) -> None: self.__help_test( "tests/design/puppet/files/improper_alignment.pp", 1, @@ -40,7 +40,7 @@ def test_puppet_improper_alignment(self): [1], ) - def test_puppet_duplicate_block(self): + def test_puppet_duplicate_block(self) -> None: self.__help_test( "tests/design/puppet/files/duplicate_block.pp", 2, @@ -51,7 +51,7 @@ def test_puppet_duplicate_block(self): [1, 10], ) - def test_puppet_avoid_comments(self): + def test_puppet_avoid_comments(self) -> None: self.__help_test( "tests/design/puppet/files/avoid_comments.pp", 1, @@ -61,7 +61,7 @@ def test_puppet_avoid_comments(self): [5], ) - def test_puppet_long_resource(self): + def test_puppet_long_resource(self) -> None: self.__help_test( "tests/design/puppet/files/long_resource.pp", 1, @@ -71,7 +71,7 @@ def test_puppet_long_resource(self): [1], ) - def test_puppet_multifaceted_abstraction(self): + def test_puppet_multifaceted_abstraction(self) -> None: self.__help_test( "tests/design/puppet/files/multifaceted_abstraction.pp", 2, @@ -79,7 +79,7 @@ def test_puppet_multifaceted_abstraction(self): [1, 2], ) - def test_puppet_unguarded_variable(self): + def test_puppet_unguarded_variable(self) -> None: self.__help_test( "tests/design/puppet/files/unguarded_variable.pp", 1, @@ -89,7 +89,7 @@ def test_puppet_unguarded_variable(self): [12], ) - def test_puppet_misplaced_attribute(self): + def test_puppet_misplaced_attribute(self) -> None: self.__help_test( "tests/design/puppet/files/misplaced_attribute.pp", 1, @@ -99,7 +99,7 @@ def test_puppet_misplaced_attribute(self): [1], ) - def test_puppet_too_many_variables(self): + def test_puppet_too_many_variables(self) -> None: self.__help_test( "tests/design/puppet/files/too_many_variables.pp", 1, diff --git a/glitch/tests/design/terraform/test_design.py b/glitch/tests/design/terraform/test_design.py index 7d7d07bd..bacfc8a9 100644 --- a/glitch/tests/design/terraform/test_design.py +++ b/glitch/tests/design/terraform/test_design.py @@ -6,7 +6,7 @@ class TestDesign(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = TerraformParser() inter = parser.parse(path, "script", False) analysis = DesignVisitor(Tech.terraform) @@ -24,7 +24,7 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_terraform_long_statement(self): + def test_terraform_long_statement(self) -> None: self.__help_test( "tests/design/terraform/files/long_statement.tf", 1, @@ -32,7 +32,7 @@ def test_terraform_long_statement(self): [6], ) - def test_terraform_improper_alignment(self): + def test_terraform_improper_alignment(self) -> None: self.__help_test( "tests/design/terraform/files/improper_alignment.tf", 1, @@ -40,7 +40,7 @@ def test_terraform_improper_alignment(self): [1], ) - def test_terraform_duplicate_block(self): + def test_terraform_duplicate_block(self) -> None: self.__help_test( "tests/design/terraform/files/duplicate_block.tf", 2, @@ -51,7 +51,7 @@ def test_terraform_duplicate_block(self): [1, 10], ) - def test_terraform_avoid_comments(self): + def test_terraform_avoid_comments(self) -> None: self.__help_test( "tests/design/terraform/files/avoid_comments.tf", 2, @@ -62,7 +62,7 @@ def test_terraform_avoid_comments(self): [2, 8], ) - def test_terraform_too_many_variables(self): + def test_terraform_too_many_variables(self) -> None: self.__help_test( "tests/design/terraform/files/too_many_variables.tf", 1, diff --git a/glitch/tests/hierarchical/test_parsers.py b/glitch/tests/hierarchical/test_parsers.py index 5ba484e3..f9c0bd1b 100644 --- a/glitch/tests/hierarchical/test_parsers.py +++ b/glitch/tests/hierarchical/test_parsers.py @@ -5,7 +5,7 @@ class TestAnsible(unittest.TestCase): - def __test_parse_vars(self, path, vars): + def __test_parse_vars(self, path, vars) -> None: with open(path, "r") as file: unitblock = AnsibleParser._AnsibleParser__parse_vars_file( self, "test", file @@ -13,18 +13,18 @@ def __test_parse_vars(self, path, vars): self.assertEqual(str(unitblock.variables), vars) file.close() - def __test_parse_attributes(self, path, attributes): + def __test_parse_attributes(self, path, attributes) -> None: with open(path, "r") as file: unitblock = AnsibleParser._AnsibleParser__parse_playbook(self, "test", file) play = unitblock.unit_blocks[0] self.assertEqual(str(play.attributes), attributes) file.close() - def test_hierarchichal_vars(self): + def test_hierarchichal_vars(self) -> None: vars = "[test[0]:None:[test1[0]:\"['1', '2']\"], test[1]:\"['3', '4']\", test:\"['x', 'y', '23']\", test2[0]:\"['2', '5', '6']\", vars:None:[factorial_of:'5', factorial_value:'1']]" self.__test_parse_vars("tests/hierarchical/ansible/vars.yml", vars) - def test_hierarchical_attributes(self): + def test_hierarchical_attributes(self) -> None: attributes = "[hosts:'localhost', debug:None:[msg:'The factorial of 5 is {{ factorial_value }}', seq[0]:None:[test:'something'], seq:\"['y', 'z']\", hash:None:[test1:'1', test2:'2']]]" self.__test_parse_attributes( "tests/hierarchical/ansible/attributes.yml", attributes @@ -32,21 +32,21 @@ def test_hierarchical_attributes(self): class TestPuppet(unittest.TestCase): - def __test_parse_vars(self, path, vars): + def __test_parse_vars(self, path, vars) -> None: unitblock = PuppetParser().parse_file(path, None) self.assertEqual(str(unitblock.variables), vars) - def test_hierarchical_vars(self): + def test_hierarchical_vars(self) -> None: vars = "[$my_hash:None:[key1:None:[test1:'1', test2:'2'], key2:'value2', key3:'value3', key4:None:[key5:'value5']], $configdir:'${boxen::config::configdir}/php', $datadir:'${boxen::config::datadir}/php', $pluginsdir:'${root}/plugins', $cachedir:'${php::config::datadir}/cache', $extensioncachedir:'${php::config::datadir}/cache/extensions']" self.__test_parse_vars("tests/hierarchical/puppet/vars.pp", vars) class TestChef(unittest.TestCase): - def __test_parse_vars(self, path, vars): + def __test_parse_vars(self, path, vars) -> None: unitblock = ChefParser().parse_file(path, None) self.assertEqual(str(unitblock.variables), vars) - def test_hierarchical_vars(self): + def test_hierarchical_vars(self) -> None: vars = "[grades:None:[Jane Doe:'10', Jim Doe:'6'], default:None:[zabbix:None:[database:None:[password:''], test:None:[name:'something']]]]" self.__test_parse_vars("tests/hierarchical/chef/vars.rb", vars) diff --git a/glitch/tests/parser/puppet/test_parser.py b/glitch/tests/parser/puppet/test_parser.py index 0abcb7cc..626484ff 100644 --- a/glitch/tests/parser/puppet/test_parser.py +++ b/glitch/tests/parser/puppet/test_parser.py @@ -4,7 +4,7 @@ class TestPuppetParser(unittest.TestCase): - def test_puppet_parser_if(self): + def test_puppet_parser_if(self) -> None: unit_block = PuppetParser().parse_file("tests/parser/puppet/files/if.pp", None) assert len(unit_block.statements) == 1 assert isinstance(unit_block.statements[0], ConditionalStatement) diff --git a/glitch/tests/parser/terraform/test_parser.py b/glitch/tests/parser/terraform/test_parser.py index 121e026c..97fd9f94 100644 --- a/glitch/tests/parser/terraform/test_parser.py +++ b/glitch/tests/parser/terraform/test_parser.py @@ -1,63 +1,64 @@ import unittest from glitch.parsers.terraform import TerraformParser +from typing import Sequence class TestTerraform(unittest.TestCase): - def __help_test(self, path, attributes): + def __help_test(self, path, attributes) -> None: unitblock = TerraformParser().parse_file(path, None) au = unitblock.atomic_units[0] self.assertEqual(str(au.attributes), attributes) - def __help_test_comments(self, path, comments): + def __help_test_comments(self, path, comments: Sequence[str]) -> None: unitblock = TerraformParser().parse_file(path, None) self.assertEqual(str(unitblock.comments), comments) - def test_terraform_null_value(self): + def test_terraform_null_value(self) -> None: attributes = "[account_id:'']" self.__help_test( "tests/parser/terraform/files/null_value_assign.tf", attributes ) - def test_terraform_empty_string(self): + def test_terraform_empty_string(self) -> None: attributes = "[account_id:'']" self.__help_test( "tests/parser/terraform/files/empty_string_assign.tf", attributes ) - def test_terraform_boolean_value(self): + def test_terraform_boolean_value(self) -> None: attributes = "[account_id:'True']" self.__help_test( "tests/parser/terraform/files/boolean_value_assign.tf", attributes ) - def test_terraform_multiline_string(self): + def test_terraform_multiline_string(self) -> None: attributes = "[user_data:' #!/bin/bash\\n sudo apt-get update\\n sudo apt-get install -y apache2\\n sudo systemctl start apache2']" self.__help_test( "tests/parser/terraform/files/multiline_string_assign.tf", attributes ) - def test_terraform_value_has_variable(self): + def test_terraform_value_has_variable(self) -> None: attributes = "[access:None:[user_by_email:'${google_service_account.bqowner.email}'], test:'${var.value1}']" self.__help_test( "tests/parser/terraform/files/value_has_variable.tf", attributes ) - def test_terraform_dict_value(self): + def test_terraform_dict_value(self) -> None: attributes = "[labels:None:[env:'default']]" self.__help_test( "tests/parser/terraform/files/dict_value_assign.tf", attributes ) - def test_terraform_list_value(self): + def test_terraform_list_value(self) -> None: attributes = "[keys[0]:'value1', keys[1][0]:'1', keys[1][1]:None:[key2:'value2'], keys[2]:None:[key3:'value3']]" self.__help_test( "tests/parser/terraform/files/list_value_assign.tf", attributes ) - def test_terraform_dynamic_block(self): + def test_terraform_dynamic_block(self) -> None: attributes = "[dynamic.setting:None:[content:None:[namespace:'${setting.value[\"namespace\"]}']]]" self.__help_test("tests/parser/terraform/files/dynamic_block.tf", attributes) - def test_terraform_comments(self): + def test_terraform_comments(self) -> None: comments = "[#comment1\n, //comment2\n, /*comment3\n default_table_expiration_ms = 3600000\n \n finish comment3 */, #comment4\n, #comment5\n, #comment inside dict\n, //comment2 inside dict\n]" self.__help_test_comments("tests/parser/terraform/files/comments.tf", comments) diff --git a/glitch/tests/repair/interactive/test_delta_p.py b/glitch/tests/repair/interactive/test_delta_p.py index fdb98361..2a58bc73 100644 --- a/glitch/tests/repair/interactive/test_delta_p.py +++ b/glitch/tests/repair/interactive/test_delta_p.py @@ -8,7 +8,7 @@ from tempfile import NamedTemporaryFile -def test_delta_p_compiler_puppet(): +def test_delta_p_compiler_puppet() -> None: puppet_script = """ file { '/var/www/customers/public_html/index.php': path => '/var/www/customers/public_html/index.php', @@ -37,7 +37,7 @@ def test_delta_p_compiler_puppet(): assert statement == delta_p_puppet -def test_delta_p_compiler_puppet_2(): +def test_delta_p_compiler_puppet_2() -> None: puppet_script = """ file {'/usr/sbin/policy-rc.d': ensure => absent, @@ -62,7 +62,7 @@ def test_delta_p_compiler_puppet_2(): assert statement == delta_p_puppet_2 -def test_delta_p_compiler_puppet_if(): +def test_delta_p_compiler_puppet_if() -> None: puppet_script = """ if $x == 'absent' { file {'/usr/sbin/policy-rc.d': @@ -85,7 +85,7 @@ def test_delta_p_compiler_puppet_if(): assert statement == delta_p_puppet_if -def test_delta_p_compiler_puppet_default_state(): +def test_delta_p_compiler_puppet_default_state() -> None: puppet_script = """ file { '/root/.ssh/config': content => template('fuel/root_ssh_config.erb'), @@ -104,7 +104,7 @@ def test_delta_p_compiler_puppet_default_state(): assert statement == delta_p_puppet_default_state -def test_delta_p_to_filesystems(): +def test_delta_p_to_filesystems() -> None: statement = delta_p_puppet fss = statement.to_filesystems() assert len(fss) == 1 @@ -115,14 +115,14 @@ def test_delta_p_to_filesystems(): } -def test_delta_p_to_filesystems_2(): +def test_delta_p_to_filesystems_2() -> None: statement = delta_p_puppet_2 fss = statement.to_filesystems() assert len(fss) == 1 assert fss[0].state == {"/usr/sbin/policy-rc.d": Nil()} -def test_delta_p_to_filesystems_if(): +def test_delta_p_to_filesystems_if() -> None: statement = delta_p_puppet_if fss = statement.to_filesystems() assert len(fss) == 2 @@ -130,7 +130,7 @@ def test_delta_p_to_filesystems_if(): assert fss[1].state == {"/usr/sbin/policy-rc.d": File(None, None, None)} -def test_delta_p_to_filesystems_default_state(): +def test_delta_p_to_filesystems_default_state() -> None: statement = delta_p_puppet_default_state fss = statement.to_filesystems() assert len(fss) == 1 diff --git a/glitch/tests/repair/interactive/test_delta_p_minimize.py b/glitch/tests/repair/interactive/test_delta_p_minimize.py index 1d64ec1e..47477d02 100644 --- a/glitch/tests/repair/interactive/test_delta_p_minimize.py +++ b/glitch/tests/repair/interactive/test_delta_p_minimize.py @@ -1,7 +1,7 @@ from glitch.repair.interactive.delta_p import * -def test_delta_p_minimize_let(): +def test_delta_p_minimize_let() -> None: statement = PLet( "x", "test1", @@ -13,7 +13,7 @@ def test_delta_p_minimize_let(): assert isinstance(minimized, PSkip) -def test_delta_p_minimize_seq(): +def test_delta_p_minimize_seq() -> None: statement = PSeq( PCreate(PEConst(const=PStr(value="test1"))), PCreate(PEConst(const=PStr(value="test2"))), @@ -31,7 +31,7 @@ def test_delta_p_minimize_seq(): assert isinstance(minimized, PSkip) -def test_delta_p_minimize_if(): +def test_delta_p_minimize_if() -> None: statement = PIf( PBool(True), PCreate(PEConst(const=PStr(value="test2"))), diff --git a/glitch/tests/repair/interactive/test_patch_solver.py b/glitch/tests/repair/interactive/test_patch_solver.py index e7f576fa..a8e7b23e 100644 --- a/glitch/tests/repair/interactive/test_patch_solver.py +++ b/glitch/tests/repair/interactive/test_patch_solver.py @@ -79,7 +79,7 @@ def setup_patch_solver( parser: Parser, script_type: UnitBlockType, tech: Tech, -): +) -> None: global labeled_script, statement DeltaPCompiler._condition = 0 with NamedTemporaryFile() as f: @@ -96,7 +96,7 @@ def patch_solver_apply( filesystem: FileSystemState, tech: Tech, n_filesystems: int = 1, -): +) -> None: solver.apply_patch(model, labeled_script) statement = DeltaPCompiler.compile(labeled_script, tech) filesystems = statement.to_filesystems() @@ -107,7 +107,7 @@ def patch_solver_apply( # TODO: Refactor tests -def test_patch_solver_if(): +def test_patch_solver_if() -> None: setup_patch_solver( puppet_script_4, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -133,7 +133,7 @@ def test_patch_solver_if(): patch_solver_apply(solver, models[0], filesystem, Tech.puppet, n_filesystems=2) -def test_patch_solver_mode(): +def test_patch_solver_mode() -> None: setup_patch_solver( puppet_script_1, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -163,7 +163,7 @@ def test_patch_solver_mode(): patch_solver_apply(solver, model, filesystem, Tech.puppet) -def test_patch_solver_owner(): +def test_patch_solver_owner() -> None: setup_patch_solver( puppet_script_2, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -186,7 +186,7 @@ def test_patch_solver_owner(): patch_solver_apply(solver, model, filesystem, Tech.puppet) -def test_patch_solver_two_files(): +def test_patch_solver_two_files() -> None: setup_patch_solver( puppet_script_3, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -202,7 +202,7 @@ def test_patch_solver_two_files(): patch_solver_apply(solver, model, filesystem, Tech.puppet) -def test_patch_solver_delete_file(): +def test_patch_solver_delete_file() -> None: setup_patch_solver( puppet_script_1, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -225,7 +225,7 @@ def test_patch_solver_delete_file(): patch_solver_apply(solver, model, filesystem, Tech.puppet) -def test_patch_solver_remove_content(): +def test_patch_solver_remove_content() -> None: setup_patch_solver( puppet_script_1, PuppetParser(), UnitBlockType.script, Tech.puppet ) @@ -250,7 +250,7 @@ def test_patch_solver_remove_content(): patch_solver_apply(solver, model, filesystem, Tech.puppet) -def test_patch_solver_mode_ansible(): +def test_patch_solver_mode_ansible() -> None: setup_patch_solver( ansible_script_1, AnsibleParser(), UnitBlockType.tasks, Tech.ansible ) @@ -277,7 +277,7 @@ def test_patch_solver_mode_ansible(): patch_solver_apply(solver, model, filesystem, Tech.ansible) -def test_patch_solver_new_attribute_difficult_name(): +def test_patch_solver_new_attribute_difficult_name() -> None: """ This test requires the solver to create a new attribute "state". However, the attribute "state" should be called "ensure" in Puppet, diff --git a/glitch/tests/repair/interactive/test_tracer_model.py b/glitch/tests/repair/interactive/test_tracer_model.py index a1cb950c..3144be83 100644 --- a/glitch/tests/repair/interactive/test_tracer_model.py +++ b/glitch/tests/repair/interactive/test_tracer_model.py @@ -2,7 +2,7 @@ from glitch.repair.interactive.tracer.parser import * -def test_tracer_model_rename(): +def test_tracer_model_rename() -> None: syscall = Syscall("rename", ["test", "test~"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SRename) @@ -10,7 +10,7 @@ def test_tracer_model_rename(): assert typed_syscall.dst == "test~" -def test_tracer_model_open(): +def test_tracer_model_open() -> None: syscall = Syscall("open", ["test", [OpenFlag.O_RDONLY]], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SOpen) @@ -18,7 +18,7 @@ def test_tracer_model_open(): assert typed_syscall.flags == [OpenFlag.O_RDONLY] -def test_tracer_model_openat(): +def test_tracer_model_openat() -> None: syscall = Syscall("openat", ["/", "test", [OpenFlag.O_RDONLY]], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SOpenAt) @@ -27,7 +27,7 @@ def test_tracer_model_openat(): assert typed_syscall.flags == [OpenFlag.O_RDONLY] -def test_tracer_model_stat(): +def test_tracer_model_stat() -> None: syscall = Syscall("stat", ["test", "0x7fffc2269490"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SStat) @@ -35,7 +35,7 @@ def test_tracer_model_stat(): assert typed_syscall.flags == "0x7fffc2269490" -def test_tracer_model_fstat(): +def test_tracer_model_fstat() -> None: syscall = Syscall("fstat", ["3", "0x7fffc2269490"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SFStat) @@ -43,7 +43,7 @@ def test_tracer_model_fstat(): assert typed_syscall.flags == "0x7fffc2269490" -def test_tracer_model_lstat(): +def test_tracer_model_lstat() -> None: syscall = Syscall("lstat", ["test", "0x7fffc2269490"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SStat) @@ -51,7 +51,7 @@ def test_tracer_model_lstat(): assert typed_syscall.flags == "0x7fffc2269490" -def test_tracer_model_newfstatat(): +def test_tracer_model_newfstatat() -> None: syscall = Syscall("newfstatat", ["1", "test", "0x7fffc2269490", "0"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SFStatAt) @@ -61,14 +61,14 @@ def test_tracer_model_newfstatat(): assert typed_syscall.oredFlags == "0" -def test_tracer_model_unlink(): +def test_tracer_model_unlink() -> None: syscall = Syscall("unlink", ["test"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SUnlink) assert typed_syscall.path == "test" -def test_tracer_model_unlinkat(): +def test_tracer_model_unlinkat() -> None: syscall = Syscall("unlinkat", ["1", "test", [UnlinkFlag.AT_REMOVEDIR]], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SUnlinkAt) @@ -77,7 +77,7 @@ def test_tracer_model_unlinkat(): assert typed_syscall.flags == [UnlinkFlag.AT_REMOVEDIR] -def test_tracer_model_mkdir(): +def test_tracer_model_mkdir() -> None: syscall = Syscall("mkdir", ["test", "0777"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SMkdir) @@ -85,7 +85,7 @@ def test_tracer_model_mkdir(): assert typed_syscall.mode == "0777" -def test_tracer_model_mkdirat(): +def test_tracer_model_mkdirat() -> None: syscall = Syscall("mkdirat", ["AT_FDCWD", "test", "0777"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SMkdirAt) @@ -94,14 +94,14 @@ def test_tracer_model_mkdirat(): assert typed_syscall.mode == "0777" -def test_tracer_model_rmdir(): +def test_tracer_model_rmdir() -> None: syscall = Syscall("rmdir", ["test"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SRmdir) assert typed_syscall.path == "test" -def test_tracer_model_chdir(): +def test_tracer_model_chdir() -> None: syscall = Syscall("chdir", ["test"], 0) typed_syscall = get_syscall_with_type(syscall) assert isinstance(typed_syscall, SChdir) diff --git a/glitch/tests/repair/interactive/test_tracer_parser.py b/glitch/tests/repair/interactive/test_tracer_parser.py index 168a0677..2777d4bd 100644 --- a/glitch/tests/repair/interactive/test_tracer_parser.py +++ b/glitch/tests/repair/interactive/test_tracer_parser.py @@ -1,7 +1,7 @@ from glitch.repair.interactive.tracer.parser import * -def test_tracer_parser_rename(): +def test_tracer_parser_rename() -> None: parsed = parse_tracer_output('[pid 18040] rename("test", "test~") = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "rename" @@ -10,7 +10,7 @@ def test_tracer_parser_rename(): assert parsed.exitCode == 0 -def test_tracer_parser_stat(): +def test_tracer_parser_stat() -> None: parsed = parse_tracer_output( '[pid 255] stat("/usr/share/vim/vimfiles/after/scripts.vim", 0x7fffc2269490) = -1 ENOENT (No such file or directory)' ) @@ -21,7 +21,7 @@ def test_tracer_parser_stat(): assert parsed.exitCode == -1 -def test_tracer_parser_open(): +def test_tracer_parser_open() -> None: parsed = parse_tracer_output( '[pid 255] open("/lib/x86_64-linux-gnu/libpthread.so.0", O_RDONLY|O_CLOEXEC) = 3' ) @@ -32,7 +32,7 @@ def test_tracer_parser_open(): assert parsed.exitCode == 3 -def test_tracer_parser_open_mode(): +def test_tracer_parser_open_mode() -> None: parsed = parse_tracer_output( '[pid 105] open("/var/lib/apt/extended_states.tmp", O_WRONLY|O_CREAT|O_TRUNC, 0666) = 25' ) @@ -44,7 +44,7 @@ def test_tracer_parser_open_mode(): assert parsed.exitCode == 25 -def test_tracer_parser_openat(): +def test_tracer_parser_openat() -> None: parsed = parse_tracer_output( '[pid 33096] openat(AT_FDCWD, "/usr/lib/python3/dist-packages/mercurial/__pycache__/error.cpython-310.pyc", O_RDONLY|O_CLOEXEC) = 3' ) @@ -59,7 +59,7 @@ def test_tracer_parser_openat(): assert parsed.exitCode == 3 -def test_tracer_parser_openat_mode(): +def test_tracer_parser_openat_mode() -> None: parsed = parse_tracer_output( '[pid 33096] openat(AT_FDCWD, "/usr/lib/python3/dist-packages/mercurial/__pycache__/error.cpython-310.pyc", O_RDONLY|O_CLOEXEC, 0666) = 3' ) @@ -75,7 +75,7 @@ def test_tracer_parser_openat_mode(): assert parsed.exitCode == 3 -def test_tracer_parser_newfstatat(): +def test_tracer_parser_newfstatat() -> None: parsed = parse_tracer_output( '[pid 33096] newfstatat(AT_FDCWD, "/usr/lib/python3/dist-packages/mercurial/error.py", {st_dev=makedev(0x103, 0x3), st_ino=14852531, st_mode=S_IFREG|0644, st_nlink=1, st_uid=0, st_gid=0, st_blksize=4096, st_blocks=24, st_size=8377, st_atime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_atime_nsec=917367753, st_mtime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_mtime_nsec=917367753, st_ctime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_ctime_nsec=917367753}, 0) = 0' ) @@ -104,7 +104,7 @@ def test_tracer_parser_newfstatat(): assert parsed.exitCode == 0 -def test_tracer_parser_newfstatat_empty_path(): +def test_tracer_parser_newfstatat_empty_path() -> None: parsed = parse_tracer_output( '[pid 33096] newfstatat(3, "", {st_dev=makedev(0x103, 0x3), st_ino=14852531, st_mode=S_IFREG|0644, st_nlink=1, st_uid=0, st_gid=0, st_blksize=4096, st_blocks=24, st_size=8377, st_atime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_atime_nsec=917367753, st_mtime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_mtime_nsec=917367753, st_ctime=1684251578 /* 2023-05-16T16:39:38.917367753+0100 */, st_ctime_nsec=917367753}, AT_EMPTY_PATH) = 0' ) @@ -133,7 +133,7 @@ def test_tracer_parser_newfstatat_empty_path(): assert parsed.exitCode == 0 -def test_tracer_parser_no_pid(): +def test_tracer_parser_no_pid() -> None: parsed = parse_tracer_output('openat(AT_FDCWD, "/dev/null", O_RDWR|O_NOCTTY) = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "openat" @@ -143,7 +143,7 @@ def test_tracer_parser_no_pid(): assert parsed.exitCode == 0 -def test_tracer_parser_write(): +def test_tracer_parser_write() -> None: parsed = parse_tracer_output('write(2, "o", 1) = 1') assert isinstance(parsed, Syscall) assert parsed.cmd == "write" @@ -151,7 +151,7 @@ def test_tracer_parser_write(): assert parsed.exitCode == 1 -def test_tracer_parser_execve(): +def test_tracer_parser_execve() -> None: parsed = parse_tracer_output( 'execve("/usr/bin/ls", ["ls"], ["SHELL=/bin/zsh", "LSCOLORS=Gxfxcxdxbxegedabagacad"]) = 0' ) @@ -165,7 +165,7 @@ def test_tracer_parser_execve(): assert parsed.exitCode == 0 -def test_tracer_parser_faccessat2(): +def test_tracer_parser_faccessat2() -> None: parsed = parse_tracer_output( '[pid 47072] faccessat2(AT_FDCWD, "/usr/lib/command-not-found", X_OK, AT_EACCESS) = 0' ) @@ -180,7 +180,7 @@ def test_tracer_parser_faccessat2(): assert parsed.exitCode == 0 -def test_tracer_parser_mkdirat(): +def test_tracer_parser_mkdirat() -> None: parsed = parse_tracer_output('[pid 36388] mkdirat(AT_FDCWD, "test23456", 0777) = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "mkdirat" @@ -188,7 +188,7 @@ def test_tracer_parser_mkdirat(): assert parsed.exitCode == 0 -def test_tracer_parser_mkdir(): +def test_tracer_parser_mkdir() -> None: parsed = parse_tracer_output('[pid 36388] mkdir("test23456", 0777) = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "mkdir" @@ -196,7 +196,7 @@ def test_tracer_parser_mkdir(): assert parsed.exitCode == 0 -def test_tracer_parser_rmdir(): +def test_tracer_parser_rmdir() -> None: parsed = parse_tracer_output('[pid 36152] rmdir("test23456") = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "rmdir" @@ -204,7 +204,7 @@ def test_tracer_parser_rmdir(): assert parsed.exitCode == 0 -def test_tracer_parser_unlink(): +def test_tracer_parser_unlink() -> None: parsed = parse_tracer_output( '[pid 47072] unlink("/home/test/.zsh_history.LOCK") = 0' ) @@ -214,7 +214,7 @@ def test_tracer_parser_unlink(): assert parsed.exitCode == 0 -def test_tracer_parser_unlinkat(): +def test_tracer_parser_unlinkat() -> None: parsed = parse_tracer_output( '[pid 32850] unlinkat(AT_FDCWD, "test234", AT_REMOVEDIR) = 0' ) @@ -224,7 +224,7 @@ def test_tracer_parser_unlinkat(): assert parsed.exitCode == 0 -def test_tracer_parser_unlinkat_0(): +def test_tracer_parser_unlinkat_0() -> None: parsed = parse_tracer_output('[pid 32850] unlinkat(AT_FDCWD, "test234", 0) = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "unlinkat" @@ -232,7 +232,7 @@ def test_tracer_parser_unlinkat_0(): assert parsed.exitCode == 0 -def test_tracer_parser_chdir(): +def test_tracer_parser_chdir() -> None: parsed = parse_tracer_output('[pid 32850] chdir("/home/test") = 0') assert isinstance(parsed, Syscall) assert parsed.cmd == "chdir" diff --git a/glitch/tests/repair/interactive/test_tracer_transform.py b/glitch/tests/repair/interactive/test_tracer_transform.py index 33f809ce..da4b4f58 100644 --- a/glitch/tests/repair/interactive/test_tracer_transform.py +++ b/glitch/tests/repair/interactive/test_tracer_transform.py @@ -11,7 +11,7 @@ from glitch.repair.interactive.filesystem import * -def test_get_affected_paths(): +def test_get_affected_paths() -> None: sys_calls = [ SOpen("open", ["file1", [OpenFlag.O_WRONLY]], 0), SOpenAt("openat", ["0", "file2", [OpenFlag.O_WRONLY]], 0), @@ -63,7 +63,7 @@ def teardown_file_system(): shutil.rmtree(temp_dir.name) -def test_get_file_system_state(setup_file_system, teardown_file_system): +def test_get_file_system_state(setup_file_system, teardown_file_system) -> None: file4 = os.path.join(dir1, "file4") files = {dir1, file2, file3, file4} diff --git a/glitch/tests/security/ansible/test_security.py b/glitch/tests/security/ansible/test_security.py index ed995e66..7778f3ad 100644 --- a/glitch/tests/security/ansible/test_security.py +++ b/glitch/tests/security/ansible/test_security.py @@ -6,7 +6,7 @@ class TestSecurity(unittest.TestCase): - def __help_test(self, path, type, n_errors, codes, lines): + def __help_test(self, path, type, n_errors: int, codes, lines) -> None: parser = AnsibleParser() inter = parser.parse(path, type, False) analysis = SecurityVisitor(Tech.ansible) @@ -20,17 +20,17 @@ def __help_test(self, path, type, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_ansible_http(self): + def test_ansible_http(self) -> None: self.__help_test( "tests/security/ansible/files/http.yml", "tasks", 1, ["sec_https"], [4] ) - def test_ansible_susp_comment(self): + def test_ansible_susp_comment(self) -> None: self.__help_test( "tests/security/ansible/files/susp.yml", "vars", 1, ["sec_susp_comm"], [9] ) - def test_ansible_def_admin(self): + def test_ansible_def_admin(self) -> None: self.__help_test( "tests/security/ansible/files/admin.yml", "tasks", @@ -39,7 +39,7 @@ def test_ansible_def_admin(self): [3, 3, 3], ) - def test_ansible_empt_pass(self): + def test_ansible_empt_pass(self) -> None: self.__help_test( "tests/security/ansible/files/empty.yml", "tasks", @@ -48,7 +48,7 @@ def test_ansible_empt_pass(self): [8], ) - def test_ansible_weak_crypt(self): + def test_ansible_weak_crypt(self) -> None: self.__help_test( "tests/security/ansible/files/weak_crypt.yml", "tasks", @@ -57,7 +57,7 @@ def test_ansible_weak_crypt(self): [4, 7], ) - def test_ansible_hard_secr(self): + def test_ansible_hard_secr(self) -> None: self.__help_test( "tests/security/ansible/files/hard_secr.yml", "tasks", @@ -66,7 +66,7 @@ def test_ansible_hard_secr(self): [7, 7, 8, 8], ) - def test_ansible_invalid_bind(self): + def test_ansible_invalid_bind(self) -> None: self.__help_test( "tests/security/ansible/files/inv_bind.yml", "tasks", @@ -75,7 +75,7 @@ def test_ansible_invalid_bind(self): [7], ) - def test_ansible_int_check(self): + def test_ansible_int_check(self) -> None: self.__help_test( "tests/security/ansible/files/int_check.yml", "tasks", @@ -84,7 +84,7 @@ def test_ansible_int_check(self): [5], ) - def test_ansible_full_perm(self): + def test_ansible_full_perm(self) -> None: self.__help_test( "tests/security/ansible/files/full_permission.yml", "tasks", @@ -93,7 +93,7 @@ def test_ansible_full_perm(self): [7], ) - def test_ansible_obs_command(self): + def test_ansible_obs_command(self) -> None: self.__help_test( "tests/security/ansible/files/obs_command.yml", "tasks", diff --git a/glitch/tests/security/chef/test_security.py b/glitch/tests/security/chef/test_security.py index 9de30689..d2d8e65a 100644 --- a/glitch/tests/security/chef/test_security.py +++ b/glitch/tests/security/chef/test_security.py @@ -6,7 +6,7 @@ class TestSecurity(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = ChefParser() inter = parser.parse(path, "script", False) analysis = SecurityVisitor(Tech.chef) @@ -20,13 +20,13 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_chef_http(self): + def test_chef_http(self) -> None: self.__help_test("tests/security/chef/files/http.rb", 1, ["sec_https"], [3]) - def test_chef_susp_comment(self): + def test_chef_susp_comment(self) -> None: self.__help_test("tests/security/chef/files/susp.rb", 1, ["sec_susp_comm"], [1]) - def test_chef_def_admin(self): + def test_chef_def_admin(self) -> None: self.__help_test( "tests/security/chef/files/admin.rb", 3, @@ -34,17 +34,17 @@ def test_chef_def_admin(self): [8, 8, 8], ) - def test_chef_empt_pass(self): + def test_chef_empt_pass(self) -> None: self.__help_test( "tests/security/chef/files/empty.rb", 1, ["sec_empty_pass"], [1] ) - def test_chef_weak_crypt(self): + def test_chef_weak_crypt(self) -> None: self.__help_test( "tests/security/chef/files/weak_crypt.rb", 1, ["sec_weak_crypt"], [4] ) - def test_chef_hard_secr(self): + def test_chef_hard_secr(self) -> None: self.__help_test( "tests/security/chef/files/hard_secr.rb", 2, @@ -52,17 +52,17 @@ def test_chef_hard_secr(self): [8, 8], ) - def test_chef_invalid_bind(self): + def test_chef_invalid_bind(self) -> None: self.__help_test( "tests/security/chef/files/inv_bind.rb", 1, ["sec_invalid_bind"], [7] ) - def test_chef_int_check(self): + def test_chef_int_check(self) -> None: self.__help_test( "tests/security/chef/files/int_check.rb", 1, ["sec_no_int_check"], [1] ) - def test_chef_missing_default(self): + def test_chef_missing_default(self) -> None: self.__help_test( "tests/security/chef/files/missing_default.rb", 1, @@ -70,7 +70,7 @@ def test_chef_missing_default(self): [2], ) - def test_chef_full_permission(self): + def test_chef_full_permission(self) -> None: self.__help_test( "tests/security/chef/files/full_permission.rb", 1, @@ -78,7 +78,7 @@ def test_chef_full_permission(self): [3], ) - def test_chef_obs_command(self): + def test_chef_obs_command(self) -> None: self.__help_test( "tests/security/chef/files/obs_command.rb", 1, ["sec_obsolete_command"], [2] ) diff --git a/glitch/tests/security/docker/test_security.py b/glitch/tests/security/docker/test_security.py index 960a740a..33d90dee 100644 --- a/glitch/tests/security/docker/test_security.py +++ b/glitch/tests/security/docker/test_security.py @@ -8,7 +8,7 @@ class TestSecurity(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = DockerParser() inter = parser.parse(path, UnitBlockType.script, False) analysis = SecurityVisitor(Tech.docker) @@ -27,7 +27,7 @@ def tearDown(self) -> None: if os.path.exists("Dockerfile"): os.remove("Dockerfile") - def test_docker_admin(self): + def test_docker_admin(self) -> None: self.__help_test( "tests/security/docker/files/admin.Dockerfile", 2, @@ -35,13 +35,13 @@ def test_docker_admin(self): [2, 4], ) - def test_docker_empty(self): + def test_docker_empty(self) -> None: self.__help_test( "tests/security/docker/files/empty.Dockerfile", 1, ["sec_empty_pass"], [4] ) pass - def test_docker_full_permission(self): + def test_docker_full_permission(self) -> None: self.__help_test( "tests/security/docker/files/full_permission.Dockerfile", 1, @@ -49,7 +49,7 @@ def test_docker_full_permission(self): [3], ) - def test_docker_hard_secret(self): + def test_docker_hard_secret(self) -> None: self.__help_test( "tests/security/docker/files/hard_secr.Dockerfile", 2, @@ -57,12 +57,12 @@ def test_docker_hard_secret(self): [3, 3], ) - def test_docker_http(self): + def test_docker_http(self) -> None: self.__help_test( "tests/security/docker/files/http.Dockerfile", 1, ["sec_https"], [5] ) - def test_docker_int_check(self): + def test_docker_int_check(self) -> None: self.__help_test( "tests/security/docker/files/int_check.Dockerfile", 1, @@ -70,7 +70,7 @@ def test_docker_int_check(self): [4], ) - def test_docker_inv_bind(self): + def test_docker_inv_bind(self) -> None: self.__help_test( "tests/security/docker/files/inv_bind.Dockerfile", 1, @@ -78,7 +78,7 @@ def test_docker_inv_bind(self): [4], ) - def test_docker_non_official_image(self): + def test_docker_non_official_image(self) -> None: self.__help_test( "tests/security/docker/files/non_off_image.Dockerfile", 1, @@ -86,7 +86,7 @@ def test_docker_non_official_image(self): [1], ) - def test_docker_obs_command(self): + def test_docker_obs_command(self) -> None: self.__help_test( "tests/security/docker/files/obs_command.Dockerfile", 1, @@ -94,12 +94,12 @@ def test_docker_obs_command(self): [4], ) - def test_docker_susp(self): + def test_docker_susp(self) -> None: self.__help_test( "tests/security/docker/files/susp.Dockerfile", 1, ["sec_susp_comm"], [3] ) - def test_docker_weak_crypt(self): + def test_docker_weak_crypt(self) -> None: self.__help_test( "tests/security/docker/files/weak_crypt.Dockerfile", 1, diff --git a/glitch/tests/security/puppet/test_security.py b/glitch/tests/security/puppet/test_security.py index 199c2e9b..918ebf17 100644 --- a/glitch/tests/security/puppet/test_security.py +++ b/glitch/tests/security/puppet/test_security.py @@ -6,7 +6,7 @@ class TestSecurity(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = PuppetParser() inter = parser.parse(path, "script", False) analysis = SecurityVisitor(Tech.puppet) @@ -20,15 +20,15 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].code, codes[i]) self.assertEqual(errors[i].line, lines[i]) - def test_puppet_http(self): + def test_puppet_http(self) -> None: self.__help_test("tests/security/puppet/files/http.pp", 1, ["sec_https"], [2]) - def test_puppet_susp_comment(self): + def test_puppet_susp_comment(self) -> None: self.__help_test( "tests/security/puppet/files/susp.pp", 1, ["sec_susp_comm"], [19] ) - def test_puppet_def_admin(self): + def test_puppet_def_admin(self) -> None: self.__help_test( "tests/security/puppet/files/admin.pp", 3, @@ -36,17 +36,17 @@ def test_puppet_def_admin(self): [7, 7, 7], ) - def test_puppet_empt_pass(self): + def test_puppet_empt_pass(self) -> None: self.__help_test( "tests/security/puppet/files/empty.pp", 1, ["sec_empty_pass"], [1] ) - def test_puppet_weak_crypt(self): + def test_puppet_weak_crypt(self) -> None: self.__help_test( "tests/security/puppet/files/weak_crypt.pp", 1, ["sec_weak_crypt"], [12] ) - def test_puppet_hard_secr(self): + def test_puppet_hard_secr(self) -> None: self.__help_test( "tests/security/puppet/files/hard_secr.pp", 2, @@ -54,17 +54,17 @@ def test_puppet_hard_secr(self): [2, 2], ) - def test_puppet_invalid_bind(self): + def test_puppet_invalid_bind(self) -> None: self.__help_test( "tests/security/puppet/files/inv_bind.pp", 1, ["sec_invalid_bind"], [12] ) - def test_puppet_int_check(self): + def test_puppet_int_check(self) -> None: self.__help_test( "tests/security/puppet/files/int_check.pp", 1, ["sec_no_int_check"], [5] ) - def test_puppet_missing_default(self): + def test_puppet_missing_default(self) -> None: self.__help_test( "tests/security/puppet/files/missing_default.pp", 2, @@ -72,7 +72,7 @@ def test_puppet_missing_default(self): [2, 7], ) - def test_puppet_full_perm(self): + def test_puppet_full_perm(self) -> None: self.__help_test( "tests/security/puppet/files/full_permission.pp", 1, @@ -80,7 +80,7 @@ def test_puppet_full_perm(self): [4], ) - def test_puppet_obs_command(self): + def test_puppet_obs_command(self) -> None: self.__help_test( "tests/security/puppet/files/obs_command.pp", 1, diff --git a/glitch/tests/security/terraform/test_security.py b/glitch/tests/security/terraform/test_security.py index cb89d44c..d3b11dcf 100644 --- a/glitch/tests/security/terraform/test_security.py +++ b/glitch/tests/security/terraform/test_security.py @@ -6,7 +6,7 @@ class TestSecurity(unittest.TestCase): - def __help_test(self, path, n_errors, codes, lines): + def __help_test(self, path, n_errors: int, codes, lines) -> None: parser = TerraformParser() inter = parser.parse(path, "script", False) analysis = SecurityVisitor(Tech.terraform) @@ -21,17 +21,17 @@ def __help_test(self, path, n_errors, codes, lines): self.assertEqual(errors[i].line, lines[i]) # testing previous implemented code smells - def test_terraform_http(self): + def test_terraform_http(self) -> None: self.__help_test( "tests/security/terraform/files/http.tf", 1, ["sec_https"], [2] ) - def test_terraform_susp_comment(self): + def test_terraform_susp_comment(self) -> None: self.__help_test( "tests/security/terraform/files/susp.tf", 1, ["sec_susp_comm"], [8] ) - def test_terraform_def_admin(self): + def test_terraform_def_admin(self) -> None: self.__help_test( "tests/security/terraform/files/admin.tf", 3, @@ -39,17 +39,17 @@ def test_terraform_def_admin(self): [2, 2, 2], ) - def test_terraform_empt_pass(self): + def test_terraform_empt_pass(self) -> None: self.__help_test( "tests/security/terraform/files/empty.tf", 1, ["sec_empty_pass"], [5] ) - def test_terraform_weak_crypt(self): + def test_terraform_weak_crypt(self) -> None: self.__help_test( "tests/security/terraform/files/weak_crypt.tf", 1, ["sec_weak_crypt"], [4] ) - def test_terraform_hard_secr(self): + def test_terraform_hard_secr(self) -> None: self.__help_test( "tests/security/terraform/files/hard_secr.tf", 2, @@ -57,14 +57,14 @@ def test_terraform_hard_secr(self): [5, 5], ) - def test_terraform_invalid_bind(self): + def test_terraform_invalid_bind(self) -> None: self.__help_test( "tests/security/terraform/files/inv_bind.tf", 1, ["sec_invalid_bind"], [19] ) # testing new implemented code smells, or previous ones with new rules for Terraform - def test_terraform_insecure_access_control(self): + def test_terraform_insecure_access_control(self) -> None: self.__help_test( "tests/security/terraform/files/insecure-access-control/access-to-bigquery-dataset.tf", 1, @@ -246,7 +246,7 @@ def test_terraform_insecure_access_control(self): [37, 44], ) - def test_terraform_invalid_ip_binding(self): + def test_terraform_invalid_ip_binding(self) -> None: self.__help_test( "tests/security/terraform/files/invalid-ip-binding/aws-ec2-vpc-no-public-egress-sgr.tf", 2, @@ -332,7 +332,7 @@ def test_terraform_invalid_ip_binding(self): [27], ) - def test_terraform_disabled_authentication(self): + def test_terraform_disabled_authentication(self) -> None: self.__help_test( "tests/security/terraform/files/disabled-authentication/azure-app-service-authentication-activated.tf", 2, @@ -364,7 +364,7 @@ def test_terraform_disabled_authentication(self): [7, 53], ) - def test_terraform_missing_encryption(self): + def test_terraform_missing_encryption(self) -> None: self.__help_test( "tests/security/terraform/files/missing-encryption/athena-enable-at-rest-encryption.tf", 2, @@ -556,7 +556,7 @@ def test_terraform_missing_encryption(self): [1, 1, 4, 8, 13, 14], ) - def test_terraform_hard_coded_secrets(self): + def test_terraform_hard_coded_secrets(self) -> None: self.__help_test( "tests/security/terraform/files/hard-coded-secrets/encryption-key-in-plaintext.tf", 1, @@ -606,7 +606,7 @@ def test_terraform_hard_coded_secrets(self): [9], ) - def test_terraform_public_ip(self): + def test_terraform_public_ip(self) -> None: self.__help_test( "tests/security/terraform/files/public-ip/google-compute-intance-with-public-ip.tf", 1, @@ -632,7 +632,7 @@ def test_terraform_public_ip(self): [3], ) - def test_terraform_use_of_http_without_tls(self): + def test_terraform_use_of_http_without_tls(self) -> None: self.__help_test( "tests/security/terraform/files/use-of-http-without-tls/azure-appservice-enforce-https.tf", 2, @@ -676,7 +676,7 @@ def test_terraform_use_of_http_without_tls(self): [8], ) - def test_terraform_ssl_tls_mtls_policy(self): + def test_terraform_ssl_tls_mtls_policy(self) -> None: self.__help_test( "tests/security/terraform/files/ssl-tls-mtls-policy/api-gateway-secure-tls-policy.tf", 2, @@ -744,7 +744,7 @@ def test_terraform_ssl_tls_mtls_policy(self): [1, 45], ) - def test_terraform_use_of_dns_without_dnssec(self): + def test_terraform_use_of_dns_without_dnssec(self) -> None: self.__help_test( "tests/security/terraform/files/use-of-dns-without-dnssec/cloud-dns-without-dnssec.tf", 2, @@ -752,7 +752,7 @@ def test_terraform_use_of_dns_without_dnssec(self): [1, 6], ) - def test_terraform_firewall_misconfiguration(self): + def test_terraform_firewall_misconfiguration(self) -> None: self.__help_test( "tests/security/terraform/files/firewall-misconfiguration/alb-drop-invalid-headers.tf", 2, @@ -816,7 +816,7 @@ def test_terraform_firewall_misconfiguration(self): [1, 1, 10], ) - def test_terraform_missing_threats_detection_and_alerts(self): + def test_terraform_missing_threats_detection_and_alerts(self) -> None: self.__help_test( "tests/security/terraform/files/missing-threats-detection-and-alerts/azure-database-disabled-alerts.tf", 1, @@ -860,7 +860,7 @@ def test_terraform_missing_threats_detection_and_alerts(self): [1, 19], ) - def test_terraform_weak_password_key_policy(self): + def test_terraform_weak_password_key_policy(self) -> None: self.__help_test( "tests/security/terraform/files/weak-password-key-policy/aws-iam-no-password-reuse.tf", 2, @@ -922,7 +922,7 @@ def test_terraform_weak_password_key_policy(self): [1, 5], ) - def test_terraform_integrity_policy(self): + def test_terraform_integrity_policy(self) -> None: self.__help_test( "tests/security/terraform/files/integrity-policy/aws-ecr-immutable-repo.tf", 2, @@ -942,7 +942,7 @@ def test_terraform_integrity_policy(self): [3], ) - def test_terraform_sensitive_action_by_iam(self): + def test_terraform_sensitive_action_by_iam(self) -> None: self.__help_test( "tests/security/terraform/files/sensitive-action-by-iam/aws-iam-no-policy-wildcards.tf", 3, @@ -954,7 +954,7 @@ def test_terraform_sensitive_action_by_iam(self): [7, 8, 20], ) - def test_terraform_key_management(self): + def test_terraform_key_management(self) -> None: self.__help_test( "tests/security/terraform/files/key-management/aws-cloudtrail-encryption-use-cmk.tf", 2, @@ -1100,7 +1100,7 @@ def test_terraform_key_management(self): [1, 8], ) - def test_terraform_network_security_rules(self): + def test_terraform_network_security_rules(self) -> None: self.__help_test( "tests/security/terraform/files/network-security-rules/aws-vpc-ec2-use-tcp.tf", 1, @@ -1166,7 +1166,7 @@ def test_terraform_network_security_rules(self): [1, 5], ) - def test_terraform_permission_of_iam_policies(self): + def test_terraform_permission_of_iam_policies(self) -> None: self.__help_test( "tests/security/terraform/files/permission-of-iam-policies/default-service-account-not-used-at-folder-level.tf", 2, @@ -1216,7 +1216,7 @@ def test_terraform_permission_of_iam_policies(self): [7], ) - def test_terraform_logging(self): + def test_terraform_logging(self) -> None: self.__help_test( "tests/security/terraform/files/logging/aws-api-gateway-enable-access-logging.tf", 4, @@ -1430,7 +1430,7 @@ def test_terraform_logging(self): [11], ) - def test_terraform_attached_resource(self): + def test_terraform_attached_resource(self) -> None: self.__help_test( "tests/security/terraform/files/attached-resource/aws_route53_attached_resource.tf", 2, @@ -1438,7 +1438,7 @@ def test_terraform_attached_resource(self): [12, 16], ) - def test_terraform_versioning(self): + def test_terraform_versioning(self) -> None: self.__help_test( "tests/security/terraform/files/versioning/aws-s3-enable-versioning.tf", 2, @@ -1452,7 +1452,7 @@ def test_terraform_versioning(self): [1, 7], ) - def test_terraform_naming(self): + def test_terraform_naming(self) -> None: self.__help_test( "tests/security/terraform/files/naming/aws-ec2-description-to-security-group-rule.tf", 2, @@ -1490,7 +1490,7 @@ def test_terraform_naming(self): [1, 19], ) - def test_terraform_replication(self): + def test_terraform_replication(self) -> None: self.__help_test( "tests/security/terraform/files/replication/s3-bucket-cross-region-replication.tf", 2, diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 41874c3b..00000000 --- a/mypy.ini +++ /dev/null @@ -1,4 +0,0 @@ -[mypy] -python_version = 3.9 -warn_return_any = True -warn_unused_configs = True \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 28515d7b..d7bafcc1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -549,16 +549,17 @@ tests = ["pytest", "pytest-cov", "pytest-lazy-fixture"] [[package]] name = "puppetparser" -version = "0.2.1" +version = "0.2.4" description = "A parser from Puppet to an object model" optional = false -python-versions = ">=3.9, <4" +python-versions = "<4.0,>=3.9" files = [ - {file = "puppetparser-0.2.1.tar.gz", hash = "sha256:e8b50e47fce93529ec2c4dfab36cfea8889a6cd3daa681c6fcc27a36ef0bd198"}, + {file = "puppetparser-0.2.4-py3-none-any.whl", hash = "sha256:3c625c7f8826f705b61c21aef5b59a759dd0ba79e72779ec9e664f900d8c713c"}, + {file = "puppetparser-0.2.4.tar.gz", hash = "sha256:2db6273653e94f018582aa87da2780a5b7e2b2320cfa1485e312e4a16029accd"}, ] [package.dependencies] -ply = "*" +ply = "3.11" [[package]] name = "pytest" @@ -894,4 +895,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2882412fa633629315acba716a2d0a7fc1798aa69d664bb46d8e8aef4fb0f527" +content-hash = "451b6f64e0d301878a4d461f4d0811289f06267e213e76dcef72d9bf94caeb3f" diff --git a/pyproject.toml b/pyproject.toml index 554604d6..4ce83c6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ alive-progress = "3.0.1" prettytable = "3.6.0" pandas = "1.5.3" configparser = "5.3.0" -puppetparser = "0.2.1" +puppetparser = "0.2.4" Jinja2 = "3.1.2" glitch-python-hcl2 = "0.1.4" dockerfile-parse = "2.0.0" @@ -33,3 +33,8 @@ glitch = "glitch.__main__:main" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.pyright] +typeCheckingMode = "strict" +stubPath = "stubs" +exclude = ["glitch/tests", ".venv", "scripts/"] \ No newline at end of file diff --git a/stubs/ply/lex.pyi b/stubs/ply/lex.pyi new file mode 100644 index 00000000..7bac93db --- /dev/null +++ b/stubs/ply/lex.pyi @@ -0,0 +1,27 @@ +import re +from typing import Any + +def lex( + module: Any = None, + object: Any = None, + debug: bool = False, + optimize: bool = False, + lextab: str = "lextab", + reflags: int = int(re.VERBOSE), + nowarn: bool = False, + outputdir: str | None = None, + debuglog: str | None = None, + errorlog: str | None = None, +) -> Any: ... + +class Lexer: + lineno: int + + def begin(self, state: str) -> None: ... + def skip(self, n: int) -> None: ... + +class LexToken: + value: str + lexer: Lexer + lexpos: int + type: str diff --git a/stubs/ply/yacc.pyi b/stubs/ply/yacc.pyi new file mode 100644 index 00000000..0ab54be7 --- /dev/null +++ b/stubs/ply/yacc.pyi @@ -0,0 +1,26 @@ +from typing import Any + +def yacc( + method: str = "LALR", + debug: bool = True, + module: Any = None, + tabmodule: str = "parsetab", + start: Any = None, + check_recursion: bool = True, + optimize: bool = False, + write_tables: bool = True, + debugfile: str = "parser.out", + outputdir: str | None = None, + debuglog: str | None = None, + errorlog: str | None = None, + picklefile: str | None = None, +) -> Any: ... + +class YaccProduction: + value: str + + def __getitem__(self, n: int) -> Any: ... + def __setitem__(self, n: int, v: Any) -> Any: ... + def lineno(self, n: int) -> int: ... + def lexpos(self, n: int) -> int: ... + def set_lineno(self, n: int, lineno: int) -> None: ... diff --git a/stubs/ruamel/__init__.pyi b/stubs/ruamel/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/stubs/ruamel/yaml/__init__.pyi b/stubs/ruamel/yaml/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/stubs/ruamel/yaml/error.pyi b/stubs/ruamel/yaml/error.pyi new file mode 100644 index 00000000..d1ebf5be --- /dev/null +++ b/stubs/ruamel/yaml/error.pyi @@ -0,0 +1,3 @@ +class StreamMark: + line: int + column: int diff --git a/stubs/ruamel/yaml/main.pyi b/stubs/ruamel/yaml/main.pyi new file mode 100644 index 00000000..aa982061 --- /dev/null +++ b/stubs/ruamel/yaml/main.pyi @@ -0,0 +1,5 @@ +from ruamel.yaml.nodes import Node +from typing import TextIO + +class YAML: + def compose(self, stream: str | TextIO) -> Node | None: ... diff --git a/stubs/ruamel/yaml/nodes.pyi b/stubs/ruamel/yaml/nodes.pyi new file mode 100644 index 00000000..30826046 --- /dev/null +++ b/stubs/ruamel/yaml/nodes.pyi @@ -0,0 +1,16 @@ +from ruamel.yaml.tokens import Token +from typing import List, Union, Optional, Any +from ruamel.yaml.error import StreamMark + +RecursiveTokenList = List[Union[Token, "RecursiveTokenList", None]] + +class Node: + comment: Optional[RecursiveTokenList] + value: Any + start_mark: StreamMark + end_mark: StreamMark + +class ScalarNode(Node): ... +class MappingNode(Node): ... +class SequenceNode(Node): ... +class CollectionNode(Node): ... diff --git a/stubs/ruamel/yaml/tokens.pyi b/stubs/ruamel/yaml/tokens.pyi new file mode 100644 index 00000000..d1f43778 --- /dev/null +++ b/stubs/ruamel/yaml/tokens.pyi @@ -0,0 +1,11 @@ +from ruamel.yaml.error import StreamMark + +class Token: + start_mark: StreamMark + end_mark: StreamMark + + @property + def column(self) -> int: ... + +class CommentToken(Token): + value: str diff --git a/stubs/z3.pyi b/stubs/z3.pyi new file mode 100644 index 00000000..457e6c29 --- /dev/null +++ b/stubs/z3.pyi @@ -0,0 +1,52 @@ +from typing import List + +class CheckSatResult: ... + +sat: CheckSatResult + +class Z3PPObject: ... +class AstRef(Z3PPObject): ... +class FuncDeclRef(AstRef): ... + +class ModelRef(Z3PPObject): + def __getitem__(self, idx: AstRef) -> FuncDeclRef: ... + +class ExprRef(AstRef): + def __eq__(self, __value: object) -> BoolRef: ... # type: ignore + +class BoolRef(ExprRef): ... + +class ArithRef(ExprRef): + def __ge__(self, __value: object) -> BoolRef: ... + +class IntNumRef(ArithRef): ... +class SeqRef(ExprRef): ... +class Context: ... +class Tactic: ... +class Probe: ... + +class Solver: + def __init__( + self, + solver: Solver | None = None, + ctx: Context | None = None, + logFile: str | None = None, + ) -> None: ... + def add(self, *args: Z3PPObject) -> None: ... + def push(self) -> None: ... + def pop(self) -> None: ... + def check(self) -> CheckSatResult: ... + def model(self) -> ModelRef: ... + +def If( + a: Probe | Z3PPObject, b: Z3PPObject, c: Z3PPObject, ctx: Context | None = None +) -> ExprRef: ... +def StringVal(s: str, ctx: Context | None = None) -> SeqRef: ... +def String(name: str, ctx: Context | None = None) -> SeqRef: ... +def Int(name: str, ctx: Context | None = None) -> ArithRef: ... +def IntVal(val: int, ctx: Context | None = None) -> IntNumRef: ... +def Bool(name: str, ctx: Context | None = None) -> BoolRef: ... +def And(*args: Z3PPObject | List[Z3PPObject]) -> BoolRef: ... +def Or(*args: Z3PPObject | List[Z3PPObject]) -> BoolRef: ... +def Not(a: Z3PPObject) -> BoolRef: ... +def Sum(*args: Z3PPObject | int | List[Z3PPObject | int]) -> ArithRef: ...