diff --git a/example.py b/example.py index a613eaf..085f75d 100644 --- a/example.py +++ b/example.py @@ -31,21 +31,21 @@ class SimpleInvoiceSchema(BaseModel): print("Available models:", list(list_models().keys())) # Use default model (claude-haiku) -# h = Harvestor(model="claude-haiku") +h = Harvestor(model="claude-haiku", validate=True) -# output = h.harvest_file( -# source="data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema -# ) - -# print(output.to_summary()) +output = h.harvest_file( + source="data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema +) +print(output.to_summary()) +print(output.validation) # Alternative: use OpenAI # h_openai = Harvestor(model="gpt-4o-mini") # output = h_openai.harvest_file("data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema) # Alternative: use local Ollama (free) or cloud Ollama -h_ollama = Harvestor(model="gemma3:4b-cloud") -output = h_ollama.harvest_file( - "data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema -) -print(output.to_summary()) +# h_ollama = Harvestor(model="gemma3:4b-cloud") +# output = h_ollama.harvest_file( +# "data/uploads/keep_for_test.jpg", schema=SimpleInvoiceSchema +# ) +# print(output.to_summary()) diff --git a/src/harvestor/__init__.py b/src/harvestor/__init__.py index f6b5668..c69eb12 100644 --- a/src/harvestor/__init__.py +++ b/src/harvestor/__init__.py @@ -36,6 +36,7 @@ ValidationResult, ) from .schemas.defaults import InvoiceData, LineItem, ReceiptData +from .validators import BaseValidationRule, RuleFinding, RuleSeverity, ValidationEngine __all__ = [ "__version__", @@ -48,6 +49,11 @@ "ExtractionStrategy", "HarvestResult", "ValidationResult", + # Validation + "ValidationEngine", + "BaseValidationRule", + "RuleFinding", + "RuleSeverity", # Output schemas "InvoiceData", "ReceiptData", diff --git a/src/harvestor/cli/main.py b/src/harvestor/cli/main.py index a4302f9..a34b2e1 100644 --- a/src/harvestor/cli/main.py +++ b/src/harvestor/cli/main.py @@ -55,6 +55,11 @@ def build_parser(): action="store_true", help="List available schemas and exit", ) + parser.add_argument( + "--validate", + action="store_true", + help="Run validation rules on extracted data", + ) return parser @@ -151,6 +156,7 @@ def main(): source=args.file_path, schema=schema, model=args.model, + validate=args.validate, ) if not result.success: @@ -158,7 +164,23 @@ def main(): sys.exit(1) indent = 2 if args.pretty else None - output = json.dumps(result.data, indent=indent, default=str) + + if result.validation: + full_output = { + "data": result.data, + "validation": { + "is_valid": result.validation.is_valid, + "confidence": result.validation.confidence, + "fraud_risk": result.validation.fraud_risk, + "errors": result.validation.errors, + "warnings": result.validation.warnings, + "fraud_reasons": result.validation.fraud_reasons, + "rules_checked": result.validation.rules_checked, + }, + } + output = json.dumps(full_output, indent=indent, default=str) + else: + output = json.dumps(result.data, indent=indent, default=str) if args.output: args.output.write_text(output) diff --git a/src/harvestor/core/harvestor.py b/src/harvestor/core/harvestor.py index 7e8eda6..b817224 100644 --- a/src/harvestor/core/harvestor.py +++ b/src/harvestor/core/harvestor.py @@ -38,6 +38,8 @@ def __init__( cost_limit_per_doc: float = 0.10, daily_cost_limit: Optional[float] = None, base_url: Optional[str] = None, + validate: bool = False, + validation_rules: Optional[List] = None, ): """ Initialize Harvestor. @@ -48,6 +50,8 @@ def __init__( cost_limit_per_doc: Maximum cost per document (default: $0.10) daily_cost_limit: Optional daily cost limit base_url: Optional base URL override for the provider + validate: Run validation rules on extracted data (default: False) + validation_rules: Custom validation rules (used with validate=True) """ self.model_name = model self.api_key = api_key @@ -61,6 +65,24 @@ def __init__( # Initialize LLM parser (handles provider selection) self.llm_parser = LLMParser(model=model, api_key=api_key, base_url=base_url) + # Initialize validation engine if enabled + self._validate = validate + self._validation_engine = None + if validate: + from ..validators import ValidationEngine + + self._validation_engine = ValidationEngine(rules=validation_rules) + + def _maybe_validate( + self, result: HarvestResult, schema: Type[BaseModel] + ) -> HarvestResult: + """Run validation if enabled and extraction succeeded.""" + if self._validate and self._validation_engine and result.success: + result.validation = self._validation_engine.validate( + data=result.data, schema=schema + ) + return result + @staticmethod def get_doc_type_from_schema(schema: Type[BaseModel]) -> str: """ @@ -120,7 +142,7 @@ def harvest_text( total_time = time.time() - start_time - return HarvestResult( + result = HarvestResult( success=extraction_result.success, document_id=document_id, document_type=doc_type, @@ -134,6 +156,7 @@ def harvest_text( error=extraction_result.error, language=language, ) + return self._maybe_validate(result, schema) def harvest_file( self, @@ -349,7 +372,7 @@ def _harvest_image( processing_time = time.time() - start_time - return HarvestResult( + result = HarvestResult( success=extraction_result.success, document_id=document_id, document_type=doc_type, @@ -363,6 +386,7 @@ def _harvest_image( error=extraction_result.error, language=language, ) + return self._maybe_validate(result, schema) def harvest_batch( self, @@ -417,6 +441,8 @@ def harvest( api_key: Optional[str] = None, filename: Optional[str] = None, base_url: Optional[str] = None, + validate: bool = False, + validation_rules: Optional[List] = None, ) -> HarvestResult: """ One-liner function for quick extraction. @@ -435,8 +461,9 @@ def harvest( # With OpenAI result = harvest("invoice.jpg", schema=InvoiceData, model="gpt-4o-mini") - # With local Ollama - result = harvest("invoice.txt", schema=InvoiceData, model="llama3") + # With validation + result = harvest("invoice.pdf", schema=InvoiceData, validate=True) + print(result.validation.fraud_risk) ``` Args: @@ -448,11 +475,19 @@ def harvest( api_key: API key (uses env var if not provided) filename: Original filename (required when source is bytes/file-like) base_url: Optional base URL override + validate: Run validation rules on extracted data (default: False) + validation_rules: Custom validation rules (used with validate=True) Returns: HarvestResult with extracted data """ - harvestor = Harvestor(api_key=api_key, model=model, base_url=base_url) + harvestor = Harvestor( + api_key=api_key, + model=model, + base_url=base_url, + validate=validate, + validation_rules=validation_rules, + ) return harvestor.harvest_file( source=source, schema=schema, diff --git a/src/harvestor/validators/__init__.py b/src/harvestor/validators/__init__.py index e69de29..5e1df96 100644 --- a/src/harvestor/validators/__init__.py +++ b/src/harvestor/validators/__init__.py @@ -0,0 +1,36 @@ +""" +Validation engine for extracted document data. + +Provides rule-based validation, fraud detection, and anomaly flagging +for data extracted by Harvestor. +""" + +from ..schemas.base import ValidationResult +from .base import BaseValidationRule, RuleFinding, RuleSeverity +from .engine import ValidationEngine + + +def validate(data, schema, rules=None, include_defaults=True) -> ValidationResult: + """ + One-liner validation function. + + Args: + data: Extracted data dict + schema: Pydantic schema class + rules: Optional custom rules + include_defaults: Include built-in rules (default: True) + + Returns: + ValidationResult + """ + engine = ValidationEngine(rules=rules, include_defaults=include_defaults) + return engine.validate(data, schema) + + +__all__ = [ + "BaseValidationRule", + "RuleFinding", + "RuleSeverity", + "ValidationEngine", + "validate", +] diff --git a/src/harvestor/validators/base.py b/src/harvestor/validators/base.py new file mode 100644 index 0000000..3a2a16f --- /dev/null +++ b/src/harvestor/validators/base.py @@ -0,0 +1,73 @@ +"""Base abstractions for validation rules.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Type + +from pydantic import BaseModel + + +class RuleSeverity(str, Enum): + """Severity level for a rule finding.""" + + ERROR = "error" + WARNING = "warning" + INFO = "info" + + +@dataclass +class RuleFinding: + """A single finding from a validation rule.""" + + rule_name: str + severity: RuleSeverity + message: str + field_name: Optional[str] = None + confidence_impact: float = 0.0 + is_fraud_signal: bool = False + fraud_weight: float = 0.0 + + +class BaseValidationRule(ABC): + """Abstract base class for all validation rules.""" + + @property + @abstractmethod + def name(self) -> str: + """Unique name for this rule.""" + ... + + @property + @abstractmethod + def description(self) -> str: + """Human-readable description of what this rule checks.""" + ... + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + """Set of schema types this rule applies to. None means all schemas.""" + return None + + def applies_to(self, schema: Type[BaseModel]) -> bool: + """Check if this rule applies to the given schema.""" + supported = self.supported_schemas + if supported is None: + return True + return any(issubclass(schema, s) for s in supported) + + @abstractmethod + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + """ + Run this rule against extracted data. + + Args: + data: The extracted data dict (from HarvestResult.data) + schema: The Pydantic schema class used for extraction + + Returns: + List of findings (empty list means rule passed) + """ + ... diff --git a/src/harvestor/validators/engine.py b/src/harvestor/validators/engine.py new file mode 100644 index 0000000..4122866 --- /dev/null +++ b/src/harvestor/validators/engine.py @@ -0,0 +1,124 @@ +"""Validation engine that runs rules and produces ValidationResult.""" + +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +from pydantic import BaseModel + +from ..schemas.base import ValidationResult +from .base import BaseValidationRule, RuleFinding, RuleSeverity + + +class ValidationEngine: + """ + Runs validation rules against extracted data and produces a ValidationResult. + + Usage: + engine = ValidationEngine() # loads built-in rules + result = engine.validate(data, schema=InvoiceData) + + # Custom rules only: + engine = ValidationEngine(rules=[MyRule()], include_defaults=False) + """ + + def __init__( + self, + rules: Optional[List[BaseValidationRule]] = None, + include_defaults: bool = True, + ): + self._rules: List[BaseValidationRule] = [] + if include_defaults: + self._rules.extend(self._get_default_rules()) + if rules: + self._rules.extend(rules) + + @staticmethod + def _get_default_rules() -> List[BaseValidationRule]: + from .rules import get_all_default_rules + + return get_all_default_rules() + + @property + def rules(self) -> List[BaseValidationRule]: + return list(self._rules) + + def add_rule(self, rule: BaseValidationRule) -> None: + self._rules.append(rule) + + def remove_rule(self, rule_name: str) -> None: + self._rules = [r for r in self._rules if r.name != rule_name] + + def validate( + self, + data: Dict[str, Any], + schema: Type[BaseModel], + ) -> ValidationResult: + all_findings: List[RuleFinding] = [] + rules_checked: List[str] = [] + + for rule in self._rules: + if not rule.applies_to(schema): + continue + rules_checked.append(rule.name) + try: + findings = rule.validate(data, schema) + all_findings.extend(findings) + except Exception: + all_findings.append( + RuleFinding( + rule_name=rule.name, + severity=RuleSeverity.WARNING, + message=f"Rule '{rule.name}' raised an exception and was skipped", + ) + ) + + return self._build_result(all_findings, rules_checked) + + def _build_result( + self, + findings: List[RuleFinding], + rules_checked: List[str], + ) -> ValidationResult: + errors = [f.message for f in findings if f.severity == RuleSeverity.ERROR] + warnings = [f.message for f in findings if f.severity == RuleSeverity.WARNING] + + confidence = 1.0 + for f in findings: + confidence -= f.confidence_impact + confidence = max(0.0, min(1.0, confidence)) + + fraud_findings = [f for f in findings if f.is_fraud_signal] + fraud_checked = len(rules_checked) > 0 + fraud_reasons = [f.message for f in fraud_findings] + fraud_risk = self._calculate_fraud_risk(fraud_findings) + + return ValidationResult( + is_valid=len(errors) == 0, + confidence=confidence, + errors=errors, + warnings=warnings, + fraud_checked=fraud_checked, + fraud_risk=fraud_risk, + fraud_reasons=fraud_reasons, + cost=0.0, + rules_checked=rules_checked, + timestamp=datetime.now(), + ) + + @staticmethod + def _calculate_fraud_risk(fraud_findings: List[RuleFinding]) -> str: + if not fraud_findings: + return "clean" + + total_weight = min(1.0, sum(f.fraud_weight for f in fraud_findings)) + + if total_weight < 0.01: + return "clean" + elif total_weight < 0.2: + return "low" + elif total_weight < 0.5: + return "medium" + elif total_weight < 0.8: + return "high" + else: + return "critical" diff --git a/src/harvestor/validators/rules/__init__.py b/src/harvestor/validators/rules/__init__.py new file mode 100644 index 0000000..8015c42 --- /dev/null +++ b/src/harvestor/validators/rules/__init__.py @@ -0,0 +1,48 @@ +"""Built-in validation rules.""" + +from .anomaly_rules import DuplicateLineItemRule, ExtremeQuantityRule, RoundNumberRule +from .business_rules import ( + AmountThresholdRule, + DueDateAfterIssueDateRule, + EmptyLineItemsRule, + NegativeAmountsRule, + RequiredFieldsRule, +) +from .format_rules import ( + CardLastFourRule, + CurrencyCodeRule, + DateFormatRule, + TaxIdFormatRule, +) +from .math_rules import ( + LineItemMathRule, + LineItemsSumRule, + SubtotalTaxTotalRule, + TaxConsistencyRule, +) + + +def get_all_default_rules(): + """Instantiate all built-in rules with default configuration.""" + return [ + # Math + LineItemsSumRule(), + SubtotalTaxTotalRule(), + LineItemMathRule(), + TaxConsistencyRule(), + # Format + DateFormatRule(), + CurrencyCodeRule(), + TaxIdFormatRule(), + CardLastFourRule(), + # Business + RequiredFieldsRule(), + DueDateAfterIssueDateRule(), + NegativeAmountsRule(), + AmountThresholdRule(), + EmptyLineItemsRule(), + # Anomaly + RoundNumberRule(), + DuplicateLineItemRule(), + ExtremeQuantityRule(), + ] diff --git a/src/harvestor/validators/rules/anomaly_rules.py b/src/harvestor/validators/rules/anomaly_rules.py new file mode 100644 index 0000000..184fde8 --- /dev/null +++ b/src/harvestor/validators/rules/anomaly_rules.py @@ -0,0 +1,180 @@ +"""Anomaly detection rules for fraud signals.""" + +from typing import Any, Dict, List, Optional, Set, Type + +from pydantic import BaseModel + +from ...schemas.defaults import InvoiceData, ReceiptData +from ..base import BaseValidationRule, RuleFinding, RuleSeverity + + +class RoundNumberRule(BaseValidationRule): + """Flag suspiciously round total amounts.""" + + def __init__(self, min_amount: float = 1000.0): + self.min_amount = min_amount + + @property + def name(self) -> str: + return "round_number_anomaly" + + @property + def description(self) -> str: + return "Flags suspiciously round total amounts above threshold" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + total_field = "total_amount" if issubclass(schema, InvoiceData) else "total" + total = data.get(total_field) + + if total is None or not isinstance(total, (int, float)): + return findings + + if total < self.min_amount: + return findings + + # Check if the amount is exactly round (no cents, divisible by 1000) + if total == int(total) and int(total) % 1000 == 0: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=( + f"Total amount ({total:.2f}) is a suspiciously round number" + ), + field_name=total_field, + confidence_impact=0.05, + is_fraud_signal=True, + fraud_weight=0.15, + ) + ) + + return findings + + +class DuplicateLineItemRule(BaseValidationRule): + """Detect duplicate line items with identical name and amount.""" + + @property + def name(self) -> str: + return "duplicate_line_items" + + @property + def description(self) -> str: + return "Detects line items with identical name and amount" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + if not items or not isinstance(items, list): + return findings + + seen = {} + for i, item in enumerate(items): + if not isinstance(item, dict): + continue + key = (item.get("name"), item.get("amount")) + if key[0] is None and key[1] is None: + continue + if key in seen: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=( + f"Duplicate line item: '{key[0]}' with amount {key[1]} " + f"appears at positions {seen[key]} and {i}" + ), + field_name=f"{items_key}[{i}]", + confidence_impact=0.05, + is_fraud_signal=True, + fraud_weight=0.2, + ) + ) + else: + seen[key] = i + + return findings + + +class ExtremeQuantityRule(BaseValidationRule): + """Flag line items with extreme quantities.""" + + def __init__(self, max_quantity: float = 10_000.0): + self.max_quantity = max_quantity + + @property + def name(self) -> str: + return "extreme_quantity" + + @property + def description(self) -> str: + return "Flags line items with quantities outside normal range" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + if not items or not isinstance(items, list): + return findings + + for i, item in enumerate(items): + if not isinstance(item, dict): + continue + quantity = item.get("quantity") + if quantity is None or not isinstance(quantity, (int, float)): + continue + + item_name = item.get("name", f"item #{i + 1}") + + if quantity < 0: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Line item '{item_name}' has negative quantity: {quantity}", + field_name=f"{items_key}[{i}].quantity", + confidence_impact=0.1, + is_fraud_signal=True, + fraud_weight=0.15, + ) + ) + elif quantity > self.max_quantity: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=( + f"Line item '{item_name}' has extreme quantity: " + f"{quantity} (threshold: {self.max_quantity})" + ), + field_name=f"{items_key}[{i}].quantity", + confidence_impact=0.05, + is_fraud_signal=True, + fraud_weight=0.15, + ) + ) + + return findings diff --git a/src/harvestor/validators/rules/business_rules.py b/src/harvestor/validators/rules/business_rules.py new file mode 100644 index 0000000..ceee676 --- /dev/null +++ b/src/harvestor/validators/rules/business_rules.py @@ -0,0 +1,258 @@ +"""Business logic validation rules.""" + +from typing import Any, Dict, List, Optional, Set, Type + +from pydantic import BaseModel + +from ...schemas.defaults import InvoiceData, ReceiptData +from ..base import BaseValidationRule, RuleFinding, RuleSeverity + + +def _try_parse_date(value: str): + """Try to parse a date string into a comparable tuple (year, month, day).""" + import re + from datetime import datetime + + if not isinstance(value, str): + return None + + # Try ISO format first: YYYY-MM-DD + m = re.match(r"(\d{4})-(\d{2})-(\d{2})", value) + if m: + return (int(m.group(1)), int(m.group(2)), int(m.group(3))) + + # Try common strptime formats + for fmt in [ + "%m/%d/%Y", + "%d/%m/%Y", + "%m-%d-%Y", + "%d-%m-%Y", + "%B %d, %Y", + "%B %d %Y", + "%d %B %Y", + "%b %d, %Y", + "%b %d %Y", + "%d %b %Y", + "%d.%m.%Y", + ]: + try: + dt = datetime.strptime(value.strip().replace(",", ","), fmt) + return (dt.year, dt.month, dt.day) + except ValueError: + continue + + return None + + +class RequiredFieldsRule(BaseValidationRule): + """Check that critical fields are present and non-null.""" + + @property + def name(self) -> str: + return "required_fields_present" + + @property + def description(self) -> str: + return "Verifies that critical fields are present and non-null" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + if issubclass(schema, InvoiceData): + required = ["invoice_number", "date", "total_amount", "vendor_name"] + else: + required = ["merchant_name", "date", "total"] + + for field_name in required: + value = data.get(field_name) + if value is None or (isinstance(value, str) and not value.strip()): + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Required field '{field_name}' is missing or empty", + field_name=field_name, + confidence_impact=0.05, + ) + ) + + return findings + + +class DueDateAfterIssueDateRule(BaseValidationRule): + """Check that due_date is on or after issue date.""" + + @property + def name(self) -> str: + return "due_date_after_issue_date" + + @property + def description(self) -> str: + return "Verifies that due_date is on or after the issue date" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + date_str = data.get("date") + due_date_str = data.get("due_date") + + if date_str is None or due_date_str is None: + return findings + + issue_date = _try_parse_date(str(date_str)) + due_date = _try_parse_date(str(due_date_str)) + + if issue_date is None or due_date is None: + return findings + + if due_date < issue_date: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Due date ({due_date_str}) is before issue date ({date_str})", + field_name="due_date", + confidence_impact=0.1, + is_fraud_signal=True, + fraud_weight=0.15, + ) + ) + + return findings + + +class NegativeAmountsRule(BaseValidationRule): + """Check that monetary amounts are non-negative.""" + + @property + def name(self) -> str: + return "no_negative_amounts" + + @property + def description(self) -> str: + return "Verifies that monetary amounts are non-negative" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + if issubclass(schema, InvoiceData): + amount_fields = ["total_amount", "subtotal", "tax_amount"] + else: + amount_fields = ["total", "subtotal", "tax"] + + for field_name in amount_fields: + value = data.get(field_name) + if value is not None and isinstance(value, (int, float)) and value < 0: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.ERROR, + message=f"Field '{field_name}' has negative value: {value}", + field_name=field_name, + confidence_impact=0.15, + is_fraud_signal=True, + fraud_weight=0.3, + ) + ) + + return findings + + +class AmountThresholdRule(BaseValidationRule): + """Flag documents with unusually high total amounts.""" + + def __init__(self, threshold: float = 100_000.0): + self.threshold = threshold + + @property + def name(self) -> str: + return "amount_threshold" + + @property + def description(self) -> str: + return f"Flags documents with total amount exceeding {self.threshold}" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + total_field = "total_amount" if issubclass(schema, InvoiceData) else "total" + total = data.get(total_field) + + if ( + total is not None + and isinstance(total, (int, float)) + and total > self.threshold + ): + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Total amount ({total:.2f}) exceeds threshold ({self.threshold:.2f})", + field_name=total_field, + confidence_impact=0.05, + is_fraud_signal=True, + fraud_weight=0.1, + ) + ) + + return findings + + +class EmptyLineItemsRule(BaseValidationRule): + """Check that line items list is not empty when present.""" + + @property + def name(self) -> str: + return "line_items_not_empty" + + @property + def description(self) -> str: + return "Verifies that line items list is not empty when present" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + + if items is not None and isinstance(items, list) and len(items) == 0: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"'{items_key}' is present but empty", + field_name=items_key, + confidence_impact=0.05, + ) + ) + + return findings diff --git a/src/harvestor/validators/rules/format_rules.py b/src/harvestor/validators/rules/format_rules.py new file mode 100644 index 0000000..96cf8b2 --- /dev/null +++ b/src/harvestor/validators/rules/format_rules.py @@ -0,0 +1,232 @@ +"""Date, currency, tax ID, and card format validation rules.""" + +import re +from typing import Any, Dict, List, Optional, Set, Type + +from pydantic import BaseModel + +from ...schemas.defaults import InvoiceData, ReceiptData +from ..base import BaseValidationRule, RuleFinding, RuleSeverity + +# Common date patterns that LLMs produce +_DATE_PATTERNS = [ + r"\d{4}-\d{2}-\d{2}", # 2024-01-15 + r"\d{2}/\d{2}/\d{4}", # 01/15/2024 or 15/01/2024 + r"\d{2}-\d{2}-\d{4}", # 01-15-2024 + r"\d{1,2}\s+\w+\s+\d{4}", # 15 January 2024 + r"\w+\s+\d{1,2},?\s+\d{4}", # January 15, 2024 + r"\d{2}\.\d{2}\.\d{4}", # 15.01.2024 +] + +_DATE_REGEX = re.compile("|".join(f"(?:{p})" for p in _DATE_PATTERNS)) + +# Common ISO 4217 currency codes +_VALID_CURRENCIES = frozenset( + { + "USD", + "EUR", + "GBP", + "JPY", + "CHF", + "CAD", + "AUD", + "NZD", + "CNY", + "HKD", + "SGD", + "SEK", + "NOK", + "DKK", + "KRW", + "INR", + "BRL", + "MXN", + "ZAR", + "RUB", + "TRY", + "PLN", + "CZK", + "HUF", + "ILS", + "THB", + "MYR", + "PHP", + "IDR", + "TWD", + "AED", + "SAR", + "ARS", + "CLP", + "COP", + "PEN", + "EGP", + "NGN", + "KES", + "MAD", + } +) + + +class DateFormatRule(BaseValidationRule): + """Check that date fields are parseable.""" + + @property + def name(self) -> str: + return "date_format_valid" + + @property + def description(self) -> str: + return "Verifies that date fields match common date patterns" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + date_fields = ["date"] + if issubclass(schema, InvoiceData): + date_fields.append("due_date") + + for field_name in date_fields: + value = data.get(field_name) + if value is None: + continue + if not isinstance(value, str): + continue + if not _DATE_REGEX.search(value): + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Field '{field_name}' value '{value}' does not match common date formats", + field_name=field_name, + confidence_impact=0.05, + ) + ) + + return findings + + +class CurrencyCodeRule(BaseValidationRule): + """Check that currency code is a valid ISO 4217 code.""" + + @property + def name(self) -> str: + return "currency_code_valid" + + @property + def description(self) -> str: + return "Verifies that currency field is a valid ISO 4217 code" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + currency = data.get("currency") + if currency is None: + return findings + + if not isinstance(currency, str): + return findings + + normalized = currency.strip().upper() + if normalized not in _VALID_CURRENCIES: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Currency code '{currency}' is not a recognized ISO 4217 code", + field_name="currency", + confidence_impact=0.05, + ) + ) + + return findings + + +class TaxIdFormatRule(BaseValidationRule): + """Check that vendor tax ID has a reasonable format.""" + + @property + def name(self) -> str: + return "tax_id_format_valid" + + @property + def description(self) -> str: + return "Verifies that vendor tax ID is non-empty and has minimum length" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + tax_id = data.get("vendor_tax_id") + if tax_id is None: + return findings + + if not isinstance(tax_id, str): + return findings + + cleaned = tax_id.strip() + if len(cleaned) < 5: + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"Vendor tax ID '{tax_id}' is suspiciously short (< 5 characters)", + field_name="vendor_tax_id", + confidence_impact=0.05, + is_fraud_signal=True, + fraud_weight=0.1, + ) + ) + + return findings + + +class CardLastFourRule(BaseValidationRule): + """Check that card_last_four is exactly 4 digits.""" + + @property + def name(self) -> str: + return "card_last_four_format" + + @property + def description(self) -> str: + return "Verifies that card_last_four is exactly 4 digits" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + card = data.get("card_last_four") + if card is None: + return findings + + card_str = str(card).strip() + if not re.fullmatch(r"\d{4}", card_str): + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=f"card_last_four '{card}' is not exactly 4 digits", + field_name="card_last_four", + confidence_impact=0.05, + ) + ) + + return findings diff --git a/src/harvestor/validators/rules/math_rules.py b/src/harvestor/validators/rules/math_rules.py new file mode 100644 index 0000000..5d1ff6d --- /dev/null +++ b/src/harvestor/validators/rules/math_rules.py @@ -0,0 +1,245 @@ +"""Cross-field arithmetic consistency rules.""" + +from typing import Any, Dict, List, Optional, Set, Type + +from pydantic import BaseModel + +from ...schemas.defaults import InvoiceData, ReceiptData +from ..base import BaseValidationRule, RuleFinding, RuleSeverity + + +class LineItemsSumRule(BaseValidationRule): + """Check that line item amounts sum to subtotal.""" + + def __init__(self, tolerance: float = 0.02): + self.tolerance = tolerance + + @property + def name(self) -> str: + return "line_items_sum_to_subtotal" + + @property + def description(self) -> str: + return "Verifies that the sum of line item amounts equals the subtotal" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + subtotal = data.get("subtotal") + + if items is None or subtotal is None: + return findings + + computed_sum = sum( + item.get("amount", 0) or 0 for item in items if isinstance(item, dict) + ) + + diff = abs(computed_sum - subtotal) + if diff > self.tolerance: + severity = RuleSeverity.ERROR if diff > 1.0 else RuleSeverity.WARNING + findings.append( + RuleFinding( + rule_name=self.name, + severity=severity, + message=( + f"Line items sum ({computed_sum:.2f}) does not match " + f"subtotal ({subtotal:.2f}), diff={diff:.2f}" + ), + field_name="subtotal", + confidence_impact=0.15 if severity == RuleSeverity.ERROR else 0.05, + is_fraud_signal=severity == RuleSeverity.ERROR, + fraud_weight=0.2 if severity == RuleSeverity.ERROR else 0.0, + ) + ) + + return findings + + +class SubtotalTaxTotalRule(BaseValidationRule): + """Check that subtotal + tax - discount = total.""" + + def __init__(self, tolerance: float = 0.02): + self.tolerance = tolerance + + @property + def name(self) -> str: + return "subtotal_plus_tax_equals_total" + + @property + def description(self) -> str: + return "Verifies that subtotal + tax - discount equals total" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + subtotal = data.get("subtotal") + + # InvoiceData uses tax_amount/total_amount/discount, ReceiptData uses tax/total + if issubclass(schema, InvoiceData): + tax = data.get("tax_amount") + total = data.get("total_amount") + discount = data.get("discount") or 0.0 + else: + tax = data.get("tax") + total = data.get("total") + discount = 0.0 + + if subtotal is None or total is None: + return findings + + tax = tax or 0.0 + expected_total = subtotal + tax - discount + + diff = abs(expected_total - total) + if diff > self.tolerance: + severity = RuleSeverity.ERROR if diff > 1.0 else RuleSeverity.WARNING + findings.append( + RuleFinding( + rule_name=self.name, + severity=severity, + message=( + f"Subtotal ({subtotal:.2f}) + tax ({tax:.2f}) - discount ({discount:.2f}) " + f"= {expected_total:.2f}, but total is {total:.2f}, diff={diff:.2f}" + ), + field_name="total_amount" + if issubclass(schema, InvoiceData) + else "total", + confidence_impact=0.15 if severity == RuleSeverity.ERROR else 0.05, + is_fraud_signal=severity == RuleSeverity.ERROR, + fraud_weight=0.25 if severity == RuleSeverity.ERROR else 0.0, + ) + ) + + return findings + + +class LineItemMathRule(BaseValidationRule): + """Check that quantity * unit_price ~= amount for each line item.""" + + def __init__(self, tolerance: float = 0.02): + self.tolerance = tolerance + + @property + def name(self) -> str: + return "line_item_internal_math" + + @property + def description(self) -> str: + return "Verifies that quantity * unit_price equals amount for each line item" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + if not items: + return findings + + for i, item in enumerate(items): + if not isinstance(item, dict): + continue + + quantity = item.get("quantity") + amount = item.get("amount") + unit_price = item.get("unit_price_with_taxes") or item.get( + "unit_price_without_taxes" + ) + + if quantity is not None and unit_price is not None and amount is not None: + expected = quantity * unit_price + diff = abs(expected - amount) + if diff > self.tolerance: + item_name = item.get("name", f"item #{i + 1}") + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=( + f"Line item '{item_name}': quantity ({quantity}) * " + f"unit_price ({unit_price:.2f}) = {expected:.2f}, " + f"but amount is {amount:.2f}" + ), + field_name=f"{items_key}[{i}].amount", + confidence_impact=0.05, + ) + ) + + return findings + + +class TaxConsistencyRule(BaseValidationRule): + """Check that taxes match taxes_percentage * base price for line items.""" + + def __init__(self, tolerance: float = 0.02): + self.tolerance = tolerance + + @property + def name(self) -> str: + return "tax_percentage_consistency" + + @property + def description(self) -> str: + return "Verifies that taxes match taxes_percentage * base price for line items" + + @property + def supported_schemas(self) -> Optional[Set[Type[BaseModel]]]: + return {InvoiceData, ReceiptData} + + def validate( + self, data: Dict[str, Any], schema: Type[BaseModel] + ) -> List[RuleFinding]: + findings = [] + + items_key = "line_items" if "line_items" in data else "items" + items = data.get(items_key) + if not items: + return findings + + for i, item in enumerate(items): + if not isinstance(item, dict): + continue + + taxes = item.get("taxes") + taxes_pct = item.get("taxes_percentage") + base_price = item.get("unit_price_without_taxes") or item.get("amount") + + if taxes is not None and taxes_pct is not None and base_price is not None: + expected_taxes = base_price * taxes_pct / 100.0 + diff = abs(expected_taxes - taxes) + if diff > self.tolerance: + item_name = item.get("name", f"item #{i + 1}") + findings.append( + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message=( + f"Line item '{item_name}': expected taxes " + f"{expected_taxes:.2f} ({taxes_pct}% of {base_price:.2f}), " + f"but got {taxes:.2f}" + ), + field_name=f"{items_key}[{i}].taxes", + confidence_impact=0.05, + ) + ) + + return findings diff --git a/tests/conftest.py b/tests/conftest.py index fa0eb55..ae380ee 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -144,6 +144,58 @@ def temp_output_dir(tmp_path) -> Path: return output_dir +@pytest.fixture +def valid_invoice_data() -> Dict: + """Provide a mathematically consistent invoice data dict.""" + return { + "invoice_number": "INV-2024-001", + "date": "2024-01-15", + "due_date": "2024-02-15", + "vendor_name": "Tech Solutions Inc.", + "vendor_tax_id": "US123456789", + "customer_name": "Acme Corp", + "customer_address": "123 Business St", + "line_items": [ + { + "name": "Service A", + "amount": 100.00, + "quantity": 2, + "unit_price_with_taxes": 50.00, + }, + { + "name": "Service B", + "amount": 200.00, + "quantity": 1, + "unit_price_with_taxes": 200.00, + }, + ], + "subtotal": 300.00, + "tax_amount": 24.00, + "discount": 0.0, + "total_amount": 324.00, + "currency": "USD", + } + + +@pytest.fixture +def valid_receipt_data() -> Dict: + """Provide a mathematically consistent receipt data dict.""" + return { + "merchant_name": "Coffee Shop", + "date": "2024-03-15", + "time": "14:30", + "items": [ + {"name": "Latte", "amount": 5.50, "quantity": 1}, + {"name": "Muffin", "amount": 3.50, "quantity": 2}, + ], + "subtotal": 9.00, + "tax": 0.72, + "total": 9.72, + "payment_method": "credit_card", + "card_last_four": "4242", + } + + @pytest.fixture(autouse=True) def reset_cost_tracker(): """Automatically reset cost tracker before each test.""" diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..6cc84bb --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,607 @@ +"""Tests for the validation rules engine.""" + +from unittest.mock import MagicMock, patch + +from harvestor import InvoiceData +from harvestor.schemas.base import ValidationResult +from harvestor.schemas.defaults import ReceiptData +from harvestor.validators import ValidationEngine, validate +from harvestor.validators.base import BaseValidationRule, RuleFinding, RuleSeverity +from harvestor.validators.rules.anomaly_rules import ( + DuplicateLineItemRule, + ExtremeQuantityRule, + RoundNumberRule, +) +from harvestor.validators.rules.business_rules import ( + AmountThresholdRule, + DueDateAfterIssueDateRule, + EmptyLineItemsRule, + NegativeAmountsRule, + RequiredFieldsRule, +) +from harvestor.validators.rules.format_rules import ( + CardLastFourRule, + CurrencyCodeRule, + DateFormatRule, + TaxIdFormatRule, +) +from harvestor.validators.rules.math_rules import ( + LineItemMathRule, + LineItemsSumRule, + SubtotalTaxTotalRule, + TaxConsistencyRule, +) + + +# --------------------------------------------------------------------------- +# Base abstractions +# --------------------------------------------------------------------------- + + +class TestRuleFindingAndSeverity: + def test_severity_values(self): + assert RuleSeverity.ERROR == "error" + assert RuleSeverity.WARNING == "warning" + assert RuleSeverity.INFO == "info" + + def test_rule_finding_construction(self): + f = RuleFinding( + rule_name="test", + severity=RuleSeverity.ERROR, + message="something wrong", + field_name="total", + confidence_impact=0.1, + is_fraud_signal=True, + fraud_weight=0.3, + ) + assert f.rule_name == "test" + assert f.severity == RuleSeverity.ERROR + assert f.is_fraud_signal is True + + def test_rule_finding_defaults(self): + f = RuleFinding(rule_name="x", severity=RuleSeverity.INFO, message="ok") + assert f.field_name is None + assert f.confidence_impact == 0.0 + assert f.is_fraud_signal is False + assert f.fraud_weight == 0.0 + + +# --------------------------------------------------------------------------- +# Validation Engine +# --------------------------------------------------------------------------- + + +class _DummyRule(BaseValidationRule): + """A simple test rule that always produces one warning.""" + + @property + def name(self): + return "dummy_rule" + + @property + def description(self): + return "dummy" + + def validate(self, data, schema): + return [ + RuleFinding( + rule_name=self.name, + severity=RuleSeverity.WARNING, + message="dummy warning", + ) + ] + + +class _CrashingRule(BaseValidationRule): + @property + def name(self): + return "crashing_rule" + + @property + def description(self): + return "always crashes" + + def validate(self, data, schema): + raise RuntimeError("boom") + + +class TestValidationEngine: + def test_engine_loads_default_rules(self): + engine = ValidationEngine() + assert len(engine.rules) > 0 + + def test_engine_custom_rules_only(self): + engine = ValidationEngine(rules=[_DummyRule()], include_defaults=False) + assert len(engine.rules) == 1 + assert engine.rules[0].name == "dummy_rule" + + def test_engine_add_remove_rule(self): + engine = ValidationEngine(include_defaults=False) + assert len(engine.rules) == 0 + + engine.add_rule(_DummyRule()) + assert len(engine.rules) == 1 + + engine.remove_rule("dummy_rule") + assert len(engine.rules) == 0 + + def test_engine_returns_validation_result(self, valid_invoice_data): + engine = ValidationEngine() + result = engine.validate(valid_invoice_data, InvoiceData) + assert isinstance(result, ValidationResult) + + def test_engine_handles_crashing_rule(self, valid_invoice_data): + engine = ValidationEngine(rules=[_CrashingRule()], include_defaults=False) + result = engine.validate(valid_invoice_data, InvoiceData) + assert isinstance(result, ValidationResult) + assert any("exception" in w for w in result.warnings) + + def test_engine_records_rules_checked(self, valid_invoice_data): + engine = ValidationEngine(rules=[_DummyRule()], include_defaults=False) + result = engine.validate(valid_invoice_data, InvoiceData) + assert "dummy_rule" in result.rules_checked + + def test_clean_data_is_valid(self, valid_invoice_data): + engine = ValidationEngine() + result = engine.validate(valid_invoice_data, InvoiceData) + assert result.is_valid is True + assert result.fraud_risk == "clean" + + +# --------------------------------------------------------------------------- +# Math Rules +# --------------------------------------------------------------------------- + + +class TestMathRules: + def test_line_items_sum_valid(self, valid_invoice_data): + rule = LineItemsSumRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_line_items_sum_mismatch_error(self, valid_invoice_data): + valid_invoice_data["subtotal"] = 999.99 + rule = LineItemsSumRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 1 + assert findings[0].severity == RuleSeverity.ERROR + + def test_line_items_sum_small_mismatch_warning(self, valid_invoice_data): + valid_invoice_data["subtotal"] = 300.50 # diff = 0.50, < 1.0 + rule = LineItemsSumRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 1 + assert findings[0].severity == RuleSeverity.WARNING + + def test_line_items_sum_skips_missing_fields(self): + data = {"invoice_number": "INV-001"} + rule = LineItemsSumRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + def test_line_items_sum_receipt(self, valid_receipt_data): + rule = LineItemsSumRule() + findings = rule.validate(valid_receipt_data, ReceiptData) + assert len(findings) == 0 + + def test_subtotal_tax_total_valid(self, valid_invoice_data): + rule = SubtotalTaxTotalRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_subtotal_tax_total_mismatch(self, valid_invoice_data): + valid_invoice_data["total_amount"] = 9999.00 + rule = SubtotalTaxTotalRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 1 + assert findings[0].severity == RuleSeverity.ERROR + + def test_subtotal_tax_total_receipt(self, valid_receipt_data): + rule = SubtotalTaxTotalRule() + findings = rule.validate(valid_receipt_data, ReceiptData) + assert len(findings) == 0 + + def test_line_item_math_valid(self, valid_invoice_data): + rule = LineItemMathRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_line_item_math_mismatch(self, valid_invoice_data): + valid_invoice_data["line_items"][0]["amount"] = 999.00 # 2 * 50 != 999 + rule = LineItemMathRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 1 + assert "Service A" in findings[0].message + + def test_tax_consistency_valid(self): + data = { + "line_items": [ + { + "name": "Item", + "unit_price_without_taxes": 100.00, + "taxes": 20.00, + "taxes_percentage": 20.0, + } + ] + } + rule = TaxConsistencyRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + def test_tax_consistency_mismatch(self): + data = { + "line_items": [ + { + "name": "Item", + "unit_price_without_taxes": 100.00, + "taxes": 50.00, # should be 20.00 + "taxes_percentage": 20.0, + } + ] + } + rule = TaxConsistencyRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + + def test_tolerance_boundary(self): + data = { + "line_items": [{"name": "A", "amount": 100.01}], + "subtotal": 100.00, # diff = 0.01, within default tolerance + } + rule = LineItemsSumRule(tolerance=0.02) + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + +# --------------------------------------------------------------------------- +# Format Rules +# --------------------------------------------------------------------------- + + +class TestFormatRules: + def test_valid_date_formats(self, valid_invoice_data): + rule = DateFormatRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_invalid_date_format(self): + data = {"date": "not-a-date", "due_date": "also bad"} + rule = DateFormatRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 2 + + def test_various_valid_date_formats(self): + for date_str in [ + "2024-01-15", + "01/15/2024", + "15 January 2024", + "January 15, 2024", + "15.01.2024", + ]: + data = {"date": date_str} + rule = DateFormatRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0, f"Failed for date format: {date_str}" + + def test_valid_currency_code(self, valid_invoice_data): + rule = CurrencyCodeRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_invalid_currency_code(self): + data = {"currency": "FAKE"} + rule = CurrencyCodeRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + + def test_currency_case_insensitive(self): + data = {"currency": "usd"} + rule = CurrencyCodeRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + def test_tax_id_valid(self, valid_invoice_data): + rule = TaxIdFormatRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_tax_id_too_short(self): + data = {"vendor_tax_id": "AB"} + rule = TaxIdFormatRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_valid_card_last_four(self, valid_receipt_data): + rule = CardLastFourRule() + findings = rule.validate(valid_receipt_data, ReceiptData) + assert len(findings) == 0 + + def test_invalid_card_last_four(self): + data = {"card_last_four": "12AB"} + rule = CardLastFourRule() + findings = rule.validate(data, ReceiptData) + assert len(findings) == 1 + + def test_card_last_four_too_short(self): + data = {"card_last_four": "12"} + rule = CardLastFourRule() + findings = rule.validate(data, ReceiptData) + assert len(findings) == 1 + + +# --------------------------------------------------------------------------- +# Business Rules +# --------------------------------------------------------------------------- + + +class TestBusinessRules: + def test_required_fields_all_present(self, valid_invoice_data): + rule = RequiredFieldsRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_required_fields_missing(self): + data = { + "subtotal": 100.00 + } # missing invoice_number, date, total_amount, vendor_name + rule = RequiredFieldsRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 4 + + def test_required_fields_receipt(self, valid_receipt_data): + rule = RequiredFieldsRule() + findings = rule.validate(valid_receipt_data, ReceiptData) + assert len(findings) == 0 + + def test_required_fields_receipt_missing(self): + data = {} + rule = RequiredFieldsRule() + findings = rule.validate(data, ReceiptData) + assert len(findings) == 3 # merchant_name, date, total + + def test_due_date_after_issue_date_valid(self, valid_invoice_data): + rule = DueDateAfterIssueDateRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_due_date_before_issue_date(self): + data = {"date": "2024-06-15", "due_date": "2024-01-01"} + rule = DueDateAfterIssueDateRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_due_date_skips_unparseable(self): + data = {"date": "not-a-date", "due_date": "also-bad"} + rule = DueDateAfterIssueDateRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + def test_negative_amount_detected(self): + data = {"total_amount": -500.00} + rule = NegativeAmountsRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].severity == RuleSeverity.ERROR + assert findings[0].is_fraud_signal is True + assert findings[0].fraud_weight == 0.3 + + def test_positive_amounts_clean(self, valid_invoice_data): + rule = NegativeAmountsRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_amount_threshold_exceeded(self): + data = {"total_amount": 200_000.00} + rule = AmountThresholdRule(threshold=100_000.0) + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_amount_threshold_within_range(self, valid_invoice_data): + rule = AmountThresholdRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_empty_line_items_warning(self): + data = {"line_items": []} + rule = EmptyLineItemsRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + + def test_non_empty_line_items_clean(self, valid_invoice_data): + rule = EmptyLineItemsRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + +# --------------------------------------------------------------------------- +# Anomaly Rules +# --------------------------------------------------------------------------- + + +class TestAnomalyRules: + def test_round_number_flagged(self): + data = {"total_amount": 10_000.00} + rule = RoundNumberRule(min_amount=1000.0) + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_non_round_number_clean(self, valid_invoice_data): + rule = RoundNumberRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_round_number_below_min_amount_clean(self): + data = {"total_amount": 100.00} # round but below min_amount + rule = RoundNumberRule(min_amount=1000.0) + findings = rule.validate(data, InvoiceData) + assert len(findings) == 0 + + def test_duplicate_line_items_detected(self): + data = { + "line_items": [ + {"name": "Widget", "amount": 50.00}, + {"name": "Widget", "amount": 50.00}, + ] + } + rule = DuplicateLineItemRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_unique_line_items_clean(self, valid_invoice_data): + rule = DuplicateLineItemRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + def test_extreme_quantity_flagged(self): + data = {"line_items": [{"name": "Bolts", "quantity": 50_000, "amount": 100.00}]} + rule = ExtremeQuantityRule(max_quantity=10_000.0) + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + assert findings[0].is_fraud_signal is True + + def test_negative_quantity_flagged(self): + data = {"line_items": [{"name": "Refund", "quantity": -5, "amount": 100.00}]} + rule = ExtremeQuantityRule() + findings = rule.validate(data, InvoiceData) + assert len(findings) == 1 + + def test_normal_quantity_clean(self, valid_invoice_data): + rule = ExtremeQuantityRule() + findings = rule.validate(valid_invoice_data, InvoiceData) + assert len(findings) == 0 + + +# --------------------------------------------------------------------------- +# Fraud Risk Calculation +# --------------------------------------------------------------------------- + + +class TestFraudRiskCalculation: + def test_clean_data_clean_risk(self, valid_invoice_data): + result = validate(valid_invoice_data, InvoiceData) + assert result.fraud_risk == "clean" + + def test_single_low_signal(self): + data = { + "total_amount": 150_123.45, + "subtotal": 135_123.45, + "tax_amount": 15_000.00, + "date": "2024-01-01", + "invoice_number": "INV-001", + "vendor_name": "Vendor Co", + "line_items": [{"name": "Big project", "amount": 135_123.45}], + } + result = validate(data, InvoiceData) + # Only amount_threshold rule fires (fraud_weight=0.1) -> "low" + assert result.fraud_risk == "low" + + def test_multiple_fraud_signals_escalate_risk(self): + data = { + "total_amount": -500.00, # negative = fraud_weight 0.3 + "subtotal": -500.00, # another negative = 0.3 + "tax_amount": -50.00, # another negative = 0.3 + "date": "2024-01-01", + "invoice_number": "X", + "vendor_name": "V", + } + result = validate(data, InvoiceData) + assert result.fraud_risk in ("high", "critical") + assert result.is_valid is False + + +# --------------------------------------------------------------------------- +# Convenience function +# --------------------------------------------------------------------------- + + +class TestConvenienceFunction: + def test_validate_function(self, valid_invoice_data): + result = validate(valid_invoice_data, InvoiceData) + assert isinstance(result, ValidationResult) + assert result.is_valid is True + + def test_validate_with_custom_rules(self, valid_invoice_data): + result = validate( + valid_invoice_data, + InvoiceData, + rules=[_DummyRule()], + include_defaults=False, + ) + assert len(result.warnings) == 1 + assert "dummy warning" in result.warnings[0] + + def test_validate_receipt(self, valid_receipt_data): + result = validate(valid_receipt_data, ReceiptData) + assert isinstance(result, ValidationResult) + assert result.is_valid is True + + +# --------------------------------------------------------------------------- +# Integration with Harvestor +# --------------------------------------------------------------------------- + + +class TestHarvestorValidationIntegration: + @patch("harvestor.providers.anthropic.Anthropic") + def test_validation_populates_harvest_result( + self, mock_anthropic, sample_invoice_text, mock_anthropic_response, api_key + ): + from harvestor import Harvestor + from harvestor.schemas.base import HarvestResult + + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_anthropic_response + mock_anthropic.return_value = mock_client + + harvestor = Harvestor(api_key=api_key, validate=True) + result = harvestor.harvest_text( + sample_invoice_text, schema=InvoiceData, doc_type="invoice" + ) + + assert isinstance(result, HarvestResult) + assert result.validation is not None + assert isinstance(result.validation, ValidationResult) + assert result.validation.fraud_checked is True + assert len(result.validation.rules_checked) > 0 + + @patch("harvestor.providers.anthropic.Anthropic") + def test_validation_disabled_by_default( + self, mock_anthropic, sample_invoice_text, mock_anthropic_response, api_key + ): + from harvestor import Harvestor + + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_anthropic_response + mock_anthropic.return_value = mock_client + + harvestor = Harvestor(api_key=api_key) + result = harvestor.harvest_text( + sample_invoice_text, schema=InvoiceData, doc_type="invoice" + ) + + assert result.validation is None + + @patch("harvestor.providers.anthropic.Anthropic") + def test_custom_rule_in_harvest( + self, mock_anthropic, sample_invoice_text, mock_anthropic_response, api_key + ): + from harvestor import Harvestor + + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_anthropic_response + mock_anthropic.return_value = mock_client + + harvestor = Harvestor( + api_key=api_key, validate=True, validation_rules=[_DummyRule()] + ) + result = harvestor.harvest_text( + sample_invoice_text, schema=InvoiceData, doc_type="invoice" + ) + + assert result.validation is not None + assert "dummy_rule" in result.validation.rules_checked