diff --git a/.gitignore b/.gitignore index 2623317ef4..d42a3db34c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,9 +10,6 @@ __pycache__/ # C extensions *.so -# .csv files -*.csv - # Distribution / packaging .Python build/ diff --git a/keep/api/bl/enrichments_bl.py b/keep/api/bl/enrichments_bl.py index 74a4afd9eb..49303631f8 100644 --- a/keep/api/bl/enrichments_bl.py +++ b/keep/api/bl/enrichments_bl.py @@ -11,9 +11,11 @@ import json5 from fastapi import HTTPException from sqlalchemy import func +from sqlalchemy.orm import defer from sqlalchemy_utils import UUIDType from sqlmodel import Session, select +from keep.api.bl.mapping_rule_matcher import MappingRuleMatcher from keep.api.core.config import config from keep.api.core.db import batch_enrich, get_incidents_by_alert_fingerprint from keep.api.core.db import enrich_entity as enrich_alert_db, get_last_alert_by_fingerprint, \ @@ -101,7 +103,7 @@ def __init__(self, tenant_id: str, db: Session | None = None): self.db_session = None self.elastic_client = None - def run_mapping_rule_by_id(self, rule_id: int, alert_id: UUID) -> AlertDto: + def run_mapping_rule_by_id(self, rule_id: int, alert_id: UUID) -> bool: rule = get_mapping_rule_by_id(self.tenant_id, rule_id, session=self.db_session) if not rule: raise HTTPException(status_code=404, detail="Mapping rule not found") @@ -316,6 +318,7 @@ def run_mapping_rules(self, alert: AlertDto) -> AlertDto: self.db_session.query(MappingRule) .filter(MappingRule.tenant_id == self.tenant_id) .filter(MappingRule.disabled == False) + .options(defer(MappingRule.rows)) .order_by(MappingRule.priority.desc()) .all() ) @@ -419,14 +422,37 @@ def check_if_match_and_enrich(self, alert: AlertDto, rule: MappingRule) -> bool: enrichments.pop("id", None) elif rule.type == "csv": if not rule.is_multi_level: - for row in rule.rows: - if any( - self._check_matcher(alert, row, matcher) - for matcher in rule.matchers + # Create an alert values dictionary for SQL-based matching + alert_dict = self._extract_alert_values_for_matchers( + alert, rule.matchers + ) + + # Use direct SQL-based matching to find matching row + try: + dialect_name = None + if ( + self.db_session + and hasattr(self.db_session, "bind") + and self.db_session.bind ): + dialect_name = self.db_session.bind.dialect.name + + matcher = MappingRuleMatcher( + dialect_name=dialect_name, session=self.db_session + ) + + self._add_enrichment_log( + f"Using SQL-based matching with dialect: {dialect_name or 'fallback'}", + "debug", + {"rule_id": rule.id}, + ) + + matched_row = matcher.get_matching_row(rule, alert_dict) + + if matched_row: # Extract enrichments from the matched row enrichments = {} - for key, value in row.items(): + for key, value in matched_row.items(): if value is not None: is_matcher = False for matcher in rule.matchers: @@ -439,7 +465,33 @@ def check_if_match_and_enrich(self, alert: AlertDto, rule: MappingRule) -> bool: if isinstance(value, str): value = value.strip() enrichments[key.strip()] = value - break + except Exception as e: + self._add_enrichment_log( + f"Error using SQL matcher, falling back to in-memory iteration: {str(e)}", + "warning", + {"rule_id": rule.id}, + ) + # Fall back to the original in-memory matching + if rule.rows: + for row in rule.rows: + if any( + self._check_matcher(alert, row, matcher) + for matcher in rule.matchers + ): + # Extract enrichments from the matched row + enrichments = {} + for key, value in row.items(): + if value is not None: + is_matcher = False + for matcher in rule.matchers: + if key in matcher: + is_matcher = True + break + if not is_matcher: + if isinstance(value, str): + value = value.strip() + enrichments[key.strip()] = value + break else: # Multi-level mapping # We can assume that the matcher is only a single key. i.e., [['customers']] @@ -451,23 +503,93 @@ def check_if_match_and_enrich(self, alert: AlertDto, rule: MappingRule) -> bool: else: if isinstance(matcher_values, str): matcher_values = json5.loads(matcher_values) - for matcher in matcher_values: - if rule.prefix_to_remove: - matcher = matcher.replace(rule.prefix_to_remove, "") - for row in rule.rows: - if self._check_explicit_match(row, key, matcher): - if rule.new_property_name not in enrichments: - enrichments[rule.new_property_name] = {} - - if matcher not in enrichments[rule.new_property_name]: - enrichments[rule.new_property_name][matcher] = {} - - for enrichment_key, enrichment_value in row.items(): - if enrichment_value is not None: - enrichments[rule.new_property_name][matcher][ - enrichment_key.strip() - ] = enrichment_value.strip() - break + + # Try SQL-based multi-level matching + try: + dialect_name = None + if ( + self.db_session + and hasattr(self.db_session, "bind") + and self.db_session.bind + ): + dialect_name = self.db_session.bind.dialect.name + + matcher = MappingRuleMatcher( + dialect_name=dialect_name, session=self.db_session + ) + + self._add_enrichment_log( + f"Using SQL-based multi-level matching with dialect: {dialect_name or 'fallback'}", + "debug", + {"rule_id": rule.id}, + ) + + # Convert matcher_values to list of strings if needed + string_values = [] + for val in matcher_values: + if isinstance(val, str): + string_values.append(val) + else: + string_values.append(str(val)) + + matches = matcher.get_matching_rows_multi_level( + rule, key, string_values + ) + + if matches: + if rule.new_property_name not in enrichments: + enrichments[rule.new_property_name] = {} + + for matcher_key, match_data in matches.items(): + enrichments[rule.new_property_name][ + matcher_key + ] = match_data + except Exception as e: + self._add_enrichment_log( + f"Error using SQL multi-level matcher, falling back to in-memory iteration: {str(e)}", + "warning", + {"rule_id": rule.id}, + ) + # Fall back to the original implementation + if rule.rows: + for matcher in matcher_values: + matcher_str = ( + str(matcher) + if not isinstance(matcher, str) + else matcher + ) + if rule.prefix_to_remove: + matcher_str = matcher_str.replace( + rule.prefix_to_remove, "" + ) + for row in rule.rows: + if self._check_explicit_match( + row, key, matcher_str + ): + if rule.new_property_name not in enrichments: + enrichments[rule.new_property_name] = {} + + if ( + matcher_str + not in enrichments[rule.new_property_name] + ): + enrichments[rule.new_property_name][ + matcher_str + ] = {} + + for ( + enrichment_key, + enrichment_value, + ) in row.items(): + if enrichment_value is not None: + enrichments[rule.new_property_name][ + matcher_str + ][enrichment_key.strip()] = ( + enrichment_value.strip() + if isinstance(enrichment_value, str) + else enrichment_value + ) + break if enrichments: # Enrich the alert with the matched data from the row for key, matcher in enrichments.items(): @@ -962,3 +1084,32 @@ def check_incident_resolution(self, alert: Alert | AlertDto): incident.status = IncidentStatus.RESOLVED.value self.db_session.add(incident) self.db_session.commit() + + def _extract_alert_values_for_matchers( + self, alert: AlertDto, matchers: list[list[str]] + ) -> dict: + """ + Extract alert values that match the matchers for SQL-based matching. + + Args: + alert: AlertDto to extract values from + matchers: List of matcher rules + + Returns: + Dictionary of alert values needed for matching + """ + alert_values = {} + + # Get unique attributes across all matchers + all_attributes = set() + for matcher_group in matchers: + for attr in matcher_group: + all_attributes.add(attr.strip()) + + # Extract values for each attribute + for attr in all_attributes: + value = get_nested_attribute(alert, attr) + if value is not None: + alert_values[attr] = value + + return alert_values diff --git a/keep/api/bl/mapping_rule_matcher.py b/keep/api/bl/mapping_rule_matcher.py new file mode 100644 index 0000000000..1599efde81 --- /dev/null +++ b/keep/api/bl/mapping_rule_matcher.py @@ -0,0 +1,471 @@ +import json +import logging +from typing import Any, Dict, List, Optional + +from sqlalchemy import text +from sqlalchemy.orm import Session + +from keep.api.models.db.mapping import MappingRule + + +class MappingRuleMatcher: + """ + Class for matching mapping rules using SQL queries instead of in-memory iteration. + """ + + def __init__( + self, dialect_name: Optional[str] = None, session: Optional[Session] = None + ): + """ + Initialize the matcher with the database dialect and session. + + Args: + dialect_name: Database dialect name + session: SQLAlchemy session + """ + self.logger = logging.getLogger(__name__) + self.dialect_name = dialect_name + self.session = session + + def get_matching_row( + self, rule: MappingRule, alert_values: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + """ + Get the first matching row for a rule using SQL directly on the MappingRule.rows field. + + Args: + rule: MappingRule to check + alert_values: Dict of alert attribute values + + Returns: + First matching row as a dict if found, None otherwise + """ + if not rule.rows or not alert_values: + return None + + conditions = [] + params = {} + + params["rule_id"] = rule.id + + # Build the query based on the dialect + if self.dialect_name == "postgresql": + # Build SQL conditions from matchers + for i, and_group in enumerate(rule.matchers): + and_conditions = [] + + for j, attr in enumerate(and_group): + attr_value = alert_values.get(attr) + if attr_value is not None: + param_name = f"val_{i}_{j}" + + # Handle different value types + if isinstance(attr_value, str): + # String comparison with wildcard support + and_conditions.append( + f"(rows->>'{attr}' = :{param_name} OR rows->>'{attr}' = '*')" + ) + params[param_name] = attr_value + elif isinstance(attr_value, (int, float, bool)): + # For numeric or boolean values + and_conditions.append( + f"(rows->>'{attr}' = :{param_name} OR rows->>'{attr}' = '*')" + ) + params[param_name] = str( + attr_value + ) # Convert to string for JSON text comparison + else: + # For complex types, convert to JSON string + and_conditions.append( + f"(rows->>'{attr}' = :{param_name} OR rows->>'{attr}' = '*')" + ) + params[param_name] = json.dumps(attr_value) + + if and_conditions: + conditions.append(f"({' AND '.join(and_conditions)})") + + if not conditions: + return None + # PostgreSQL version + query = f""" + SELECT rows::jsonb + FROM ( + SELECT jsonb_array_elements(rows::jsonb) AS rows + FROM mappingrule + WHERE id = :rule_id + ) AS expanded_rows + WHERE {' OR '.join(conditions)} + LIMIT 1 + """ + elif self.dialect_name == "mysql": + # MySQL version with proper JSON_TABLE syntax + mysql_conditions = [] + + for i, and_group in enumerate(rule.matchers): + and_conditions = [] + + for j, attr in enumerate(and_group): + attr_value = alert_values.get(attr) + if attr_value is not None: + param_name = f"val_{i}_{j}" + + if isinstance(attr_value, str): + and_conditions.append( + f"""(JSON_EXTRACT(jt.json_object, '$.{attr}') = :{param_name} + OR JSON_EXTRACT(jt.json_object, '$.{attr}') = '"*"')""" + ) + elif isinstance(attr_value, (int, float)): + and_conditions.append( + f"""(CAST(JSON_EXTRACT(jt.json_object, '$.{attr}') AS CHAR) = :{param_name} + OR JSON_EXTRACT(jt.json_object, '$.{attr}') = '"*"')""" + ) + else: + and_conditions.append( + f"""(JSON_EXTRACT(jt.json_object, '$.{attr}') = :{param_name} + OR JSON_EXTRACT(jt.json_object, '$.{attr}') = '"*"')""" + ) + + if and_conditions: + mysql_conditions.append(f"({' AND '.join(and_conditions)})") + + query = f""" + SELECT jt.json_object + FROM mappingrule, + JSON_TABLE( + mappingrule.rows, + '$[*]' COLUMNS ( + sequence_number FOR ORDINALITY, + json_object JSON PATH '$' + ) + ) AS jt + WHERE mappingrule.id = :rule_id + AND ({' OR '.join(mysql_conditions)}) + LIMIT 1 + """ + elif self.dialect_name == "sqlite": + # SQLite version using json_each() function + # Build match conditions for each attribute + match_conditions = [] + + for i, and_group in enumerate(rule.matchers): + and_conditions = [] + + for j, attr in enumerate(and_group): + attr_value = alert_values.get(attr) + if attr_value is not None: + param_name = f"val_{i}_{j}" + + # Convert value to string for comparison + if isinstance(attr_value, (int, float, bool)): + attr_value = str(attr_value) + elif not isinstance(attr_value, str): + attr_value = json.dumps(attr_value) + + params[param_name] = attr_value + + # Build condition to check if the attribute matches or if there's a wildcard + and_conditions.append( + f""" + json_extract(row_data, '$."{attr}"') = :{param_name} + OR json_extract(row_data, '$."{attr}"') = '*' + """ + ) + + if and_conditions: + match_conditions.append(f"({' AND '.join(and_conditions)})") + + if not match_conditions: + return None + + query = f""" + WITH flattened AS ( + SELECT + value AS row_data + FROM mappingrule, json_each(mappingrule.rows) + WHERE mappingrule.id = :rule_id + ) + SELECT row_data FROM flattened + WHERE {' OR '.join(match_conditions)} + LIMIT 1 + """ + else: + # Default implementation (fallback to Python) + return self._fallback_get_matching_row(rule, alert_values) + + try: + if not self.session: + return self._fallback_get_matching_row(rule, alert_values) + + result = self.session.execute(text(query), params).first() + + if result: + result_dict = result[0] + if isinstance(result_dict, str): + result_dict = json.loads(result_dict) + return result_dict + return None + except Exception as e: + self.logger.exception( + f"Failed to query {self.dialect_name} for mapping rule {rule.id} due to {e}, falling back.", + extra={ + "tenant_id": rule.tenant_id, + "rule_id": rule.id, + "rule_name": rule.name, + }, + ) + # Fallback to in-memory implementation + return self._fallback_get_matching_row(rule, alert_values) + + def get_matching_rows_multi_level( + self, rule: MappingRule, key: str, values: List[str] + ) -> Dict[str, Dict[str, Any]]: + """ + Get matching rows for multi-level mapping rules using SQL. + + Args: + rule: MappingRule to check + key: The key to match on + values: List of values to match + + Returns: + Dict of matched values and their corresponding enrichments + """ + if not rule.rows or not values: + return {} + + result = {} + + # Process matcher values and build IN condition + clean_values = [] + value_mapping = {} + + for v in values: + clean_v = v + if rule.prefix_to_remove: + clean_v = v.replace(rule.prefix_to_remove, "") + clean_values.append(clean_v) + value_mapping[clean_v] = v + + if not clean_values: + return {} + + params = {"rule_id": rule.id} + placeholder_list = [] + + for i, val in enumerate(clean_values): + param_name = f"val_{i}" + params[param_name] = val + placeholder_list.append(f":{param_name}") + + # Build the query based on the dialect + if self.dialect_name == "postgresql": + # PostgreSQL version + in_clause = ", ".join(placeholder_list) + query = f""" + SELECT rows::jsonb + FROM ( + SELECT jsonb_array_elements(rows::jsonb) AS rows + FROM mappingrule + WHERE id = :rule_id + ) AS expanded_rows + WHERE rows->>'{key}' IN ({in_clause}) + """ + elif self.dialect_name == "mysql": + # MySQL version with proper JSON_TABLE syntax for multi-level matching + in_clause = ", ".join(placeholder_list) + + # Handle the @@ syntax by replacing with . and wrapping in quotes + json_key = key + if "@@" in json_key: + json_key = json_key.replace("@@", ".") + json_key = f'"{json_key}"' + + query = f""" + SELECT jt.json_object + FROM mappingrule, + JSON_TABLE( + mappingrule.rows, + '$[*]' COLUMNS ( + sequence_number FOR ORDINALITY, + json_object JSON PATH '$' + ) + ) AS jt + WHERE mappingrule.id = :rule_id + AND JSON_UNQUOTE(JSON_EXTRACT(jt.json_object, '$.{json_key}')) IN ({in_clause}) + """ + elif self.dialect_name == "sqlite": + # SQLite version using json_each + in_clause = ", ".join(placeholder_list) + + # Handle @@ and escaping in key for json_extract + json_key = key + if "@@" in json_key: + json_key = json_key.replace("@@", ".") + json_key = json_key.replace('"', '""') # Escape quotes for SQLite JSON path + + query = f""" + WITH flattened AS ( + SELECT + value AS row_data + FROM mappingrule, json_each(mappingrule.rows) + WHERE mappingrule.id = :rule_id + ) + SELECT row_data FROM flattened + WHERE json_extract(row_data, '$."{json_key}"') IN ({in_clause}) + """ + else: + # Fallback to Python implementation for other dialects + return self._fallback_get_matching_rows_multi_level(rule, key, values) + + try: + if not self.session: + return self._fallback_get_matching_rows_multi_level(rule, key, values) + + rows = self.session.execute(text(query), params).all() + + for row in rows: + row_dict = row[0] + if isinstance(row_dict, str): + row_dict = json.loads(row_dict) + match_key = row_dict.get(key) + + if match_key in clean_values: + match_data = {} + + for enrichment_key, enrichment_value in row_dict.items(): + if enrichment_value is not None and enrichment_key != key: + match_data[enrichment_key.strip()] = ( + enrichment_value.strip() + if isinstance(enrichment_value, str) + else enrichment_value + ) + + result[match_key] = match_data + + return result + except Exception as e: + self.logger.exception( + f"Failed to query multi-level {self.dialect_name} for mapping rule {rule.id} due to {e}, falling back.", + extra={ + "tenant_id": rule.tenant_id, + "rule_id": rule.id, + "rule_name": rule.name, + }, + ) + # Fallback to Python implementation + return self._fallback_get_matching_rows_multi_level(rule, key, values) + + def _fallback_get_matching_row( + self, rule: MappingRule, alert_values: Dict + ) -> Optional[Dict]: + """ + Fallback method to get matching row using in-memory iteration. + + Args: + rule: MappingRule to check + alert_values: Dict of alert attribute values + + Returns: + First matching row if found, None otherwise + """ + if not rule.rows: + return None + + for row in rule.rows or []: + if any( + self._check_matcher(alert_values, row, matcher) + for matcher in rule.matchers + ): + return row + return None + + def _fallback_get_matching_rows_multi_level( + self, rule: MappingRule, key: str, values: List[str] + ) -> Dict[str, Dict]: + """ + Fallback method to get matching rows for multi-level mapping using in-memory iteration. + + Args: + rule: MappingRule to check + key: The key to match on + values: List of values to match + + Returns: + Dict mapping matched values to their enrichments + """ + result = {} + + if not rule.rows: + return result + + for matcher_value in values: + clean_value = matcher_value + if rule.prefix_to_remove: + clean_value = matcher_value.replace(rule.prefix_to_remove, "") + + for row in rule.rows: + if row.get(key) == clean_value: + match_data = {} + for enrichment_key, enrichment_value in row.items(): + if enrichment_value is not None and enrichment_key != key: + match_data[enrichment_key.strip()] = ( + enrichment_value.strip() + if isinstance(enrichment_value, str) + else enrichment_value + ) + result[clean_value] = match_data + break + + return result + + def _check_matcher(self, alert_values: Dict, row: Dict, matcher: List[str]) -> bool: + """ + Check if an alert matches the conditions in a matcher. + + Args: + alert_values: Alert values dict + row: Row from the mapping rule + matcher: List of attributes to match (AND condition) + + Returns: + True if matched, False otherwise + """ + try: + return all( + self._is_match( + alert_values.get(attribute.strip()), + row.get(attribute.strip()), + ) + or alert_values.get(attribute.strip()) == row.get(attribute.strip()) + or row.get(attribute.strip()) == "*" # Wildcard match + for attribute in matcher + ) + except TypeError: + return False + + @staticmethod + def _is_match(value, pattern): + """ + Check if a value matches a pattern. + + Args: + value: Value to check + pattern: Pattern to match against + + Returns: + True if matched, False otherwise + """ + import re + + if value is None or pattern is None: + return False + + # Add start and end anchors to pattern to ensure exact match + if isinstance(pattern, str) and isinstance(value, str): + # Only add anchors if they're not already there + if not pattern.startswith("^"): + pattern = f"^{pattern}" + if not pattern.endswith("$"): + pattern = f"{pattern}$" + + return re.search(pattern, value) is not None diff --git a/keep/api/core/cel_to_sql/mapping_rule_matcher.py b/keep/api/core/cel_to_sql/mapping_rule_matcher.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keep/api/core/db.py b/keep/api/core/db.py index e413fbfd97..c39fd63fb7 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -40,7 +40,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.exc import IntegrityError, OperationalError -from sqlalchemy.orm import joinedload, subqueryload +from sqlalchemy.orm import defer, joinedload, subqueryload from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql import exists, expression from sqlmodel import Session, SQLModel, col, or_, select, text @@ -248,11 +248,13 @@ def create_workflow_execution( def get_mapping_rule_by_id( - tenant_id: str, rule_id: str, session: Optional[Session] = None + tenant_id: str, rule_id: int, session: Optional[Session] = None ) -> MappingRule | None: with existed_or_new_session(session) as session: - query = select(MappingRule).where( - MappingRule.tenant_id == tenant_id, MappingRule.id == rule_id + query = ( + select(MappingRule) + .where(MappingRule.tenant_id == tenant_id, MappingRule.id == rule_id) + .options(defer(MappingRule.rows)) ) return session.exec(query).first() diff --git a/tests/conftest.py b/tests/conftest.py index fb6577d94f..d940522346 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -229,6 +229,7 @@ def db_session(request, monkeypatch): # sqlite else: db_connection_string = "sqlite:///:memory:" + # db_connection_string = "sqlite:///state/db.sqlite3?check_same_thread=False" mock_engine = create_engine( db_connection_string, connect_args={"check_same_thread": False}, @@ -446,7 +447,7 @@ def elastic_container(docker_ip, docker_services): @pytest.fixture def elastic_client(request): - if hasattr(request, 'param') and request.param is False: + if hasattr(request, "param") and request.param is False: yield None else: # this is so if any other module initialized Elasticsearch, it will be deleted diff --git a/tests/e2e_tests/complex.csv b/tests/e2e_tests/complex.csv new file mode 100644 index 0000000000..fc2faa4e7a --- /dev/null +++ b/tests/e2e_tests/complex.csv @@ -0,0 +1,2 @@ +a@@b,c,d +hello,sf,hey diff --git a/tests/e2e_tests/multi.csv b/tests/e2e_tests/multi.csv new file mode 100644 index 0000000000..0932371eea --- /dev/null +++ b/tests/e2e_tests/multi.csv @@ -0,0 +1,4 @@ +a,b,c,d +1,2,3,4 +5,6,7,8 +9,10,11,12 diff --git a/tests/e2e_tests/simple.csv b/tests/e2e_tests/simple.csv new file mode 100644 index 0000000000..cba67ea1ac --- /dev/null +++ b/tests/e2e_tests/simple.csv @@ -0,0 +1,2 @@ +a,b.c,c,d +hello,to,all,sf diff --git a/tests/e2e_tests/test_end_to_end_mapping.py b/tests/e2e_tests/test_end_to_end_mapping.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_enrichments.py b/tests/test_enrichments.py index 649b76c8c1..4e2d330f1d 100644 --- a/tests/test_enrichments.py +++ b/tests/test_enrichments.py @@ -227,7 +227,7 @@ def test_run_mapping_rules_applies(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -252,7 +252,7 @@ def test_run_mapping_rules_with_regex_match(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -296,7 +296,7 @@ def test_run_mapping_rules_no_match(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] del mock_alert_dto.service @@ -322,7 +322,7 @@ def test_check_matcher_with_and_condition(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -362,7 +362,7 @@ def test_check_matcher_with_or_condition(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -614,7 +614,9 @@ def test_topology_mapping_rule_enrichment(mock_session, mock_alert_dto): ) # Mock the session to return this topology mapping rule - mock_session.query.return_value.filter.return_value.all.return_value = [rule] + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ + rule + ] # Initialize the EnrichmentsBl class with the mock session enrichment_bl = EnrichmentsBl(tenant_id="test_tenant", db=mock_session) @@ -679,7 +681,7 @@ def test_run_mapping_rules_with_complex_matchers(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -733,7 +735,7 @@ def test_run_mapping_rules_enrichments_filtering(mock_session, mock_alert_dto): disabled=False, type="csv", ) - mock_session.query.return_value.filter.return_value.filter.return_value.order_by.return_value.all.return_value = [ + mock_session.query.return_value.filter.return_value.filter.return_value.options.return_value.order_by.return_value.all.return_value = [ rule ] @@ -881,7 +883,7 @@ def test_batch_enrichment(db_session, client, test_app, create_alert, elastic_cl ) assert response.status_code == 200 - assert response.json() == {"status":"ok"} + assert response.json() == {"status": "ok"} time.sleep(1) diff --git a/tests/test_sql_mapping_complex_json.py b/tests/test_sql_mapping_complex_json.py new file mode 100644 index 0000000000..59bb9f1ac7 --- /dev/null +++ b/tests/test_sql_mapping_complex_json.py @@ -0,0 +1,456 @@ +import json + +import pytest +from sqlmodel import Session + +from keep.api.bl.enrichments_bl import EnrichmentsBl, get_nested_attribute +from keep.api.bl.mapping_rule_matcher import MappingRuleMatcher +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.api.models.db.mapping import MappingRule +from tests.fixtures.client import test_app # noqa + + +@pytest.fixture +def complex_mapping_rule(db_session: Session): + """Create a mapping rule with complex JSON matchers.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Complex JSON Matcher", + description="Rule for matching complex nested JSON structures", + matchers=[["attributes.environment"], ["metadata.region", "metadata.zone"]], + rows=[ + { + "attributes.environment": "production", + "owner": "prod-team", + "sla": "99.9%", + "priority": "critical", + }, + { + "attributes.environment": "staging", + "owner": "dev-team", + "sla": "99.5%", + "priority": "high", + }, + { + "metadata.region": "us-west", + "metadata.zone": "us-west-1", + "owner": "west-team", + "datacenter": "dc-west", + "backup": "daily", + }, + { + "metadata.region": "us-east", + "metadata.zone": "us-east-1", + "owner": "east-team", + "datacenter": "dc-east", + "backup": "hourly", + }, + ], + type="csv", + file_name="", + created_by="test", + condition="", + new_property_name="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def complex_multi_level_rule(db_session: Session): + """Create a multi-level mapping rule with complex JSON matchers.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Complex Multi-level JSON Matcher", + description="Multi-level rule for complex nested JSON structures", + matchers=[["services.ids"]], + rows=[ + { + "service_id": "svc-1", + "details.type": "web", + "details.tier": "frontend", + "owner": "team-a", + "contact": {"email": "team-a@example.com", "slack": "#team-a-alerts"}, + }, + { + "service_id": "svc-2", + "details.type": "api", + "details.tier": "backend", + "owner": "team-b", + "contact": {"email": "team-b@example.com", "slack": "#team-b-alerts"}, + }, + { + "service_id": "svc-3", + "details.type": "database", + "details.tier": "data", + "owner": "team-db", + "contact": { + "email": "db-team@example.com", + "phone": "555-123-4567", + "slack": "#db-alerts", + }, + }, + ], + type="csv", + is_multi_level=True, + new_property_name="service_details", + file_name="", + created_by="test", + condition="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def complex_alert_dto(): + """Create an alert with complex nested attributes.""" + alert = AlertDto( + id="test-complex-id", + name="Complex Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-complex-fingerprint", + ) + + # Add complex nested attributes using json strings + attributes_json = json.dumps( + { + "environment": "production", + "application": { + "name": "payment-service", + "version": "2.3.1", + "components": ["api", "processor", "database"], + }, + "tags": ["finance", "critical", "monitored"], + } + ) + setattr(alert, "attributes", json.loads(attributes_json)) + + metadata_json = json.dumps( + { + "region": "us-west", + "zone": "us-west-1", + "instance": { + "id": "i-12345abcdef", + "type": "m5.large", + "launchTime": "2023-01-01T00:00:00Z", + }, + "network": {"vpc": "vpc-abc123", "subnet": "subnet-def456"}, + } + ) + setattr(alert, "metadata", json.loads(metadata_json)) + + # Add service IDs as an array attribute + services_json = json.dumps( + {"ids": ["svc-1", "svc-3"], "types": ["web", "database"]} + ) + setattr(alert, "services", json.loads(services_json)) + + return alert + + +@pytest.fixture +def dotted_attribute_rule(db_session: Session): + """Create a mapping rule with a matcher for attributes containing dots in their names.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Dotted Attribute Matcher", + description="Rule for matching attributes with dots in their names", + matchers=[["config.aws@@region"]], + rows=[ + { + "config.aws@@region": "us-west-2", + "owner": "west-team", + "support": "24/7", + }, + { + "config.aws@@region": "us-east-1", + "owner": "east-team", + "support": "business hours", + }, + ], + type="csv", + file_name="", + created_by="test", + condition="", + new_property_name="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def dotted_alert_dto(): + """Create an alert with attributes that contain dots in their names.""" + alert = AlertDto( + id="test-dotted-id", + name="Dotted Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-dotted-fingerprint", + ) + + # Add a nested config object with a key that contains a dot + config = { + "aws.region": "us-west-2", + "instance_type": "t2.micro", + "subnet_id": "subnet-12345", + } + setattr(alert, "config", config) + + return alert + + +@pytest.fixture +def regex_pattern_rule(db_session: Session): + """Create a mapping rule for testing regex pattern matching issues.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Regex Pattern Rule", + description="Rule for testing regex pattern matching behavior", + matchers=[["service_id"]], + rows=[ + {"service_id": "customer-9", "owner": "team-9", "region": "us-west-9"}, + {"service_id": "customer-99", "owner": "team-99", "region": "us-west-99"}, + ], + type="csv", + file_name="", + created_by="test", + condition="", + new_property_name="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +def test_match_nested_json_first_matcher_group( + db_session: Session, complex_mapping_rule, complex_alert_dto +): + """Test matching against the first matcher group with nested JSON.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Extract flattened alert values + alert_values = {} + if hasattr(complex_alert_dto, "attributes") and isinstance( + complex_alert_dto.attributes, dict + ): + for key, value in complex_alert_dto.attributes.items(): + alert_values[f"attributes.{key}"] = value + + if hasattr(complex_alert_dto, "metadata") and isinstance( + complex_alert_dto.metadata, dict + ): + for key, value in complex_alert_dto.metadata.items(): + alert_values[f"metadata.{key}"] = value + + # Get matching row using matcher + matched_row = matcher.get_matching_row(complex_mapping_rule, alert_values) + + # Verify correct match is found (production environment) + assert matched_row is not None + assert matched_row["owner"] == "prod-team" + assert matched_row["sla"] == "99.9%" + assert matched_row["priority"] == "critical" + + +def test_match_nested_json_second_matcher_group( + db_session: Session, complex_mapping_rule, complex_alert_dto +): + """Test matching against the second matcher group with nested JSON.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Modify alert to not match the first matcher group + if hasattr(complex_alert_dto, "attributes") and isinstance( + complex_alert_dto.attributes, dict + ): + complex_alert_dto.attributes["environment"] = "unknown" + + # Extract flattened alert values + alert_values = {} + if hasattr(complex_alert_dto, "attributes") and isinstance( + complex_alert_dto.attributes, dict + ): + for key, value in complex_alert_dto.attributes.items(): + alert_values[f"attributes.{key}"] = value + + if hasattr(complex_alert_dto, "metadata") and isinstance( + complex_alert_dto.metadata, dict + ): + for key, value in complex_alert_dto.metadata.items(): + alert_values[f"metadata.{key}"] = value + + # Get matching row using matcher + matched_row = matcher.get_matching_row(complex_mapping_rule, alert_values) + + # Verify correct match is found (us-west region and us-west-1 zone) + assert matched_row is not None + assert matched_row["owner"] == "west-team" + assert matched_row["datacenter"] == "dc-west" + assert matched_row["backup"] == "daily" + + +def test_multi_level_complex_json_matching( + db_session: Session, complex_multi_level_rule, complex_alert_dto +): + """Test multi-level matching with complex nested JSON structures.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Extract service IDs for multi-level matching + service_ids = None + if hasattr(complex_alert_dto, "services") and isinstance( + complex_alert_dto.services, dict + ): + if "ids" in complex_alert_dto.services: + service_ids = complex_alert_dto.services["ids"] + + assert service_ids is not None, "Service IDs not found in alert" + + # Get matching rows using multi-level matcher + matches = matcher.get_matching_rows_multi_level( + complex_multi_level_rule, "service_id", service_ids + ) + + # Verify correct matches are found + assert len(matches) == 2 # Should match svc-1 and svc-3 + + # Check svc-1 details + assert "svc-1" in matches + assert matches["svc-1"]["owner"] == "team-a" + assert matches["svc-1"]["details.tier"] == "frontend" + assert isinstance(matches["svc-1"]["contact"], dict) + assert matches["svc-1"]["contact"]["email"] == "team-a@example.com" + + # Check svc-3 details + assert "svc-3" in matches + assert matches["svc-3"]["owner"] == "team-db" + assert matches["svc-3"]["details.tier"] == "data" + assert isinstance(matches["svc-3"]["contact"], dict) + assert matches["svc-3"]["contact"]["email"] == "db-team@example.com" + assert matches["svc-3"]["contact"]["phone"] == "555-123-4567" + + +def test_dotted_attribute_direct_access(): + """Test the get_nested_attribute function with attributes containing dots.""" + # Create a simple object with a nested attribute containing a dot + alert = AlertDto( + id="test-id", + name="Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-fingerprint", + ) + + # Add a nested config object with a key that contains a dot + config = {"aws.region": "us-west-2", "instance_type": "t2.micro"} + setattr(alert, "config", config) + + # Test direct access with the @@ placeholder + value = get_nested_attribute(alert, "config.aws@@region") + assert value == "us-west-2" + + # Test access without the placeholder (should fail or return None) + value = get_nested_attribute(alert, "config.aws.region") + assert value is None # Since "aws" is not a nested object in config + + +def test_dotted_attribute_mapping_rule( + db_session: Session, dotted_attribute_rule, dotted_alert_dto +): + """Test that mapping rules correctly handle attributes with dots in their names.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Check if the rule matches and enriches the alert + result = enrichment_bl.check_if_match_and_enrich( + dotted_alert_dto, dotted_attribute_rule + ) + + # Verify enrichment worked + assert result is True + assert hasattr(dotted_alert_dto, "owner") + assert dotted_alert_dto.owner == "west-team" + assert hasattr(dotted_alert_dto, "support") + assert dotted_alert_dto.support == "24/7" + + +def test_dotted_attribute_direct_matcher_access( + db_session: Session, dotted_attribute_rule, dotted_alert_dto +): + """Test that the matcher correctly handles attributes with dots in their names.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Extract the alert values including the dotted attribute + alert_values = {} + if hasattr(dotted_alert_dto, "config") and isinstance( + dotted_alert_dto.config, dict + ): + for key, value in dotted_alert_dto.config.items(): + # If the key contains a dot, replace it with @@ + if "." in key: + formatted_key = key.replace(".", "@@") + alert_values[f"config.{formatted_key}"] = value + else: + alert_values[f"config.{key}"] = value + + # Get matching row using matcher + matched_row = matcher.get_matching_row(dotted_attribute_rule, alert_values) + + # Verify correct match is found + assert matched_row is not None + assert matched_row["owner"] == "west-team" + assert matched_row["support"] == "24/7" diff --git a/tests/test_sql_mapping_enrichment.py b/tests/test_sql_mapping_enrichment.py new file mode 100644 index 0000000000..aaa83d9720 --- /dev/null +++ b/tests/test_sql_mapping_enrichment.py @@ -0,0 +1,330 @@ +from datetime import datetime +from typing import List + +import pytest +from sqlmodel import Session + +from keep.api.bl.enrichments_bl import EnrichmentsBl +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.alert import AlertDto +from keep.api.models.db.mapping import MappingRule +from tests.fixtures.client import test_app # noqa + + +@pytest.fixture +def alert_dto(): + """Create a test AlertDto for testing.""" + return AlertDto( + id="test-alert-id", + name="Test Alert", + status="firing", + severity="high", + lastReceived=datetime.utcnow().isoformat(), + source=["test-source"], + fingerprint="test-fingerprint", + service="test-service", + environment="test-environment", + ) + + +@pytest.fixture +def simple_mapping_rule(db_session: Session): + """Create a simple mapping rule for testing.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Simple Mapping Rule", + description="A simple mapping rule for testing", + matchers=[["service"]], + rows=[{"service": "test-service", "team": "test-team", "owner": "test-owner"}], + type="csv", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def multi_matcher_rule(db_session: Session): + """Create a mapping rule with multiple matchers.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Multi Matcher Rule", + description="A rule with multiple matchers", + matchers=[["service", "environment"], ["source"]], + rows=[ + { + "service": "test-service", + "environment": "test-environment", + "team": "test-team", + "owner": "test-owner", + "tier": "tier-1", + }, + { + "source": "test-source", + "team": "source-team", + "owner": "source-owner", + "tier": "tier-2", + }, + ], + type="csv", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def multi_level_rule(db_session: Session): + """Create a multi-level mapping rule.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Multi-Level Rule", + description="A multi-level mapping rule", + matchers=[["tags"]], + rows=[ + {"tags": "tag1", "contact": "contact1", "team": "team1"}, + {"tags": "tag2", "contact": "contact2", "team": "team2"}, + {"tags": "tag3", "contact": "contact3", "team": "team3"}, + ], + is_multi_level=True, + new_property_name="tag_info", + type="csv", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def wildcard_rule(db_session: Session): + """Create a mapping rule with wildcard matcher.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Wildcard Rule", + description="A rule with wildcard matcher", + matchers=[["service"]], + rows=[ + { + "service": "*", # Matches any service + "team": "default-team", + "owner": "default-owner", + } + ], + type="csv", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def all_mapping_rules( + db_session: Session, + simple_mapping_rule: MappingRule, + multi_matcher_rule: MappingRule, + multi_level_rule: MappingRule, + wildcard_rule: MappingRule, +) -> List[MappingRule]: + """Return all created mapping rules.""" + return [simple_mapping_rule, multi_matcher_rule, multi_level_rule, wildcard_rule] + + +def test_simple_mapping_rule_match( + db_session: Session, alert_dto: AlertDto, simple_mapping_rule: MappingRule +): + """Test that a simple mapping rule correctly matches and enriches an alert.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Run the matching logic + result = enrichment_bl.check_if_match_and_enrich(alert_dto, simple_mapping_rule) + + # Verify the result + assert result is True + assert alert_dto.team == "test-team" + assert alert_dto.owner == "test-owner" + + +def test_multi_matcher_rule_first_matcher( + db_session: Session, alert_dto: AlertDto, multi_matcher_rule: MappingRule +): + """Test that the first matcher group in a multi-matcher rule matches correctly.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Alert already has service="test-service" and environment="test-environment" + result = enrichment_bl.check_if_match_and_enrich(alert_dto, multi_matcher_rule) + + # Verify the result + assert result is True + assert alert_dto.team == "test-team" + assert alert_dto.owner == "test-owner" + assert alert_dto.tier == "tier-1" + + +def test_multi_matcher_rule_second_matcher( + db_session: Session, alert_dto: AlertDto, multi_matcher_rule: MappingRule +): + """Test that the second matcher group in a multi-matcher rule matches correctly.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Change service to not match the first matcher + alert_dto.service = "different-service" + # source still matches the second matcher + + result = enrichment_bl.check_if_match_and_enrich(alert_dto, multi_matcher_rule) + + # Verify the result + assert result is True + assert alert_dto.team == "source-team" + assert alert_dto.owner == "source-owner" + assert alert_dto.tier == "tier-2" + + +def test_no_match( + db_session: Session, alert_dto: AlertDto, simple_mapping_rule: MappingRule +): + """Test that no match returns False and doesn't enrich the alert.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Modify the alert to not match any rule + alert_dto.service = "non-matching-service" + alert_dto.source = ["non-matching-source"] + + # Store original attribute values + original_team = getattr(alert_dto, "team", None) + original_owner = getattr(alert_dto, "owner", None) + + result = enrichment_bl.check_if_match_and_enrich(alert_dto, simple_mapping_rule) + + # Verify the result + assert result is False + assert getattr(alert_dto, "team", None) == original_team + assert getattr(alert_dto, "owner", None) == original_owner + + +def test_wildcard_match( + db_session: Session, alert_dto: AlertDto, wildcard_rule: MappingRule +): + """Test that wildcard matching works correctly.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Modify the alert to have a service that doesn't explicitly match + alert_dto.service = "any-service" + + result = enrichment_bl.check_if_match_and_enrich(alert_dto, wildcard_rule) + + # Verify the result + assert result is True + assert alert_dto.team == "default-team" + assert alert_dto.owner == "default-owner" + + +def test_multi_level_mapping( + db_session: Session, alert_dto: AlertDto, multi_level_rule: MappingRule +): + """Test that multi-level mapping works correctly.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Set tags attribute to a list of values + alert_dto.tags = ["tag1", "tag3", "non-existent-tag"] + + result = enrichment_bl.check_if_match_and_enrich(alert_dto, multi_level_rule) + + # Verify the result + assert result is True + assert hasattr(alert_dto, "tag_info") + + # Verify the nested structure + tag_info = alert_dto.tag_info + assert "tag1" in tag_info + assert "tag3" in tag_info + assert "non-existent-tag" not in tag_info + + assert tag_info["tag1"]["contact"] == "contact1" + assert tag_info["tag1"]["team"] == "team1" + assert tag_info["tag3"]["contact"] == "contact3" + assert tag_info["tag3"]["team"] == "team3" + + +def test_run_mapping_rules( + db_session: Session, alert_dto: AlertDto, all_mapping_rules: List[MappingRule] +): + """Test running all mapping rules on an alert.""" + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Set up alert to match a specific rule + alert_dto.service = "test-service" + alert_dto.environment = "test-environment" + alert_dto.source = ["test-source"] + alert_dto.tags = ["tag1", "tag2"] + + # Run all mapping rules + result_alert = enrichment_bl.run_mapping_rules(alert_dto) + + # Verify that the alert was enriched + assert result_alert == alert_dto # Same object + assert hasattr(alert_dto, "team") + assert hasattr(alert_dto, "owner") + # Multi-level enrichment should also have happened + assert hasattr(alert_dto, "tag_info") + assert "tag1" in alert_dto.tag_info + assert "tag2" in alert_dto.tag_info + + +def test_rule_ordering_by_priority(db_session: Session, alert_dto: AlertDto): + """Test that rules are applied in order of priority.""" + # Create two rules with different priorities that would apply different enrichments + high_priority_rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=10, # Higher priority + name="High Priority Rule", + description="Rule with high priority", + matchers=[["service"]], + rows=[ + { + "service": "test-service", + "team": "high-priority-team", + "owner": "high-priority-owner", + } + ], + type="csv", + ) + + low_priority_rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=5, # Lower priority + name="Low Priority Rule", + description="Rule with low priority", + matchers=[["service"]], + rows=[ + { + "service": "test-service", + "team": "low-priority-team", + "owner": "low-priority-owner", + } + ], + type="csv", + ) + + db_session.add(high_priority_rule) + db_session.add(low_priority_rule) + db_session.commit() + + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Set up alert to match both rules + alert_dto.service = "test-service" + + # Run all mapping rules + result_alert = enrichment_bl.run_mapping_rules(alert_dto) + + # Verify that the high priority rule was applied + assert result_alert.team == "low-priority-team" + assert result_alert.owner == "low-priority-owner" diff --git a/tests/test_sql_mapping_error_handling.py b/tests/test_sql_mapping_error_handling.py new file mode 100644 index 0000000000..04ae0de020 --- /dev/null +++ b/tests/test_sql_mapping_error_handling.py @@ -0,0 +1,253 @@ +from unittest.mock import patch + +import pytest +from sqlalchemy.exc import SQLAlchemyError +from sqlmodel import Session + +from keep.api.bl.enrichments_bl import EnrichmentsBl +from keep.api.bl.mapping_rule_matcher import MappingRuleMatcher +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.api.models.db.mapping import MappingRule +from tests.fixtures.client import test_app # noqa + + +@pytest.fixture +def mock_mapping_rule(): + """Create a simple mapping rule for testing.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Test Rule", + description="A test rule for error handling", + matchers=[["service"]], + rows=[ + {"service": "web", "owner": "team-a", "email": "team-a@example.com"}, + {"service": "api", "owner": "team-b", "email": "team-b@example.com"}, + ], + type="csv", + file_name="", + created_by="test", + condition="", + new_property_name="", + prefix_to_remove="", + ) + return rule + + +@pytest.fixture +def alert_dto(): + """Create a test alert DTO.""" + alert = AlertDto( + id="test-id", + name="Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-fingerprint", + ) + # Add service attribute + setattr(alert, "service", "web") + return alert + + +@pytest.fixture +def alert_dto_multi_level(): + """Create a test alert DTO.""" + alert = AlertDto( + id="test-id", + name="Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-fingerprint", + services=["web"], + ) + return alert + + +def test_fallback_on_sql_error(db_session: Session, mock_mapping_rule, alert_dto): + """Test that matcher falls back to in-memory matching when SQL query fails.""" + # Mock the execute method to raise an exception + with patch( + "sqlmodel.Session.execute", side_effect=SQLAlchemyError("Mocked SQL error") + ): + # Create the matcher with the mocked session + matcher = MappingRuleMatcher( + dialect_name="sqlite", session=db_session # Use a known dialect name + ) + + # Call get_matching_row which should trigger the SQL error and fall back + alert_values = {"service": "web"} + matched_row = matcher.get_matching_row(mock_mapping_rule, alert_values) + + # Should still get a result from the fallback method + assert matched_row is not None + assert matched_row["owner"] == "team-a" + assert matched_row["email"] == "team-a@example.com" + + +def test_fallback_on_multi_level_sql_error(db_session: Session): + """Test that matcher falls back on multi-level matching when SQL query fails.""" + # Create a multi-level mapping rule + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Multi-level Test Rule", + description="A multi-level test rule for error handling", + matchers=[["services"]], + rows=[ + {"service_id": "web", "owner": "team-a", "email": "team-a@example.com"}, + {"service_id": "api", "owner": "team-b", "email": "team-b@example.com"}, + ], + type="csv", + is_multi_level=True, + new_property_name="service_info", + file_name="", + created_by="test", + condition="", + prefix_to_remove="", + ) + + # Mock the execute method to raise an exception + with patch( + "sqlmodel.Session.execute", side_effect=SQLAlchemyError("Mocked SQL error") + ): + # Create the matcher with the mocked session + matcher = MappingRuleMatcher( + dialect_name="sqlite", session=db_session # Use a known dialect name + ) + + # Call get_matching_rows_multi_level which should trigger the SQL error and fall back + service_ids = ["web", "api"] + matches = matcher.get_matching_rows_multi_level(rule, "service_id", service_ids) + + # Should still get results from the fallback method + assert len(matches) == 2 + assert matches["web"]["owner"] == "team-a" + assert matches["api"]["owner"] == "team-b" + + +def test_enrichment_bl_fallback_on_error( + db_session: Session, mock_mapping_rule, alert_dto +): + """Test that EnrichmentsBl falls back when MappingRuleMatcher fails.""" + # Set up a mock that will raise an exception when get_matching_row is called + with patch( + "keep.api.bl.mapping_rule_matcher.MappingRuleMatcher.get_matching_row", + side_effect=Exception("Matcher error"), + ): + # Create an instance of EnrichmentsBl + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Call check_if_match_and_enrich which should catch the exception and fall back + result = enrichment_bl.check_if_match_and_enrich(alert_dto, mock_mapping_rule) + + # Should still succeed with the fallback method + assert result is True + assert hasattr(alert_dto, "owner") + assert alert_dto.owner == "team-a" + assert hasattr(alert_dto, "email") + assert alert_dto.email == "team-a@example.com" + + +def test_unsupported_dialect(mock_mapping_rule): + """Test handling of unsupported database dialect.""" + # Create matcher with an unsupported dialect + matcher = MappingRuleMatcher(dialect_name="unsupported_dialect", session=None) + + # Should fall back to Python implementation + alert_values = {"service": "web"} + matched_row = matcher.get_matching_row(mock_mapping_rule, alert_values) + + # Should still get a result + assert matched_row is not None + assert matched_row["owner"] == "team-a" + assert matched_row["email"] == "team-a@example.com" + + +def test_null_session_handling(mock_mapping_rule): + """Test handling of null database session.""" + # Create matcher with a valid dialect but null session + matcher = MappingRuleMatcher(dialect_name="sqlite", session=None) + + # Should fall back to Python implementation + alert_values = {"service": "web"} + matched_row = matcher.get_matching_row(mock_mapping_rule, alert_values) + + # Should still get a result + assert matched_row is not None + assert matched_row["owner"] == "team-a" + assert matched_row["email"] == "team-a@example.com" + + +def test_missing_alert_values(db_session: Session, mock_mapping_rule): + """Test behavior when alert values don't have required attributes.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Empty alert values - should not match anything + alert_values = {} + matched_row = matcher.get_matching_row(mock_mapping_rule, alert_values) + + # Should not match anything + assert matched_row is None + + # Partial alert values - missing the required "service" field + alert_values = {"other_field": "value"} + matched_row = matcher.get_matching_row(mock_mapping_rule, alert_values) + + # Should not match anything + assert matched_row is None + + +def test_multi_level_enrichment_fallback(db_session: Session, alert_dto_multi_level): + """Test that multi-level enrichment falls back when matcher fails.""" + # Create a multi-level mapping rule + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Multi-level Test Rule", + description="A multi-level test rule for error handling", + matchers=[["services"]], + rows=[ + {"services": "web", "owner": "team-a", "email": "team-a@example.com"}, + {"services": "api", "owner": "team-b", "email": "team-b@example.com"}, + ], + type="csv", + is_multi_level=True, + new_property_name="service_info", + file_name="", + created_by="test", + condition="", + prefix_to_remove="", + ) + + # Set up a mock that will raise an exception when get_matching_rows_multi_level is called + with patch( + "keep.api.bl.mapping_rule_matcher.MappingRuleMatcher.get_matching_rows_multi_level", + side_effect=Exception("Matcher error"), + ): + # Create an instance of EnrichmentsBl + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Call check_if_match_and_enrich which should catch the exception and fall back + result = enrichment_bl.check_if_match_and_enrich(alert_dto_multi_level, rule) + + # Should still succeed with the fallback method + assert result is True + assert hasattr(alert_dto_multi_level, "service_info") + assert len(alert_dto_multi_level.service_info) == 1 + assert alert_dto_multi_level.service_info["web"]["owner"] == "team-a" + assert ( + alert_dto_multi_level.service_info["web"]["email"] == "team-a@example.com" + ) diff --git a/tests/test_sql_mapping_matcher.py b/tests/test_sql_mapping_matcher.py new file mode 100644 index 0000000000..ed2ea5afdc --- /dev/null +++ b/tests/test_sql_mapping_matcher.py @@ -0,0 +1,259 @@ +import pytest +from sqlmodel import Session, select + +from keep.api.bl.mapping_rule_matcher import MappingRuleMatcher +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.db.mapping import MappingRule +from tests.fixtures.client import test_app # noqa + + +@pytest.fixture +def mapping_rule(db_session: Session): + """Create a mapping rule with test data.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Test SQL Matcher Rule", + description="Test rule for SQL-based matching", + matchers=[["service", "severity"], ["source"]], + rows=[ + { + "service": "backend", + "severity": "high", + "team": "backend-team", + "owner": "backend-owner", + }, + { + "service": "frontend", + "severity": "medium", + "team": "frontend-team", + "owner": "frontend-owner", + }, + { + "source": "prometheus", + "team": "monitoring-team", + "owner": "monitoring-owner", + }, + { + "service": "*", # Wildcard + "severity": "critical", + "team": "sre-team", + "owner": "sre-owner", + }, + ], + type="csv", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def multi_level_rule(db_session: Session): + """Create a multi-level mapping rule with test data.""" + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Multi-Level SQL Matcher Rule", + description="Test rule for multi-level SQL-based matching", + matchers=[["customer"]], + rows=[ + {"customer": "customer-1", "contact": "contact-1", "priority": "high"}, + {"customer": "customer-2", "contact": "contact-2", "priority": "medium"}, + {"customer": "customer-3", "contact": "contact-3", "priority": "low"}, + ], + type="csv", + is_multi_level=True, + new_property_name="customer_data", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +def test_get_matching_row_exact_match(db_session: Session, mapping_rule: MappingRule): + """Test that exact matching works correctly with DB session.""" + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test case 1: Exact match on service and severity + alert_values = {"service": "backend", "severity": "high"} + + matched_row = matcher.get_matching_row(mapping_rule, alert_values) + + assert matched_row is not None + assert matched_row["team"] == "backend-team" + assert matched_row["owner"] == "backend-owner" + + +def test_get_matching_row_wildcard_match( + db_session: Session, mapping_rule: MappingRule +): + """Test that wildcard matching works correctly with DB session.""" + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test with wildcard row + alert_values = { + "service": "unknown-service", # Not directly in the rule + "severity": "critical", + } + + matched_row = matcher.get_matching_row(mapping_rule, alert_values) + + assert matched_row is not None + assert matched_row["team"] == "sre-team" + assert matched_row["owner"] == "sre-owner" + + +def test_get_matching_row_alternative_matcher( + db_session: Session, mapping_rule: MappingRule +): + """Test that alternative matchers (OR conditions) work correctly with DB session.""" + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test with source matcher (alternative matcher) + alert_values = {"source": "prometheus"} + + matched_row = matcher.get_matching_row(mapping_rule, alert_values) + + assert matched_row is not None + assert matched_row["team"] == "monitoring-team" + assert matched_row["owner"] == "monitoring-owner" + + +def test_get_matching_row_no_match(db_session: Session, mapping_rule: MappingRule): + """Test handling of no matches with DB session.""" + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test with values that shouldn't match anything + alert_values = { + "service": "unknown-service", + "severity": "low", # Not critical, so won't match the wildcard row + } + + matched_row = matcher.get_matching_row(mapping_rule, alert_values) + + assert matched_row is None + + +def test_get_matching_rows_multi_level( + db_session: Session, multi_level_rule: MappingRule +): + """Test multi-level matching with DB session.""" + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test with multiple customer values + customers = ["customer-1", "customer-3", "customer-not-exists"] + + matches = matcher.get_matching_rows_multi_level( + multi_level_rule, "customer", customers + ) + + assert len(matches) == 2 # Should match two customers + assert "customer-1" in matches + assert "customer-3" in matches + assert matches["customer-1"]["contact"] == "contact-1" + assert matches["customer-1"]["priority"] == "high" + assert matches["customer-3"]["contact"] == "contact-3" + assert matches["customer-3"]["priority"] == "low" + + +def test_multiple_rules_priority(db_session: Session): + """Test that rule priority works correctly.""" + # Create two rules with different priorities + rule1 = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=10, # Higher priority + name="High Priority Rule", + description="Rule with high priority", + matchers=[["service"]], + rows=[{"service": "shared-service", "result": "high-priority"}], + type="csv", + ) + + rule2 = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=5, # Lower priority + name="Low Priority Rule", + description="Rule with low priority", + matchers=[["service"]], + rows=[{"service": "shared-service", "result": "low-priority"}], + type="csv", + ) + + db_session.add(rule1) + db_session.add(rule2) + db_session.commit() + db_session.refresh(rule1) + db_session.refresh(rule2) + + # Use direct query to verify the order + rules = db_session.exec( + select(MappingRule) + .filter(MappingRule.tenant_id == SINGLE_TENANT_UUID) + .filter(MappingRule.name.in_(["High Priority Rule", "Low Priority Rule"])) + .order_by(MappingRule.priority.desc()) + ).all() + + assert len(rules) == 2 + assert rules[0].name == "High Priority Rule" + assert rules[1].name == "Low Priority Rule" + + # Also test with a matcher to ensure we get the high priority result + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + alert_values = {"service": "shared-service"} + + # Should match the higher priority rule + matched_row = matcher.get_matching_row(rule1, alert_values) + assert matched_row is not None + assert matched_row["result"] == "high-priority" + + +def test_large_dataset(db_session: Session): + """Test with a large dataset to ensure SQL optimization works.""" + # Create a rule with many rows + large_rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Large Dataset Rule", + description="Rule with many rows to test SQL performance", + matchers=[["id"]], + rows=[{"id": f"id-{i}", "value": f"value-{i}"} for i in range(1000)], + type="csv", + ) + + db_session.add(large_rule) + db_session.commit() + db_session.refresh(large_rule) + + matcher = MappingRuleMatcher( + dialect_name=db_session.bind.dialect.name, session=db_session + ) + + # Test finding a specific ID in the large dataset + alert_values = {"id": "id-500"} + + matched_row = matcher.get_matching_row(large_rule, alert_values) + assert matched_row is not None + assert matched_row["value"] == "value-500" + + # Test finding the last entry to ensure full scan works + alert_values = {"id": "id-999"} + + matched_row = matcher.get_matching_row(large_rule, alert_values) + assert matched_row is not None + assert matched_row["value"] == "value-999" diff --git a/tests/test_sql_mapping_performance.py b/tests/test_sql_mapping_performance.py new file mode 100644 index 0000000000..fd9a8c0b8a --- /dev/null +++ b/tests/test_sql_mapping_performance.py @@ -0,0 +1,282 @@ +import random +import string +import time + +import pytest +from sqlmodel import Session + +from keep.api.bl.enrichments_bl import EnrichmentsBl +from keep.api.bl.mapping_rule_matcher import MappingRuleMatcher +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.api.models.db.mapping import MappingRule +from tests.fixtures.client import test_app # noqa + + +def random_string(length=10): + """Generate a random string.""" + return "".join(random.choice(string.ascii_letters) for _ in range(length)) + + +@pytest.fixture +def large_mapping_rule(db_session: Session): + """Create a mapping rule with a large number of rows.""" + # Generate 1000 rows with unique keys and values + rows = [] + for i in range(1000): + rows.append( + { + "customer_id": f"customer-{i:04d}", + "name": f"Customer {i}", + "email": f"customer{i}@example.com", + "phone": f"555-{i:04d}", + "type": random.choice(["enterprise", "smb", "startup"]), + "region": random.choice(["us-east", "us-west", "eu", "asia"]), + "tier": random.choice(["free", "basic", "premium", "enterprise"]), + "support_level": random.choice(["basic", "standard", "premium"]), + "account_manager": random_string(), + "metadata": random_string(20), + } + ) + + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Large Customer Dataset", + description="A rule with many customer records", + matchers=[["customer_id"]], + rows=rows, + type="csv", + file_name="", + created_by="test", + condition="", + new_property_name="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +@pytest.fixture +def large_multi_level_rule(db_session: Session): + """Create a multi-level mapping rule with a large number of rows.""" + # Generate 1000 rows with unique keys and values + rows = [] + for i in range(1000): + service_id = f"service-{i}" + rows.append( + { + "service_id": service_id, + "name": f"Service {i}", + "owner": f"team-{i % 20}", + "status": random.choice(["active", "deprecated", "in-development"]), + "url": f"https://service-{i}.example.com", + "port": 8000 + i % 1000, + "version": f"{random.randint(1, 5)}.{random.randint(0, 9)}.{random.randint(0, 9)}", + "dependencies": random_string(), + "description": random_string(30), + } + ) + + rule = MappingRule( + tenant_id=SINGLE_TENANT_UUID, + priority=1, + name="Large Service Dataset", + description="A multi-level rule with many service records", + matchers=[["services"]], + rows=rows, + type="csv", + is_multi_level=True, + new_property_name="service_info", + file_name="", + created_by="test", + condition="", + prefix_to_remove="", + ) + db_session.add(rule) + db_session.commit() + db_session.refresh(rule) + return rule + + +def test_large_dataset_performance( + db_session: Session, large_mapping_rule: MappingRule +): + """Test performance with a large dataset.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Test case 1: Match an item at the beginning of the dataset + start_time = time.time() + alert_values = {"customer_id": "customer-0010"} + matched_row = matcher.get_matching_row(large_mapping_rule, alert_values) + beginning_time = time.time() - start_time + + assert matched_row is not None + assert matched_row["name"] == "Customer 10" + + # Test case 2: Match an item in the middle of the dataset + start_time = time.time() + alert_values = {"customer_id": "customer-0541"} + matched_row = matcher.get_matching_row(large_mapping_rule, alert_values) + middle_time = time.time() - start_time + + assert matched_row is not None + assert matched_row["name"] == "Customer 541" + + # Test case 3: Match an item at the end of the dataset + start_time = time.time() + alert_values = {"customer_id": "customer-0990"} + matched_row = matcher.get_matching_row(large_mapping_rule, alert_values) + end_time = time.time() - start_time + + assert matched_row is not None + assert matched_row["name"] == "Customer 990" + + # Test case 4: No match + start_time = time.time() + alert_values = {"customer_id": "non-existent"} + matched_row = matcher.get_matching_row(large_mapping_rule, alert_values) + no_match_time = time.time() - start_time + + assert matched_row is None + + # Log performance metrics + print("\nPerformance metrics for large dataset (1000 rows):") + print(f"Beginning match time: {beginning_time:.6f} seconds") + print(f"Middle match time: {middle_time:.6f} seconds") + print(f"End match time: {end_time:.6f} seconds") + print(f"No match time: {no_match_time:.6f} seconds") + + # Assert that all operations are reasonably fast (< 0.1 seconds) + # Adjust this threshold as needed based on your environment + assert beginning_time < 0.1, "Beginning match too slow" + assert middle_time < 0.1, "Middle match too slow" + assert end_time < 0.1, "End match too slow" + assert no_match_time < 0.1, "No match case too slow" + + +def test_multi_level_large_dataset_performance( + db_session: Session, large_multi_level_rule: MappingRule +): + """Test multi-level mapping performance with a large dataset.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Generate a list of 100 service IDs to match + service_ids = [f"service-{i}" for i in random.sample(range(1000), 100)] + + # Test multi-level matching performance + start_time = time.time() + matches = matcher.get_matching_rows_multi_level( + large_multi_level_rule, "service_id", service_ids + ) + multi_level_time = time.time() - start_time + + assert len(matches) == 100 + for service_id in service_ids: + assert service_id in matches + assert "name" in matches[service_id] + assert "owner" in matches[service_id] + + # Log performance metrics + print("\nMulti-level matching performance (100 matches out of 1000 rows):") + print(f"Matching time: {multi_level_time:.6f} seconds") + + # Assert that the operation is reasonably fast (< 0.5 seconds) + # Multi-level matching will be slower than single matches + assert multi_level_time < 0.5, "Multi-level matching too slow" + + +def test_comparison_with_fallback(db_session: Session, large_mapping_rule: MappingRule): + """Compare SQL-based matching with fallback Python implementation.""" + dialect_name = None + if ( + hasattr(db_session, "bind") + and db_session.bind is not None + and hasattr(db_session.bind, "dialect") + ): + dialect_name = db_session.bind.dialect.name + + matcher = MappingRuleMatcher(dialect_name=dialect_name, session=db_session) + + # Test with SQL-based matching + start_time = time.time() + alert_values = {"customer_id": "customer-0534"} + sql_matched_row = matcher.get_matching_row(large_mapping_rule, alert_values) + sql_time = time.time() - start_time + + # Test with fallback Python implementation + start_time = time.time() + fallback_matched_row = matcher._fallback_get_matching_row( + large_mapping_rule, alert_values + ) + fallback_time = time.time() - start_time + + # Verify both methods return the same result + assert sql_matched_row is not None + assert fallback_matched_row is not None + assert sql_matched_row["name"] == fallback_matched_row["name"] + + # Log performance comparison + print("\nSQL vs. Fallback performance comparison:") + print(f"SQL-based match time: {sql_time:.6f} seconds") + print(f"Fallback match time: {fallback_time:.6f} seconds") + print(f"Speed improvement: {fallback_time / sql_time:.2f}x") + + # Assert that SQL is faster than fallback + # The difference should be significant for large datasets + assert sql_time < fallback_time, "SQL-based matching should be faster than fallback" + + +def test_end_to_end_performance(db_session: Session, large_mapping_rule: MappingRule): + """Test end-to-end performance using EnrichmentsBl.""" + # Create alert dto + alert = AlertDto( + id="test-id", + name="Test Alert", + status=AlertStatus.FIRING, + severity=AlertSeverity.HIGH, + lastReceived="2023-01-01T00:00:00Z", + source=["test-source"], + fingerprint="test-fingerprint", + ) + + # Add customer_id as a dynamic attribute + setattr(alert, "customer_id", "customer-0999") # Match with row in the middle + + enrichment_bl = EnrichmentsBl(tenant_id=SINGLE_TENANT_UUID, db=db_session) + + # Time the full enrichment process + start_time = time.time() + result = enrichment_bl.check_if_match_and_enrich(alert, large_mapping_rule) + end_to_end_time = time.time() - start_time + + # Verify enrichment worked + assert result is True + assert hasattr(alert, "name") + assert hasattr(alert, "email") + assert hasattr(alert, "phone") + + # Log end-to-end performance + print("\nEnd-to-end enrichment performance:") + print(f"Enrichment time: {end_to_end_time:.6f} seconds") + + # Assert reasonable performance for end-to-end process + assert end_to_end_time < 0.2, "End-to-end enrichment too slow"