diff --git a/docs/branch-notes/steps1-3-hardening.md b/docs/branch-notes/steps1-3-hardening.md new file mode 100644 index 0000000..88873e2 --- /dev/null +++ b/docs/branch-notes/steps1-3-hardening.md @@ -0,0 +1,85 @@ +# Branch Handoff: `codex/steps1-3-hardening` + +## Purpose +This branch hardens steps 1-3 of the pipeline (spec/scenario/sample validation and sampler consistency) and related persona consistency checks. + +## Commits Included (in order) +1. `969ea9d` - fix(scenario): include extended attributes in validator references (#117) +2. `de2298c` - fix(scenario): validate condition literals against categorical options (#118) +3. `0755fae` - fix(cli): add persona validation path and robust validate type detection (#110) +4. `72af9a4` - fix(sampler): reconcile household-derived attributes after assignment (#114) +5. `696c81c` - refactor(sampler): make partner correlation policy-driven with metadata fallback (#123) +6. `84a4ef3` - feat(validator): detect ambiguous categorical/boolean modifier overlaps (#122) +7. `f85393d` - fix(sampler): surface modifier condition eval failures with strict/permissive modes (#121) +8. `ee7a262` - feat(sample): enforce expression constraints and promoted warning gates without new flags (#119 #120) +9. `b790597` - fix(persona): apply semantic context to avoid unemployed/occupation contradictions (#113) + +## Exact Fixes by Area + +### Scenario Validation +- Validator now resolves references against both base population attributes and scenario `extended_attributes`. +- Validator now checks condition string literals against known categorical options, preventing case/value drift (e.g., invalid enum labels). +- Timeline exposure rule conditions now get the same attribute/literal validation checks. + +### CLI Validation +- `extropy validate` spec-type detection is more robust (population vs scenario vs persona). +- Persona config validation path added with structural checks and context-aware checks when scenario/population can be resolved. + +### Sampling / Household Integrity +- Post-sampling reconciliation pass aligns household-derived fields to actual sampled household composition: + - household size + - has-children flags + - children count fields (when present) + - marital consistency for partnered households/dependent agents +- Sampling stats are recomputed after reconciliation so summary stats reflect final values. + +### Partner Correlation Generalization +- Added explicit `partner_correlation_policy` support in population models. +- Correlation algorithm resolution is now policy/metadata driven first, with legacy-name fallback for backward compatibility. +- Added semantic warning when `partner_correlated` attrs lack policy/semantic metadata. + +### Semantic Validator Enhancements +- Added overlap analysis for categorical/boolean modifiers (`MODIFIER_OVERLAP`, `MODIFIER_OVERLAP_EXCLUSIVE`). +- Added partner-policy completeness warnings (`PARTNER_POLICY`). + +### Sampler Failure Visibility / Strictness +- Modifier condition evaluation failures are now surfaced with strict/permissive behavior: + - strict mode: fail sampling + - permissive mode: collect warnings +- `sample` now enforces expression constraints in normal mode (without `--skip-validation`). +- Some semantic warnings are promoted to blocking during strict sampling paths. + +### Persona Rendering Consistency +- Added semantic-context phrase override so non-working agents are not rendered with contradictory active-employment phrasing. +- Simulation persona generation now passes semantic metadata to renderer. + +## Test Coverage Added/Updated +- `tests/test_scenario_validator.py` +- `tests/test_cli.py` +- `tests/test_household_sampling.py` +- `tests/test_validator.py` +- `tests/test_sampler.py` +- `tests/test_persona_renderer.py` + +## Current Known Gaps (not fixed in this branch) +- Spec overlap volume can be high; overlap warnings require spec-side cleanup for deterministic behavior. +- Household type labels can still be semantically mismatched in some edge cases (`couple_with_kids`/`single_parent` labels vs zero dependents after reconciliation). +- Some implausible demographic combinations remain spec-driven (not engine bugs), e.g. age/education/employment combinations where conditional coverage is incomplete. + +## Issue Tracking Guidance (keep issues open for merge coordination) +Do **not** close issues until merge + verification in shared integration branch. +Suggested per-issue state update: +- Add comment: "Implemented on `codex/steps1-3-hardening`, pending integration verification". +- Link commit hash(es) above. +- Keep status open with label like `ready-for-merge-test`. + +Mapped issues on this branch: +- #110, #113, #114, #117, #118, #119, #120, #121, #122, #123 + +## Merge Safety Notes +- This branch intentionally increases strictness in sampling validation; expect older specs to fail faster. +- If another branch modifies validator/sampler internals, resolve conflicts by preserving: + - extended-attr aware scenario validation + - post-household reconciliation pass + - strict/permissive condition handling + - expression constraint enforcement behavior diff --git a/extropy/cli/commands/sample.py b/extropy/cli/commands/sample.py index 2eb1c44..560e61f 100644 --- a/extropy/cli/commands/sample.py +++ b/extropy/cli/commands/sample.py @@ -193,6 +193,56 @@ def sample_command( out.text("[dim]Use --skip-validation to sample anyway[/dim]") raise typer.Exit(out.finish()) else: + promoted_warning_categories = { + "CONDITION_VALUE", + "MODIFIER_OVERLAP_EXCLUSIVE", + "PARTNER_POLICY", + } + promoted_warnings = [ + w + for w in validation_result.warnings + if w.category in promoted_warning_categories + ] + if promoted_warnings: + out.set_data( + "promoted_warning_categories", + sorted(promoted_warning_categories), + ) + out.set_data( + "promoted_warnings", + [ + { + "location": w.location, + "category": w.category, + "message": w.message, + "suggestion": w.suggestion, + } + for w in promoted_warnings + ], + ) + + if skip_validation: + out.warning( + f"Spec has {len(promoted_warnings)} promoted warning(s) - skipping validation" + ) + else: + out.error( + f"Merged spec has {len(promoted_warnings)} promoted warning(s)", + exit_code=ExitCode.VALIDATION_ERROR, + ) + if not agent_mode: + for warn in promoted_warnings[:5]: + out.text( + f" [red]✗[/red] [{warn.category}] {warn.location}: {warn.message}" + ) + if len(promoted_warnings) > 5: + out.text( + f" [dim]... and {len(promoted_warnings) - 5} more[/dim]" + ) + out.blank() + out.text("[dim]Use --skip-validation to sample anyway[/dim]") + raise typer.Exit(out.finish()) + if validation_result.warnings: out.success( f"Spec validated with {len(validation_result.warnings)} warning(s)" @@ -205,6 +255,8 @@ def sample_command( sampling_start = time.time() result = None sampling_error = None + strict_condition_errors = not skip_validation + enforce_expression_constraints = not skip_validation show_progress = count >= 100 and not agent_mode @@ -238,6 +290,8 @@ def on_progress(current: int, total: int): on_progress=on_progress, household_config=household_config, agent_focus_mode=agent_focus_mode, + strict_condition_errors=strict_condition_errors, + enforce_expression_constraints=enforce_expression_constraints, ) except SamplingError as e: sampling_error = e @@ -251,6 +305,8 @@ def on_progress(current: int, total: int): seed=seed, household_config=household_config, agent_focus_mode=agent_focus_mode, + strict_condition_errors=strict_condition_errors, + enforce_expression_constraints=enforce_expression_constraints, ) except SamplingError as e: sampling_error = e @@ -262,6 +318,8 @@ def on_progress(current: int, total: int): seed=seed, household_config=household_config, agent_focus_mode=agent_focus_mode, + strict_condition_errors=strict_condition_errors, + enforce_expression_constraints=enforce_expression_constraints, ) except SamplingError as e: sampling_error = e @@ -282,6 +340,18 @@ def on_progress(current: int, total: int): sampling_time_seconds=sampling_elapsed, ) + if result.stats.condition_warnings: + warning_count = len(result.stats.condition_warnings) + out.warning( + f"{warning_count} modifier condition evaluation warning(s) encountered during sampling" + ) + out.set_data("condition_warning_count", warning_count) + if report and not agent_mode: + for warning in result.stats.condition_warnings[:3]: + out.text(f" [yellow]⚠[/yellow] {warning}") + if warning_count > 3: + out.text(f" [dim]... and {warning_count - 3} more[/dim]") + # Report if agent_mode or report: out.set_data("stats", format_sampling_stats_for_json(result.stats, merged_spec)) diff --git a/extropy/cli/commands/validate.py b/extropy/cli/commands/validate.py index 5d46ce9..d7c6a39 100644 --- a/extropy/cli/commands/validate.py +++ b/extropy/cli/commands/validate.py @@ -4,8 +4,9 @@ from pathlib import Path import typer +import yaml -from ...core.models import PopulationSpec +from ...core.models import PopulationSpec, ScenarioSpec from ...population.validator import validate_spec from ..app import app, console, get_json_mode, is_agent_mode from ..utils import Output, ExitCode, format_validation_for_json @@ -30,6 +31,461 @@ def _is_scenario_file(path: Path) -> bool: return False +def _is_persona_file(path: Path) -> bool: + """Check if file is a persona config based on naming convention.""" + name = path.name + # Legacy patterns + if name.endswith(".persona.yaml") or name.endswith(".persona.yml"): + return True + if name in {"persona.yaml", "persona.yml"}: + return True + # Versioned pattern: persona.v{N}.yaml or persona.v{N}.yml + if re.match(r"^persona\.v\d+\.ya?ml$", name): + return True + return False + + +def _detect_spec_type(path: Path) -> str: + """Detect whether a file is population/scenario/persona. + + Uses filename conventions first, then falls back to top-level YAML keys. + """ + if _is_persona_file(path): + return "persona" + if _is_scenario_file(path): + return "scenario" + + try: + data = yaml.safe_load(path.read_text()) or {} + except Exception: + return "population" + + if isinstance(data, dict): + keys = set(data.keys()) + if {"intro_template", "treatments", "groups", "phrasings"}.issubset(keys): + return "persona" + if { + "event", + "seed_exposure", + "interaction", + "spread", + "outcomes", + "simulation", + }.issubset(keys): + return "scenario" + + return "population" + + +def _extract_version(path: Path, prefix: str) -> int | None: + """Extract version from versioned filename (e.g., persona.v2.yaml).""" + match = re.match(rf"^{re.escape(prefix)}\.v(\d+)\.ya?ml$", path.name) + if not match: + return None + return int(match.group(1)) + + +def _find_scenario_for_persona(persona_path: Path) -> Path | None: + """Find the most likely scenario YAML alongside a persona config.""" + scenario_dir = persona_path.parent + persona_version = _extract_version(persona_path, "persona") + + # Prefer matching version if persona.vN and scenario.vN exist together. + if persona_version is not None: + for ext in ("yaml", "yml"): + candidate = scenario_dir / f"scenario.v{persona_version}.{ext}" + if candidate.exists(): + return candidate + + # Otherwise prefer highest versioned scenario. + versioned: list[tuple[int, Path]] = [] + for candidate in scenario_dir.iterdir(): + if not candidate.is_file(): + continue + match = re.match(r"^scenario\.v(\d+)\.ya?ml$", candidate.name) + if match: + versioned.append((int(match.group(1)), candidate)) + if versioned: + versioned.sort(key=lambda x: x[0], reverse=True) + return versioned[0][1] + + # Legacy names. + for name in ("scenario.yaml", "scenario.yml"): + candidate = scenario_dir / name + if candidate.exists(): + return candidate + + # Legacy suffix forms. + suffix_matches = sorted( + [ + p + for p in scenario_dir.iterdir() + if p.is_file() + and (p.name.endswith(".scenario.yaml") or p.name.endswith(".scenario.yml")) + ] + ) + return suffix_matches[0] if suffix_matches else None + + +def _resolve_population_from_scenario( + scenario_spec: ScenarioSpec, scenario_path: Path +) -> tuple[PopulationSpec | None, str | None]: + """Resolve and load the population spec referenced by a scenario.""" + try: + pop_name, pop_version = scenario_spec.meta.get_population_ref() + except ValueError as e: + return None, str(e) + + pop_path: Path + if scenario_spec.meta.base_population: + if pop_version is None: + return ( + None, + f"Unsupported base_population reference: {scenario_spec.meta.base_population}", + ) + # Expected layout: {study_root}/scenario/{scenario_name}/scenario.vN.yaml + scenario_dir = scenario_path.parent + scenarios_dir = scenario_dir.parent + if scenarios_dir.name == "scenario": + study_root = scenarios_dir.parent + else: + study_root = scenario_dir + pop_path = study_root / f"{pop_name}.v{pop_version}.yaml" + elif scenario_spec.meta.population_spec: + from ...utils import resolve_relative_to + + pop_path = resolve_relative_to( + scenario_spec.meta.population_spec, scenario_path + ) + else: + return None, "Scenario does not define base_population or population_spec" + + if not pop_path.exists(): + return None, f"Population spec not found: {pop_path}" + + try: + return PopulationSpec.from_yaml(pop_path), None + except Exception as e: + return None, f"Failed to load population spec: {e}" + + +def _categorical_options_for_attribute(attr_spec) -> set[str]: + """Return categorical options for an attribute when available.""" + dist = attr_spec.sampling.distribution + if dist is None: + return set() + if getattr(dist, "type", None) != "categorical": + return set() + options = getattr(dist, "options", None) + if not options: + return set() + return {str(opt) for opt in options} + + +def _validate_persona_config(spec_file: Path, out: Output) -> int: + """Validate a persona rendering config.""" + from ...population.persona import PersonaConfig + from ...population.persona.renderer import extract_intro_attributes + + # Load config + if not _is_json_output(): + with console.status("[cyan]Loading persona config...[/cyan]"): + try: + config = PersonaConfig.from_yaml(spec_file) + except Exception as e: + out.error( + f"Failed to load persona config: {e}", + exit_code=ExitCode.VALIDATION_ERROR, + ) + return out.finish() + else: + try: + config = PersonaConfig.from_yaml(spec_file) + except Exception as e: + out.error( + f"Failed to load persona config: {e}", + exit_code=ExitCode.VALIDATION_ERROR, + ) + return out.finish() + + out.success( + "Loaded persona config", + spec_file=str(spec_file), + treatment_count=len(config.treatments), + group_count=len(config.groups), + ) + out.blank() + + errors: list[str] = [] + warnings: list[str] = [] + + # Best-effort context resolution for cross-file validation. + attribute_specs_by_name: dict[str, object] = {} + scenario_path = _find_scenario_for_persona(spec_file) + if scenario_path is None: + warnings.append( + "No sibling scenario file found; running structural-only persona validation" + ) + else: + try: + scenario_spec = ScenarioSpec.from_yaml(scenario_path) + population_spec, pop_error = _resolve_population_from_scenario( + scenario_spec, scenario_path + ) + if pop_error: + warnings.append(pop_error) + elif population_spec: + merged_attributes = list(population_spec.attributes) + if scenario_spec.extended_attributes: + merged_attributes.extend(scenario_spec.extended_attributes) + attribute_specs_by_name = {a.name: a for a in merged_attributes} + out.set_data("resolved_scenario", str(scenario_path)) + out.set_data("resolved_attribute_count", len(attribute_specs_by_name)) + except Exception as e: + warnings.append(f"Failed to resolve scenario context: {e}") + + # Structural checks. + group_names: list[str] = [g.name for g in config.groups] + group_name_set = set(group_names) + duplicate_group_names = sorted( + {name for name in group_names if group_names.count(name) > 1} + ) + for name in duplicate_group_names: + errors.append(f"Duplicate group name: {name}") + + treatment_by_attr: dict[str, object] = {} + for treatment in config.treatments: + if treatment.attribute in treatment_by_attr: + errors.append(f"Duplicate treatment for attribute: {treatment.attribute}") + treatment_by_attr[treatment.attribute] = treatment + if treatment.group not in group_name_set: + errors.append( + f"Treatment for {treatment.attribute} references unknown group: {treatment.group}" + ) + + grouped_attrs: dict[str, set[str]] = {} + for group in config.groups: + seen_in_group: set[str] = set() + for attr in group.attributes: + if attr in seen_in_group: + errors.append( + f"Group {group.name} contains duplicate attribute: {attr}" + ) + seen_in_group.add(attr) + + if attr not in treatment_by_attr: + errors.append( + f"Group {group.name} references attribute without treatment: {attr}" + ) + grouped_attrs.setdefault(attr, set()).add(group.name) + + for attr, owning_groups in grouped_attrs.items(): + if len(owning_groups) > 1: + groups_str = ", ".join(sorted(owning_groups)) + errors.append(f"Attribute {attr} appears in multiple groups: {groups_str}") + + ungrouped = sorted(set(treatment_by_attr) - set(grouped_attrs)) + if ungrouped: + errors.append( + f"Attributes have treatments but are not present in any group: {', '.join(ungrouped)}" + ) + + phrasing_attr_to_kind: dict[str, str] = {} + phrasing_kind_sets: dict[str, set[str]] = { + "boolean": set(), + "categorical": set(), + "relative": set(), + "concrete": set(), + } + + def _register_phrasing(attr: str, kind: str) -> None: + if attr in phrasing_kind_sets[kind]: + errors.append(f"Duplicate {kind} phrasing for attribute: {attr}") + phrasing_kind_sets[kind].add(attr) + existing = phrasing_attr_to_kind.get(attr) + if existing and existing != kind: + errors.append( + f"Attribute {attr} has multiple phrasing kinds: {existing}, {kind}" + ) + phrasing_attr_to_kind[attr] = kind + + for phrasing in config.phrasings.boolean: + _register_phrasing(phrasing.attribute, "boolean") + if not phrasing.true_phrase.strip(): + errors.append( + f"Boolean phrasing for {phrasing.attribute} has empty true_phrase" + ) + if not phrasing.false_phrase.strip(): + errors.append( + f"Boolean phrasing for {phrasing.attribute} has empty false_phrase" + ) + + for phrasing in config.phrasings.categorical: + _register_phrasing(phrasing.attribute, "categorical") + if not phrasing.phrases: + errors.append( + f"Categorical phrasing for {phrasing.attribute} has no option phrases" + ) + + for phrasing in config.phrasings.relative: + _register_phrasing(phrasing.attribute, "relative") + labels = phrasing.labels + if not all( + [ + labels.much_below.strip(), + labels.below.strip(), + labels.average.strip(), + labels.above.strip(), + labels.much_above.strip(), + ] + ): + errors.append( + f"Relative phrasing for {phrasing.attribute} has one or more empty labels" + ) + + concrete_template_pattern = re.compile(r"\{value(?::[^}]*)?\}") + for phrasing in config.phrasings.concrete: + _register_phrasing(phrasing.attribute, "concrete") + if not concrete_template_pattern.search(phrasing.template): + errors.append( + f"Concrete phrasing for {phrasing.attribute} template must include {{value}}" + ) + if phrasing.format_spec: + if phrasing.format_spec not in {"time12", "time24"}: + try: + format(1234.567, phrasing.format_spec) + except Exception: + errors.append( + f"Concrete phrasing for {phrasing.attribute} has invalid format_spec: {phrasing.format_spec}" + ) + + intro_attrs = extract_intro_attributes(config.intro_template) + + # Cross-file semantic checks when scenario + population can be resolved. + if attribute_specs_by_name: + known_attrs = set(attribute_specs_by_name) + + for source, attr_set in [ + ("treatments", set(treatment_by_attr)), + ("groups", set(grouped_attrs)), + ("phrasings", set(phrasing_attr_to_kind)), + ("intro_template", set(intro_attrs)), + ]: + unknown = sorted(attr_set - known_attrs) + if unknown: + errors.append( + f"{source} references unknown attributes: {', '.join(unknown)}" + ) + + missing_treatments = sorted(known_attrs - set(treatment_by_attr)) + if missing_treatments: + errors.append( + f"Missing treatments for attributes: {', '.join(missing_treatments)}" + ) + + missing_phrasings = sorted(known_attrs - set(phrasing_attr_to_kind)) + if missing_phrasings: + errors.append( + f"Missing phrasing entries for attributes: {', '.join(missing_phrasings)}" + ) + + for attr_name, attr_spec in attribute_specs_by_name.items(): + attr_type = attr_spec.type + treatment = treatment_by_attr.get(attr_name) + phrasing_kind = phrasing_attr_to_kind.get(attr_name) + + if treatment is not None and treatment.treatment.value == "relative": + if attr_type not in {"int", "float"}: + errors.append( + f"Attribute {attr_name} uses relative treatment but has non-numeric type: {attr_type}" + ) + if phrasing_kind != "relative": + errors.append( + f"Attribute {attr_name} uses relative treatment but has {phrasing_kind or 'no'} phrasing" + ) + elif attr_type in {"int", "float"} and phrasing_kind not in { + None, + "concrete", + "relative", + }: + errors.append( + f"Numeric attribute {attr_name} has incompatible phrasing kind: {phrasing_kind}" + ) + + if attr_type == "boolean" and phrasing_kind not in {None, "boolean"}: + errors.append( + f"Boolean attribute {attr_name} must use boolean phrasing (found {phrasing_kind})" + ) + if attr_type == "categorical" and phrasing_kind not in { + None, + "categorical", + }: + errors.append( + f"Categorical attribute {attr_name} must use categorical phrasing (found {phrasing_kind})" + ) + + for phrasing in config.phrasings.categorical: + attr_spec = attribute_specs_by_name.get(phrasing.attribute) + if attr_spec is None: + continue + if attr_spec.type != "categorical": + errors.append( + f"Attribute {phrasing.attribute} has categorical phrasing but type is {attr_spec.type}" + ) + continue + + options = _categorical_options_for_attribute(attr_spec) + if not options: + warnings.append( + f"Could not verify option coverage for {phrasing.attribute} (no categorical options found in spec)" + ) + continue + + covered = set(phrasing.phrases.keys()) | set(phrasing.null_options) + missing = sorted(options - covered) + extra = sorted(covered - options) + if missing: + errors.append( + f"Categorical phrasing for {phrasing.attribute} missing options: {', '.join(missing)}" + ) + if extra: + errors.append( + f"Categorical phrasing for {phrasing.attribute} includes unknown options: {', '.join(extra)}" + ) + + # Emit warnings. + for warning in warnings: + out.warning(warning) + + # Emit errors. + if errors: + out.error( + f"Persona config has {len(errors)} error(s)", + exit_code=ExitCode.VALIDATION_ERROR, + ) + if _is_json_output(): + for err in errors: + out.error(err, exit_code=ExitCode.VALIDATION_ERROR) + else: + for err in errors[:15]: + out.text(f" [red]✗[/red] {err}") + if len(errors) > 15: + out.text(f" [dim]... and {len(errors) - 15} more[/dim]") + return out.finish() + + if warnings: + out.success(f"Persona config validated with {len(warnings)} warning(s)") + else: + out.success("Persona config validated") + + out.blank() + out.divider() + out.text("[green]Validation passed[/green]") + out.divider() + + return out.finish() + + def _validate_population_spec(spec_file: Path, strict: bool, out: Output) -> int: """Validate a population spec.""" # Load spec @@ -263,17 +719,18 @@ def _validate_scenario_spec(spec_file: Path, out: Output) -> int: @app.command("validate") def validate_command( spec_file: Path = typer.Argument( - ..., help="Spec file to validate (.yaml or .scenario.yaml)" + ..., help="Spec file to validate (.yaml, scenario.vN.yaml, or persona.vN.yaml)" ), strict: bool = typer.Option( False, "--strict", help="Treat warnings as errors (population specs only)" ), ): """ - Validate a population spec or scenario spec. + Validate a population, scenario, or persona spec. Auto-detects file type based on naming: - *.scenario.yaml → scenario spec validation + - *.persona.yaml → persona config validation - *.yaml → population spec validation EXIT CODES: @@ -284,6 +741,7 @@ def validate_command( EXAMPLES: extropy validate surgeons.yaml # Population spec extropy validate surgeons.scenario.yaml # Scenario spec + extropy validate scenario/ai-adoption/persona.v1.yaml # Persona config extropy validate surgeons.yaml --strict # Treat warnings as errors """ out = Output(console=console, json_mode=_is_json_output()) @@ -299,7 +757,10 @@ def validate_command( raise typer.Exit(out.finish()) # Route to appropriate validator - if _is_scenario_file(spec_file): + spec_type = _detect_spec_type(spec_file) + if spec_type == "persona": + exit_code = _validate_persona_config(spec_file, out) + elif spec_type == "scenario": exit_code = _validate_scenario_spec(spec_file, out) else: exit_code = _validate_population_spec(spec_file, strict, out) diff --git a/extropy/cli/utils.py b/extropy/cli/utils.py index a19ed96..985f54c 100644 --- a/extropy/cli/utils.py +++ b/extropy/cli/utils.py @@ -344,4 +344,9 @@ def format_sampling_stats_for_json(stats, spec) -> dict[str, Any]: if stats.constraint_violations: result["constraint_violations"] = stats.constraint_violations + # Condition evaluation warnings (permissive mode) + if stats.condition_warnings: + result["condition_warning_count"] = len(stats.condition_warnings) + result["condition_warnings"] = stats.condition_warnings[:50] + return result diff --git a/extropy/core/models/population.py b/extropy/core/models/population.py index 0483cdb..5af935a 100644 --- a/extropy/core/models/population.py +++ b/extropy/core/models/population.py @@ -349,6 +349,14 @@ class SamplingConfig(BaseModel): default_factory=list, description="Conditional modifiers (for conditional strategy)", ) + modifier_overlap_policy: Literal["exclusive", "ordered_override"] | None = Field( + default=None, + description=( + "Optional policy for overlapping conditional modifiers on categorical/" + "boolean attributes. exclusive=conditions should be mutually exclusive; " + "ordered_override=last-match precedence is intentional." + ), + ) # ============================================================================= @@ -409,6 +417,16 @@ class AttributeSpec(BaseModel): default=None, description="For partner_correlated scope: probability (0-1) that partner has same value. None uses type-specific defaults (age uses gaussian, race uses per-group rates).", ) + partner_correlation_policy: Literal[ + "gaussian_offset", "same_group_rate", "same_value_probability", None + ] = Field( + default=None, + description=( + "Optional explicit policy for scope=partner_correlated. " + "When unset, sampler resolves a policy from semantic_type/identity_type " + "with legacy-name fallback." + ), + ) semantic_type: Literal[ "age", "income", "education", "employment", "occupation", None ] = Field( @@ -706,6 +724,9 @@ class DiscoveredAttribute(BaseModel): default=None, description="For partner_correlated scope: probability (0-1) that partner has same value", ) + partner_correlation_policy: Literal[ + "gaussian_offset", "same_group_rate", "same_value_probability", None + ] = Field(default=None) semantic_type: Literal[ "age", "income", "education", "employment", "occupation", None ] = Field( @@ -766,6 +787,9 @@ class HydratedAttribute(BaseModel): default=None, description="For partner_correlated scope: probability (0-1) that partner has same value", ) + partner_correlation_policy: Literal[ + "gaussian_offset", "same_group_rate", "same_value_probability", None + ] = Field(default=None) semantic_type: Literal[ "age", "income", "education", "employment", "occupation", None ] = Field( diff --git a/extropy/population/persona/renderer.py b/extropy/population/persona/renderer.py index 34dba55..984ad3e 100644 --- a/extropy/population/persona/renderer.py +++ b/extropy/population/persona/renderer.py @@ -279,6 +279,48 @@ def _ensure_period(phrase: str) -> str: return phrase +def _is_non_working_status(value: Any) -> bool: + """Heuristic for employment statuses that imply no current job.""" + if value is None: + return False + text = str(value).lower().replace("_", " ").replace("-", " ") + non_working_tokens = ( + "unemployed", + "not employed", + "not in labor", + "retired", + "homemaker", + "stay at home", + "disabled", + ) + return any(token in text for token in non_working_tokens) + + +def _apply_contextual_phrase_overrides( + attr_name: str, + phrase: str, + *, + semantic_type_map: dict[str, str] | None = None, + not_currently_working: bool = False, +) -> str: + """Adjust phrases using semantic context to avoid contradictions.""" + if not phrase: + return phrase + if not semantic_type_map: + return phrase + + semantic_type = semantic_type_map.get(attr_name) + if semantic_type == "occupation" and not_currently_working: + lowered = phrase.lower() + if lowered.startswith("i work in "): + return "My background is in " + phrase[10:] + if lowered.startswith("i work as "): + return "My background is as " + phrase[10:] + if lowered.startswith("i work "): + return "My work background is " + phrase[7:] + return phrase + + def render_intro( agent: dict[str, Any], config: PersonaConfig, @@ -529,6 +571,7 @@ def render_persona( config: PersonaConfig, decision_relevant_attributes: list[str] | None = None, display_format_map: dict[str, str] | None = None, + semantic_type_map: dict[str, str] | None = None, ) -> str: """Render complete first-person persona for an agent. @@ -540,12 +583,25 @@ def render_persona( "Most Relevant to This Decision" section. display_format_map: Optional mapping of attr_name -> display_format (from AttributeSpec.display_format, set by LLM during spec creation) + semantic_type_map: Optional mapping of attr_name -> semantic_type + (from AttributeSpec.semantic_type, set by LLM during spec creation) Returns: Complete persona as markdown string """ sections = [] + employment_attrs = [] + if semantic_type_map: + employment_attrs = [ + name + for name, semantic in semantic_type_map.items() + if semantic == "employment" + ] + not_currently_working = any( + _is_non_working_status(agent.get(attr_name)) for attr_name in employment_attrs + ) + # Render intro intro = render_intro(agent, config, display_format_map) if intro: @@ -570,6 +626,12 @@ def render_persona( for attr_name in decision_relevant_attributes: value = agent.get(attr_name) phrase = render_attribute(attr_name, value, config) + phrase = _apply_contextual_phrase_overrides( + attr_name, + phrase, + semantic_type_map=semantic_type_map, + not_currently_working=not_currently_working, + ) if phrase: decision_phrases.append(_ensure_period(phrase)) if decision_phrases: @@ -591,13 +653,19 @@ def render_persona( # render_intro() already emits "## Who I Am" — skip duplicate if group_obj.label == "Who I Am": - lines = [""] + lines = ["## More About Me", ""] else: lines = [f"## {group_obj.label}", ""] phrases = [] for attr_name in remaining_attrs: value = agent.get(attr_name) phrase = render_attribute(attr_name, value, config) + phrase = _apply_contextual_phrase_overrides( + attr_name, + phrase, + semantic_type_map=semantic_type_map, + not_currently_working=not_currently_working, + ) if phrase: phrases.append(_ensure_period(phrase)) diff --git a/extropy/population/sampler/core.py b/extropy/population/sampler/core.py index 4e76d69..5796935 100644 --- a/extropy/population/sampler/core.py +++ b/extropy/population/sampler/core.py @@ -34,7 +34,7 @@ generate_dependents, ) from .modifiers import apply_modifiers_and_sample -from ...utils.eval_safe import eval_formula, FormulaError +from ...utils.eval_safe import ConditionError, eval_formula, FormulaError from ..names import generate_name from ..names.generator import age_to_birth_decade @@ -183,6 +183,8 @@ def sample_population( on_progress: ItemProgressCallback | None = None, household_config: HouseholdConfig | None = None, agent_focus_mode: Literal["primary_only", "couples", "all"] | None = None, + strict_condition_errors: bool = False, + enforce_expression_constraints: bool = False, ) -> SamplingResult: """ Generate agents from a PopulationSpec. @@ -197,6 +199,10 @@ def sample_population( seed: Random seed for reproducibility (None = random) on_progress: Optional callback(current, total) for progress updates household_config: Household composition config (required if spec has household attributes) + strict_condition_errors: If True, modifier condition evaluation failures + raise SamplingError; if False, failures are recorded as warnings. + enforce_expression_constraints: If True, fail sampling when expression + constraints are violated by any sampled agent. Returns: SamplingResult with agents list, metadata, and statistics @@ -257,18 +263,43 @@ def sample_population( on_progress, hh_config, agent_focus_mode=agent_focus_mode, + strict_condition_errors=strict_condition_errors, ) else: agents = _sample_population_independent( - spec, attr_map, rng, n, id_width, stats, numeric_values, on_progress + spec, + attr_map, + rng, + n, + id_width, + stats, + numeric_values, + on_progress, + strict_condition_errors=strict_condition_errors, ) households = [] + # Rebuild value-distribution stats from finalized agents. This keeps stats + # accurate when post-sampling reconciliation mutates sampled values. + numeric_values = _rebuild_stats_from_agents(spec, agents, stats) + # Compute final statistics _finalize_stats(stats, numeric_values, len(agents)) # Check expression constraints - _check_expression_constraints(spec, agents, stats) + total_constraint_violations = _check_expression_constraints(spec, agents, stats) + if enforce_expression_constraints and total_constraint_violations > 0: + top = sorted( + stats.constraint_violations.items(), + key=lambda kv: kv[1], + reverse=True, + )[:3] + details = "; ".join([f"{k} ({v})" for k, v in top]) + raise SamplingError( + "Expression constraint violations detected: " + f"{total_constraint_violations} total violation(s). " + f"Top constraints: {details}" + ) # Build metadata meta: dict[str, Any] = { @@ -303,12 +334,20 @@ def _sample_population_independent( stats: SamplingStats, numeric_values: dict[str, list[float]], on_progress: ItemProgressCallback | None = None, + strict_condition_errors: bool = False, ) -> list[dict[str, Any]]: """Sample N agents independently (legacy path).""" agents: list[dict[str, Any]] = [] for i in range(n): agent = _sample_single_agent( - spec, attr_map, rng, i, id_width, stats, numeric_values + spec, + attr_map, + rng, + i, + id_width, + stats, + numeric_values, + strict_condition_errors=strict_condition_errors, ) agents.append(agent) if on_progress: @@ -334,20 +373,32 @@ def _generate_npc_partner( # Always include gender partner["gender"] = rng.choice(["male", "female"]) - # Always correlate age if present (essential for NPC identity, regardless of scope) - if "age" in primary: - partner["age"] = correlate_partner_attribute( - "age", + # Always correlate age if present (essential for NPC identity, regardless of scope). + age_attr_name = next( + (name for name, attr in attr_map.items() if attr.semantic_type == "age"), + "age", + ) + age_attr = attr_map.get(age_attr_name) + if age_attr_name in primary: + partner_age = correlate_partner_attribute( + age_attr_name, "int", - primary["age"], + primary[age_attr_name], None, # Uses gaussian offset rng, config, + semantic_type=age_attr.semantic_type if age_attr else "age", + identity_type=age_attr.identity_type if age_attr else None, + partner_correlation_policy=( + age_attr.partner_correlation_policy if age_attr else None + ), ) + partner[age_attr_name] = partner_age + partner["age"] = partner_age # Process attributes based on their scope for attr_name, attr in attr_map.items(): - if attr_name not in primary or attr_name == "age": + if attr_name not in primary or attr_name == age_attr_name: continue if attr.scope == "household": @@ -363,6 +414,9 @@ def _generate_npc_partner( rng, config, available_options=categorical_options.get(attr_name), + semantic_type=attr.semantic_type, + identity_type=attr.identity_type, + partner_correlation_policy=attr.partner_correlation_policy, ) # Individual scope: skip for NPC (not enough data to sample fully) @@ -399,6 +453,7 @@ def _sample_dependent_as_agent( dependent: Any, parent: dict[str, Any], household_id: str, + strict_condition_errors: bool = False, ) -> dict[str, Any]: """Promote a dependent to a full agent with all attributes sampled. @@ -406,7 +461,14 @@ def _sample_dependent_as_agent( then samples remaining attributes normally. """ agent = _sample_single_agent( - spec, attr_map, rng, index, id_width, stats, numeric_values + spec, + attr_map, + rng, + index, + id_width, + stats, + numeric_values, + strict_condition_errors=strict_condition_errors, ) # Override with dependent's known attributes @@ -443,6 +505,7 @@ def _sample_population_households( on_progress: ItemProgressCallback | None = None, config: HouseholdConfig | None = None, agent_focus_mode: Literal["primary_only", "couples", "all"] | None = None, + strict_condition_errors: bool = False, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Sample agents in household units with correlated demographics. @@ -480,7 +543,14 @@ def _sample_population_households( # Sample Adult 1 (primary) — always an agent adult1 = _sample_single_agent( - spec, attr_map, rng, agent_index, id_width, stats, numeric_values + spec, + attr_map, + rng, + agent_index, + id_width, + stats, + numeric_values, + strict_condition_errors=strict_condition_errors, ) adult1_age = adult1.get("age", 35) agent_index += 1 @@ -523,6 +593,7 @@ def _sample_population_households( household_attrs, categorical_options, config, + strict_condition_errors=strict_condition_errors, ) adult2["household_id"] = household_id adult2["household_role"] = "adult_secondary" @@ -579,6 +650,7 @@ def _sample_population_households( dep, adult1, household_id, + strict_condition_errors=strict_condition_errors, ) agents.append(kid_agent) adult_ids.append(kid_agent["_id"]) @@ -625,9 +697,254 @@ def _sample_population_households( a["partner_id"] = None agents = agents[:target_n] + _reconcile_household_attributes(agents, households, attr_map) + return agents, households +def _resolve_household_size_attribute(attr_map: dict[str, AttributeSpec]) -> str | None: + """Identify the attribute representing total household membership size.""" + if "household_size" in attr_map: + return "household_size" + + for attr_name, attr in attr_map.items(): + if attr.type not in ("int", "float"): + continue + if attr.scope != "household": + continue + lowered = attr_name.lower() + if "household" in lowered and ("size" in lowered or "count" in lowered): + return attr_name + return None + + +def _resolve_has_children_attribute(attr_map: dict[str, AttributeSpec]) -> str | None: + """Identify a boolean parenthood indicator attribute (e.g., has_children).""" + if "has_children" in attr_map and attr_map["has_children"].type == "boolean": + return "has_children" + + for attr_name, attr in attr_map.items(): + if attr.type != "boolean": + continue + if attr.identity_type == "parental_status": + return attr_name + lowered = attr_name.lower() + if "child" in lowered or "kid" in lowered: + return attr_name + return None + + +def _resolve_children_count_attribute(attr_map: dict[str, AttributeSpec]) -> str | None: + """Identify an integer child-count attribute when present.""" + for attr_name, attr in attr_map.items(): + if attr.type not in ("int", "float"): + continue + lowered = attr_name.lower() + if ( + ("child" in lowered or "kid" in lowered) + and ("count" in lowered or "num" in lowered) + ) or "dependent_count" in lowered: + return attr_name + return None + + +def _resolve_marital_attribute( + attr_map: dict[str, AttributeSpec], +) -> tuple[str | None, str | None, str | None, set[str]]: + """Identify marital-status attribute and representative option values. + + Returns: + (attr_name, partnered_value, unpartnered_value, partnered_values_set) + """ + candidate_attr: str | None = None + partnered_value: str | None = None + unpartnered_value: str | None = None + partnered_values: set[str] = set() + + partnered_priority = [ + "married", + "domestic partner", + "partnered", + "civil union", + "cohab", + ] + unpartnered_priority = [ + "single", + "never married", + "divorc", + "widow", + "separat", + "unmarried", + ] + + for attr_name, attr in attr_map.items(): + if attr.type != "categorical": + continue + dist = attr.sampling.distribution + if not dist or not hasattr(dist, "options") or not dist.options: + continue + + options = [str(o) for o in dist.options] + normalized = [o.lower().replace("_", " ").replace("-", " ") for o in options] + + local_partnered: list[str] = [] + local_unpartnered: list[str] = [] + for raw, norm in zip(options, normalized, strict=False): + if any(token in norm for token in partnered_priority): + local_partnered.append(raw) + if any(token in norm for token in unpartnered_priority): + local_unpartnered.append(raw) + + if local_partnered and local_unpartnered: + candidate_attr = attr_name + partnered_values = set(local_partnered) + + for token in partnered_priority: + matched = next( + ( + raw + for raw, norm in zip(options, normalized, strict=False) + if token in norm + ), + None, + ) + if matched: + partnered_value = matched + break + if partnered_value is None: + partnered_value = local_partnered[0] + + for token in unpartnered_priority: + matched = next( + ( + raw + for raw, norm in zip(options, normalized, strict=False) + if token in norm + ), + None, + ) + if matched: + unpartnered_value = matched + break + if unpartnered_value is None: + unpartnered_value = local_unpartnered[0] + break + + return candidate_attr, partnered_value, unpartnered_value, partnered_values + + +def _count_child_dependents(dependents: list[dict[str, Any]]) -> int: + """Count child dependents from dependent records.""" + child_count = 0 + for dep in dependents: + relationship = str(dep.get("relationship", "")).lower() + if any( + token in relationship + for token in ("son", "daughter", "child", "kid", "stepchild") + ): + child_count += 1 + return child_count + + +def _reconcile_household_attributes( + agents: list[dict[str, Any]], + households: list[dict[str, Any]], + attr_map: dict[str, AttributeSpec], +) -> None: + """Reconcile household-derived attributes after partner/dependent assignment.""" + if not agents: + return + + household_size_attr = _resolve_household_size_attribute(attr_map) + has_children_attr = _resolve_has_children_attribute(attr_map) + children_count_attr = _resolve_children_count_attribute(attr_map) + marital_attr, partnered_value, unpartnered_value, partnered_values = ( + _resolve_marital_attribute(attr_map) + ) + + members_by_household: dict[str, list[dict[str, Any]]] = {} + for agent in agents: + hh_id = agent.get("household_id") + if not hh_id: + continue + members_by_household.setdefault(str(hh_id), []).append(agent) + + household_record_by_id = { + str(hh.get("id")): hh for hh in households if isinstance(hh, dict) + } + + for hh_id, members in members_by_household.items(): + primary = next( + (m for m in members if m.get("household_role") == "adult_primary"), + members[0], + ) + npc_dependents = primary.get("dependents") + if not isinstance(npc_dependents, list): + npc_dependents = [] + + promoted_child_agents = sum( + 1 + for m in members + if str(m.get("household_role", "")).startswith("dependent_") + and any( + token in str(m.get("relationship_to_primary", "")).lower() + for token in ("son", "daughter", "child", "kid") + ) + ) + + child_dependents = ( + _count_child_dependents(npc_dependents) + promoted_child_agents + ) + actual_size = len(members) + len(npc_dependents) + + for member in members: + is_dependent_agent = str(member.get("household_role", "")).startswith( + "dependent_" + ) + has_partner = bool(member.get("partner_id")) or bool( + member.get("partner_npc") + ) + + if household_size_attr and household_size_attr in attr_map: + size_attr = attr_map[household_size_attr] + if size_attr.type == "float": + member[household_size_attr] = float(actual_size) + else: + member[household_size_attr] = int(actual_size) + + if has_children_attr: + member[has_children_attr] = ( + False if is_dependent_agent else child_dependents > 0 + ) + + if children_count_attr and children_count_attr in attr_map: + count_attr = attr_map[children_count_attr] + value = 0 if is_dependent_agent else child_dependents + if count_attr.type == "float": + member[children_count_attr] = float(value) + else: + member[children_count_attr] = int(value) + + if marital_attr and partnered_value: + if is_dependent_agent and unpartnered_value: + member[marital_attr] = unpartnered_value + elif has_partner and member.get(marital_attr) not in partnered_values: + member[marital_attr] = partnered_value + + hh_record = household_record_by_id.get(hh_id) + if not hh_record: + continue + shared = hh_record.get("shared_attributes") + if not isinstance(shared, dict): + continue + if household_size_attr and household_size_attr in shared: + shared[household_size_attr] = int(actual_size) + if has_children_attr and has_children_attr in shared: + shared[has_children_attr] = child_dependents > 0 + if children_count_attr and children_count_attr in shared: + shared[children_count_attr] = int(child_dependents) + + def _sample_partner_agent( spec: PopulationSpec, attr_map: dict[str, AttributeSpec], @@ -640,6 +957,7 @@ def _sample_partner_agent( household_attrs: set[str], categorical_options: dict[str, list[str]], config: HouseholdConfig | None = None, + strict_condition_errors: bool = False, ) -> dict[str, Any]: """Sample a partner agent with correlated demographics. @@ -670,12 +988,21 @@ def _sample_partner_agent( rng, config, available_options=categorical_options.get(attr_name), + semantic_type=attr.semantic_type, + identity_type=attr.identity_type, + partner_correlation_policy=attr.partner_correlation_policy, ) else: # Individual scope: sample independently try: - value = _sample_attribute(attr, rng, agent, stats) - except FormulaError as e: + value = _sample_attribute( + attr, + rng, + agent, + stats, + strict_condition_errors=strict_condition_errors, + ) + except (FormulaError, ConditionError) as e: raise SamplingError( f"Agent {index}: Failed to sample '{attr_name}': {e}" ) from e @@ -712,6 +1039,7 @@ def _sample_single_agent( id_width: int, stats: SamplingStats, numeric_values: dict[str, list[float]], + strict_condition_errors: bool = False, ) -> dict[str, Any]: """Sample a single agent following the sampling order.""" agent: dict[str, Any] = {"_id": f"agent_{index:0{id_width}d}"} @@ -723,8 +1051,14 @@ def _sample_single_agent( continue try: - value = _sample_attribute(attr, rng, agent, stats) - except FormulaError as e: + value = _sample_attribute( + attr, + rng, + agent, + stats, + strict_condition_errors=strict_condition_errors, + ) + except (FormulaError, ConditionError) as e: raise SamplingError( f"Agent {index}: Failed to sample '{attr_name}': {e}" ) from e @@ -766,6 +1100,7 @@ def _sample_attribute( rng: random.Random, agent: dict[str, Any], stats: SamplingStats, + strict_condition_errors: bool = False, ) -> Any: """Sample a single attribute based on its strategy.""" strategy = attr.sampling.strategy @@ -800,6 +1135,8 @@ def _sample_attribute( attr.sampling.modifiers, rng, agent, + strict_condition_errors=strict_condition_errors, + condition_warnings=stats.condition_warnings, ) # Update modifier trigger stats @@ -849,6 +1186,46 @@ def _update_stats( stats.boolean_counts[attr.name][bool_value] += 1 +def _rebuild_stats_from_agents( + spec: PopulationSpec, + agents: list[dict[str, Any]], + stats: SamplingStats, +) -> dict[str, list[float]]: + """Rebuild distribution stats from finalized agent records.""" + numeric_values: dict[str, list[float]] = {} + + for attr in spec.attributes: + if attr.type in ("int", "float"): + stats.attribute_means[attr.name] = 0.0 + stats.attribute_stds[attr.name] = 0.0 + numeric_values[attr.name] = [] + elif attr.type == "categorical": + stats.categorical_counts[attr.name] = {} + elif attr.type == "boolean": + stats.boolean_counts[attr.name] = {True: 0, False: 0} + + for agent in agents: + for attr in spec.attributes: + if attr.name not in agent: + continue + + value = agent[attr.name] + if attr.type in ("int", "float") and isinstance(value, (int, float)): + numeric_values[attr.name].append(float(value)) + elif attr.type == "categorical": + str_value = str(value) + stats.categorical_counts[attr.name][str_value] = ( + stats.categorical_counts[attr.name].get(str_value, 0) + 1 + ) + elif attr.type == "boolean": + bool_value = bool(value) + stats.boolean_counts[attr.name][bool_value] = ( + stats.boolean_counts[attr.name].get(bool_value, 0) + 1 + ) + + return numeric_values + + def _finalize_stats( stats: SamplingStats, numeric_values: dict[str, list[float]], @@ -871,7 +1248,7 @@ def _check_expression_constraints( spec: PopulationSpec, agents: list[dict[str, Any]], stats: SamplingStats, -) -> None: +) -> int: """Check expression constraints and count violations. Only checks constraints with type='expression' (agent-level constraints). @@ -880,6 +1257,8 @@ def _check_expression_constraints( """ from ...utils.eval_safe import eval_condition + total_violations = 0 + for attr in spec.attributes: for constraint in attr.constraints: # Only check agent-level expression constraints @@ -902,6 +1281,9 @@ def _check_expression_constraints( if violation_count > 0: key = f"{attr.name}: {constraint.expression}" stats.constraint_violations[key] = violation_count + total_violations += violation_count + + return total_violations def save_json(result: SamplingResult, path: Path | str) -> None: diff --git a/extropy/population/sampler/households.py b/extropy/population/sampler/households.py index 9eeb04b..cf3ebca 100644 --- a/extropy/population/sampler/households.py +++ b/extropy/population/sampler/households.py @@ -14,7 +14,7 @@ import math import random -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, Literal from ...core.models.population import Dependent, HouseholdConfig, HouseholdType from ..names.generator import generate_name @@ -23,6 +23,17 @@ from ...core.models.population import NameConfig +# Legacy name-based mapping for backward compatibility only. +# New specs should drive partner-correlation behavior via metadata/policy. +_LEGACY_POLICY_HINTS: dict[str, dict[str, str]] = { + "age": {"semantic_type": "age"}, + "race_ethnicity": {"identity_type": "race_ethnicity"}, + "race": {"identity_type": "race_ethnicity"}, + "ethnicity": {"identity_type": "race_ethnicity"}, + "country": {"identity_type": "citizenship"}, +} + + def _age_bracket(age: int, config: HouseholdConfig) -> str: """Map age to bracket key using config age brackets.""" for upper_bound, label in config.age_brackets: @@ -78,13 +89,20 @@ def correlate_partner_attribute( rng: random.Random, config: HouseholdConfig, available_options: list[str] | None = None, + semantic_type: str | None = None, + identity_type: str | None = None, + partner_correlation_policy: Literal[ + "gaussian_offset", "same_group_rate", "same_value_probability" + ] + | None = None, ) -> Any: """Produce a correlated value for a partner based on the primary's value. - Uses the correlation_rate from the attribute spec. Special handling: - - age (int/float): Gaussian offset using config.partner_age_gap_mean/std - - race_ethnicity-like attrs: Per-group rates from config.same_group_rates - - Other categorical/boolean: Simple probability of same value + Policy resolution order: + 1. Explicit `partner_correlation_policy` from attribute metadata + 2. semantic_type / identity_type metadata + 3. Legacy name-based compatibility mapping + 4. Default same-value probability behavior Args: attr_name: Name of the attribute @@ -95,11 +113,24 @@ def correlate_partner_attribute( config: HouseholdConfig with default rates available_options: For categorical attrs, list of valid options to sample from + semantic_type: Optional semantic type from AttributeSpec + identity_type: Optional identity type from AttributeSpec + partner_correlation_policy: Optional explicit policy from AttributeSpec + Returns: - The correlated value for the partner. + Correlated value for the partner. """ - # Age uses gaussian offset, not simple correlation - if attr_name == "age" and attr_type in ("int", "float"): + policy = _resolve_partner_policy( + attr_name=attr_name, + attr_type=attr_type, + correlation_rate=correlation_rate, + config=config, + semantic_type=semantic_type, + identity_type=identity_type, + partner_correlation_policy=partner_correlation_policy, + ) + + if policy == "gaussian_offset": partner_age = int( round( rng.gauss( @@ -110,8 +141,7 @@ def correlate_partner_attribute( ) return max(config.min_adult_age, partner_age) - # Race/ethnicity uses per-group rates from config - if attr_name in ("race_ethnicity", "ethnicity", "race"): + if policy == "same_group_rate": same_rate = config.same_group_rates.get( str(primary_value).lower(), config.default_same_group_rate ) @@ -123,26 +153,11 @@ def correlate_partner_attribute( return rng.choice(others) return primary_value - # Country uses same_country_rate from config if no explicit rate - if attr_name == "country": - rate = ( - correlation_rate - if correlation_rate is not None - else config.same_country_rate - ) - if rng.random() < rate: - return primary_value - if available_options: - others = [o for o in available_options if o != primary_value] - if others: - return rng.choice(others) - return primary_value - - # For all other attributes, use the explicit correlation_rate or a default - rate = ( - correlation_rate - if correlation_rate is not None - else config.default_same_group_rate + rate = _resolve_same_value_rate( + attr_name=attr_name, + correlation_rate=correlation_rate, + config=config, + identity_type=identity_type, ) if rng.random() < rate: return primary_value @@ -156,6 +171,73 @@ def correlate_partner_attribute( return primary_value +def _resolve_partner_policy( + attr_name: str, + attr_type: str, + correlation_rate: float | None, + config: HouseholdConfig, + semantic_type: str | None = None, + identity_type: str | None = None, + partner_correlation_policy: Literal[ + "gaussian_offset", "same_group_rate", "same_value_probability" + ] + | None = None, +) -> Literal["gaussian_offset", "same_group_rate", "same_value_probability"]: + """Resolve which partner-correlation algorithm to use.""" + if partner_correlation_policy is not None: + return partner_correlation_policy + + inferred_semantic = semantic_type + inferred_identity = identity_type + + # Backward compatibility: infer missing metadata from legacy names. + if inferred_semantic is None and inferred_identity is None: + hints = _LEGACY_POLICY_HINTS.get(attr_name.lower()) + if hints: + inferred_semantic = hints.get("semantic_type") + inferred_identity = hints.get("identity_type") + + if inferred_semantic == "age" and attr_type in ("int", "float"): + return "gaussian_offset" + + if inferred_identity == "race_ethnicity": + return "same_group_rate" + + # Treat citizenship-like identity as same-country behavior. + if inferred_identity == "citizenship": + return "same_value_probability" + + # Default fallback remains same-value probability. + _ = correlation_rate + _ = config + return "same_value_probability" + + +def _resolve_same_value_rate( + attr_name: str, + correlation_rate: float | None, + config: HouseholdConfig, + identity_type: str | None = None, +) -> float: + """Resolve same-value probability for partner correlation.""" + if correlation_rate is not None: + return correlation_rate + + inferred_identity = identity_type + if inferred_identity is None: + hints = _LEGACY_POLICY_HINTS.get(attr_name.lower()) + if hints: + inferred_identity = hints.get("identity_type") + + if inferred_identity == "citizenship": + return config.same_country_rate + + if attr_name in config.assortative_mating: + return config.assortative_mating[attr_name] + + return config.default_same_group_rate + + def generate_dependents( household_type: HouseholdType, household_size: int, diff --git a/extropy/population/sampler/modifiers.py b/extropy/population/sampler/modifiers.py index 54c2e5f..f2108e8 100644 --- a/extropy/population/sampler/modifiers.py +++ b/extropy/population/sampler/modifiers.py @@ -20,7 +20,7 @@ CategoricalDistribution, BooleanDistribution, ) -from ...utils.eval_safe import eval_condition +from ...utils.eval_safe import ConditionError, eval_condition from .distributions import ( _sample_normal, _sample_lognormal, @@ -39,6 +39,9 @@ def apply_modifiers_and_sample( modifiers: list[Modifier], rng: random.Random, agent: dict[str, Any], + *, + strict_condition_errors: bool = False, + condition_warnings: list[str] | None = None, ) -> tuple[Any, list[int]]: """ Apply matching modifiers to a distribution and sample. @@ -62,8 +65,12 @@ def apply_modifiers_and_sample( matching_modifiers.append((i, mod)) triggered_indices.append(i) except Exception as e: - # Log warning but continue - condition failure means modifier doesn't apply - logger.warning(f"Modifier condition '{mod.when}' failed: {e}") + message = f"Modifier condition '{mod.when}' failed: {e}" + if strict_condition_errors: + raise ConditionError(message) from e + logger.warning(message) + if condition_warnings is not None: + condition_warnings.append(message) # Route to type-specific handler if isinstance(dist, (NormalDistribution, LognormalDistribution)): diff --git a/extropy/population/validator/llm_response.py b/extropy/population/validator/llm_response.py index 42fb190..d5fd8e5 100644 --- a/extropy/population/validator/llm_response.py +++ b/extropy/population/validator/llm_response.py @@ -38,6 +38,73 @@ def _make_error( ) +def _coerce_float(value: Any) -> float | None: + """Coerce value to float, returning None when not numeric-like.""" + if isinstance(value, bool): + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + stripped = value.strip() + if not stripped: + return None + try: + return float(stripped) + except ValueError: + return None + return None + + +def _coerce_numeric_list( + values: list[Any], + field: str, + errors: list[ValidationIssue], +) -> list[float] | None: + """Coerce a list of values to floats, recording validation errors.""" + coerced: list[float] = [] + invalid = False + for idx, value in enumerate(values): + numeric = _coerce_float(value) + if numeric is None: + invalid = True + errors.append( + _make_error( + field=f"{field}[{idx}]", + value=str(value), + error="value must be numeric", + suggestion="Use numeric values (e.g., 0.25)", + ) + ) + continue + coerced.append(numeric) + return None if invalid else coerced + + +def _coerce_numeric_mapping( + values: dict[str, Any], + field: str, + errors: list[ValidationIssue], +) -> dict[str, float] | None: + """Coerce mapping values to floats, recording validation errors.""" + coerced: dict[str, float] = {} + invalid = False + for key, value in values.items(): + numeric = _coerce_float(value) + if numeric is None: + invalid = True + errors.append( + _make_error( + field=f"{field}.{key}", + value=str(value), + error="value must be numeric", + suggestion="Use numeric values (e.g., 0.25)", + ) + ) + continue + coerced[key] = numeric + return None if invalid else coerced + + # Spec-level variable patterns that should use spec_expression, not expression SPEC_LEVEL_PATTERNS = {"weights", "options"} @@ -232,34 +299,76 @@ def validate_distribution_data( # Check std is positive if present std = dist_data.get("std") - if std is not None and std < 0: - errors.append( - _make_error( - field=f"{attr_name}.distribution.std", - value=str(std), - error="standard deviation cannot be negative", - suggestion="Use a positive value for std", + if std is not None: + std_num = _coerce_float(std) + if std_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.std", + value=str(std), + error="standard deviation must be numeric", + suggestion="Use a numeric value for std", + ) + ) + elif std_num < 0: + errors.append( + _make_error( + field=f"{attr_name}.distribution.std", + value=str(std), + error="standard deviation cannot be negative", + suggestion="Use a positive value for std", + ) ) - ) # Check min < max if both present min_val = dist_data.get("min") max_val = dist_data.get("max") - if min_val is not None and max_val is not None and min_val >= max_val: - errors.append( - _make_error( - field=f"{attr_name}.distribution.min/max", - value=f"min={min_val}, max={max_val}", - error="min must be less than max", - suggestion="Swap min and max values", + if min_val is not None and max_val is not None: + min_num = _coerce_float(min_val) + max_num = _coerce_float(max_val) + if min_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.min", + value=str(min_val), + error="min must be numeric", + suggestion="Use a numeric value for min", + ) + ) + if max_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.max", + value=str(max_val), + error="max must be numeric", + suggestion="Use a numeric value for max", + ) + ) + if min_num is not None and max_num is not None and min_num >= max_num: + errors.append( + _make_error( + field=f"{attr_name}.distribution.min/max", + value=f"min={min_val}, max={max_val}", + error="min must be less than max", + suggestion="Swap min and max values", + ) ) - ) elif dist_type == "beta": alpha = dist_data.get("alpha") beta = dist_data.get("beta") - if alpha is None or alpha <= 0: + alpha_num = _coerce_float(alpha) if alpha is not None else None + if alpha is None or alpha_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.alpha", + value=str(alpha), + error="alpha must be numeric", + suggestion="Use a positive value like 2.0", + ) + ) + elif alpha_num <= 0: errors.append( _make_error( field=f"{attr_name}.distribution.alpha", @@ -269,7 +378,17 @@ def validate_distribution_data( ) ) - if beta is None or beta <= 0: + beta_num = _coerce_float(beta) if beta is not None else None + if beta is None or beta_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.beta", + value=str(beta), + error="beta must be numeric", + suggestion="Use a positive value like 5.0", + ) + ) + elif beta_num <= 0: errors.append( _make_error( field=f"{attr_name}.distribution.beta", @@ -283,15 +402,36 @@ def validate_distribution_data( min_val = dist_data.get("min") max_val = dist_data.get("max") - if min_val is not None and max_val is not None and min_val >= max_val: - errors.append( - _make_error( - field=f"{attr_name}.distribution.min/max", - value=f"min={min_val}, max={max_val}", - error="min must be less than max", - suggestion="Swap min and max values", + if min_val is not None and max_val is not None: + min_num = _coerce_float(min_val) + max_num = _coerce_float(max_val) + if min_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.min", + value=str(min_val), + error="min must be numeric", + suggestion="Use a numeric value for min", + ) + ) + if max_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.max", + value=str(max_val), + error="max must be numeric", + suggestion="Use a numeric value for max", + ) + ) + if min_num is not None and max_num is not None and min_num >= max_num: + errors.append( + _make_error( + field=f"{attr_name}.distribution.min/max", + value=f"min={min_val}, max={max_val}", + error="min must be less than max", + suggestion="Swap min and max values", + ) ) - ) elif dist_type == "categorical": options = dist_data.get("options") @@ -316,28 +456,43 @@ def validate_distribution_data( ) ) elif weights: - weight_sum = sum(weights) - if abs(weight_sum - 1.0) > 0.02: - errors.append( - _make_error( - field=f"{attr_name}.distribution.weights", - value=f"sum={weight_sum:.3f}", - error="weights must sum to 1.0", - suggestion="Normalize weights to sum to 1.0", + numeric_weights = _coerce_numeric_list( + list(weights), f"{attr_name}.distribution.weights", errors + ) + if numeric_weights is not None: + weight_sum = sum(numeric_weights) + if abs(weight_sum - 1.0) > 0.02: + errors.append( + _make_error( + field=f"{attr_name}.distribution.weights", + value=f"sum={weight_sum:.3f}", + error="weights must sum to 1.0", + suggestion="Normalize weights to sum to 1.0", + ) ) - ) elif dist_type == "boolean": prob = dist_data.get("probability_true") - if prob is not None and (prob < 0 or prob > 1): - errors.append( - _make_error( - field=f"{attr_name}.distribution.probability_true", - value=str(prob), - error="probability must be between 0 and 1", - suggestion="Use a value like 0.5 or 0.75", + if prob is not None: + prob_num = _coerce_float(prob) + if prob_num is None: + errors.append( + _make_error( + field=f"{attr_name}.distribution.probability_true", + value=str(prob), + error="probability must be numeric", + suggestion="Use a value like 0.5 or 0.75", + ) + ) + elif prob_num < 0 or prob_num > 1: + errors.append( + _make_error( + field=f"{attr_name}.distribution.probability_true", + value=str(prob), + error="probability must be between 0 and 1", + suggestion="Use a value like 0.5 or 0.75", + ) ) - ) return errors @@ -439,29 +594,46 @@ def validate_modifier_data( # Validate probability_override range prob_override = modifier_data.get("probability_override") - if prob_override is not None and (prob_override < 0 or prob_override > 1): - errors.append( - _make_error( - field=f"{attr_name}.modifiers[{modifier_index}].probability_override", - value=str(prob_override), - error="probability_override must be between 0 and 1", - suggestion="Use a value like 0.75", + if prob_override is not None: + prob_override_num = _coerce_float(prob_override) + if prob_override_num is None: + errors.append( + _make_error( + field=f"{attr_name}.modifiers[{modifier_index}].probability_override", + value=str(prob_override), + error="probability_override must be numeric", + suggestion="Use a value like 0.75", + ) + ) + elif prob_override_num < 0 or prob_override_num > 1: + errors.append( + _make_error( + field=f"{attr_name}.modifiers[{modifier_index}].probability_override", + value=str(prob_override), + error="probability_override must be between 0 and 1", + suggestion="Use a value like 0.75", + ) ) - ) # Validate weight_overrides sum to 1.0 weight_overrides = modifier_data.get("weight_overrides") if weight_overrides and isinstance(weight_overrides, dict): - weight_sum = sum(weight_overrides.values()) - if abs(weight_sum - 1.0) > 0.02: - errors.append( - _make_error( - field=f"{attr_name}.modifiers[{modifier_index}].weight_overrides", - value=f"sum={weight_sum:.3f}", - error="weight_overrides must sum to 1.0", - suggestion="Normalize weights to sum to 1.0", + numeric_weight_overrides = _coerce_numeric_mapping( + weight_overrides, + f"{attr_name}.modifiers[{modifier_index}].weight_overrides", + errors, + ) + if numeric_weight_overrides is not None: + weight_sum = sum(numeric_weight_overrides.values()) + if abs(weight_sum - 1.0) > 0.02: + errors.append( + _make_error( + field=f"{attr_name}.modifiers[{modifier_index}].weight_overrides", + value=f"sum={weight_sum:.3f}", + error="weight_overrides must sum to 1.0", + suggestion="Normalize weights to sum to 1.0", + ) ) - ) return errors diff --git a/extropy/population/validator/semantic.py b/extropy/population/validator/semantic.py index 92589ab..9eff6c9 100644 --- a/extropy/population/validator/semantic.py +++ b/extropy/population/validator/semantic.py @@ -4,6 +4,8 @@ They help identify potential issues but don't indicate structural problems. """ +import ast + from ...core.models.validation import Severity, ValidationIssue from ...core.models import ( PopulationSpec, @@ -29,6 +31,8 @@ def run_semantic_checks(spec: PopulationSpec) -> list[ValidationIssue]: 10. No-Op Detection 11. Modifier Stacking Analysis 12. Condition Value Validity + 13. Partner-Correlation Policy Completeness + 14. Modifier Overlap Ambiguity """ issues: list[ValidationIssue] = [] @@ -45,6 +49,12 @@ def run_semantic_checks(spec: PopulationSpec) -> list[ValidationIssue]: # Category 12: Condition Value Validity issues.extend(_check_condition_values(attr, attr_lookup)) + # Category 13: Partner-correlation policy completeness + issues.extend(_check_partner_correlation_policy(attr)) + + # Category 14: Modifier overlap ambiguity + issues.extend(_check_modifier_overlap(attr)) + return issues @@ -248,3 +258,164 @@ def _check_condition_values( ) return issues + + +# ============================================================================= +# Category 13: Partner-Correlation Policy Completeness +# ============================================================================= + + +def _check_partner_correlation_policy(attr: AttributeSpec) -> list[ValidationIssue]: + """Warn when partner-correlated attributes lack explicit policy metadata.""" + if attr.scope != "partner_correlated": + return [] + + has_policy = attr.partner_correlation_policy is not None + has_semantics = attr.semantic_type is not None or attr.identity_type is not None + has_explicit_rate = attr.correlation_rate is not None + + if has_policy or has_semantics or has_explicit_rate: + return [] + + return [ + ValidationIssue( + severity=Severity.WARNING, + category="PARTNER_POLICY", + location=attr.name, + message=( + "partner_correlated attribute has no explicit correlation policy or " + "semantic metadata; behavior will fall back to legacy defaults" + ), + suggestion=( + "Set partner_correlation_policy, semantic_type/identity_type, " + "or correlation_rate" + ), + ) + ] + + +# ============================================================================= +# Category 14: Modifier Overlap Ambiguity +# ============================================================================= + + +def _extract_literal_sets(expr: str) -> dict[str, set[str]]: + """Extract literal comparison sets per attribute from an expression.""" + try: + tree = ast.parse(expr, mode="eval") + except SyntaxError: + return {} + + extracted: dict[str, set[str]] = {} + + def _add(attr_name: str, value: str) -> None: + extracted.setdefault(attr_name, set()).add(value) + + def _extract_values(node: ast.AST) -> list[str]: + values: list[str] = [] + if isinstance(node, ast.Constant): + if isinstance(node.value, (str, bool, int, float)): + values.append(str(node.value)) + elif isinstance(node, ast.List): + for elt in node.elts: + if isinstance(elt, ast.Constant) and isinstance( + elt.value, (str, bool, int, float) + ): + values.append(str(elt.value)) + elif isinstance(node, ast.Tuple): + for elt in node.elts: + if isinstance(elt, ast.Constant) and isinstance( + elt.value, (str, bool, int, float) + ): + values.append(str(elt.value)) + return values + + for node in ast.walk(tree): + if not isinstance(node, ast.Compare): + continue + if not isinstance(node.left, ast.Name): + continue + attr_name = node.left.id + for comparator in node.comparators: + for value in _extract_values(comparator): + _add(attr_name, value) + + return extracted + + +def _conditions_could_overlap(expr_a: str, expr_b: str) -> bool: + """Conservative overlap check for modifier conditions. + + Returns True only when overlap is plausible from literal comparisons. + """ + if expr_a.strip() == expr_b.strip(): + return True + + literals_a = _extract_literal_sets(expr_a) + literals_b = _extract_literal_sets(expr_b) + if not literals_a or not literals_b: + return False + + shared_attrs = set(literals_a) & set(literals_b) + if not shared_attrs: + return False + + for attr_name in shared_attrs: + if literals_a[attr_name].isdisjoint(literals_b[attr_name]): + return False + + return True + + +def _check_modifier_overlap(attr: AttributeSpec) -> list[ValidationIssue]: + """Warn when categorical/boolean modifiers appear to overlap ambiguously.""" + if attr.type not in {"categorical", "boolean"}: + return [] + if len(attr.sampling.modifiers) < 2: + return [] + if attr.sampling.modifier_overlap_policy == "ordered_override": + return [] + + issues: list[ValidationIssue] = [] + for i in range(len(attr.sampling.modifiers)): + for j in range(i + 1, len(attr.sampling.modifiers)): + mod_i = attr.sampling.modifiers[i] + mod_j = attr.sampling.modifiers[j] + if not mod_i.when or not mod_j.when: + continue + if not _conditions_could_overlap(mod_i.when, mod_j.when): + continue + + policy = attr.sampling.modifier_overlap_policy + if policy == "exclusive": + message = ( + f"modifiers[{i}] and modifiers[{j}] overlap despite " + "modifier_overlap_policy='exclusive'" + ) + suggestion = ( + "Make the conditions mutually exclusive, or set " + "modifier_overlap_policy: ordered_override if precedence is intended" + ) + category = "MODIFIER_OVERLAP_EXCLUSIVE" + else: + message = ( + f"modifiers[{i}] and modifiers[{j}] may overlap; categorical/boolean " + "modifiers use last-wins semantics" + ) + suggestion = ( + "Set sampling.modifier_overlap_policy to 'ordered_override' " + "if intentional, otherwise make conditions mutually exclusive" + ) + category = "MODIFIER_OVERLAP" + + issues.append( + ValidationIssue( + severity=Severity.WARNING, + category=category, + location=attr.name, + message=message, + suggestion=suggestion, + ) + ) + + return issues diff --git a/extropy/scenario/validator.py b/extropy/scenario/validator.py index 164f483..7f65e01 100644 --- a/extropy/scenario/validator.py +++ b/extropy/scenario/validator.py @@ -17,6 +17,7 @@ ValidationResult, ) from ..utils.expressions import ( + extract_comparisons_from_expression, extract_names_from_expression, validate_expression_syntax, ) @@ -116,10 +117,31 @@ def validate_scenario( errors: list[ValidationIssue] = [] warnings: list[ValidationIssue] = [] - # Build set of known attributes from population spec + # Build set of known attributes from population spec + scenario extensions known_attributes: set[str] = set() + categorical_options_by_attr: dict[str, set[str]] = {} if population_spec: known_attributes = {attr.name for attr in population_spec.attributes} + for attr in population_spec.attributes: + dist = attr.sampling.distribution + if ( + attr.type == "categorical" + and dist is not None + and hasattr(dist, "options") + and getattr(dist, "options") + ): + categorical_options_by_attr[attr.name] = set(dist.options) + if spec.extended_attributes: + known_attributes |= {attr.name for attr in spec.extended_attributes} + for attr in spec.extended_attributes: + dist = attr.sampling.distribution + if ( + attr.type == "categorical" + and dist is not None + and hasattr(dist, "options") + and getattr(dist, "options") + ): + categorical_options_by_attr[attr.name] = set(dist.options) # Build set of known edge types from network # Check both 'edge_type' and 'type' fields (different network formats) @@ -213,7 +235,7 @@ def validate_scenario( ) else: # Check attribute references - if population_spec: + if known_attributes: refs = extract_names_from_expression(rule.when) unknown_refs = refs - known_attributes if unknown_refs: @@ -226,6 +248,31 @@ def validate_scenario( ) ) + # Check compared string literals against categorical domains + comparisons = extract_comparisons_from_expression(rule.when) + for attr_name, compared_values in comparisons: + valid_options = categorical_options_by_attr.get(attr_name) + if not valid_options: + continue + invalid_values = [ + value for value in compared_values if value not in valid_options + ] + if invalid_values: + errors.append( + ValidationError( + category="condition_value", + location=f"seed_exposure.rules[{i}].when", + message=( + f"Condition compares {attr_name} to invalid value(s): " + f"{', '.join(repr(v) for v in invalid_values)}" + ), + suggestion=( + f"Valid options for {attr_name}: " + f"{', '.join(sorted(valid_options))}" + ), + ) + ) + # Check probability bounds (already enforced by Pydantic, but double-check) if not 0 <= rule.probability <= 1: errors.append( @@ -291,7 +338,7 @@ def validate_scenario( # Allow edge attributes injected during propagation refs_without_edge_fields = refs - {"edge_type", "edge_weight"} - if population_spec: + if known_attributes: unknown_refs = refs_without_edge_fields - known_attributes if unknown_refs: errors.append( @@ -320,6 +367,31 @@ def validate_scenario( ) ) + # Check compared string literals against categorical domains + comparisons = extract_comparisons_from_expression(modifier.when) + for attr_name, compared_values in comparisons: + valid_options = categorical_options_by_attr.get(attr_name) + if not valid_options: + continue + invalid_values = [ + value for value in compared_values if value not in valid_options + ] + if invalid_values: + errors.append( + ValidationError( + category="condition_value", + location=f"spread.share_modifiers[{i}].when", + message=( + f"Condition compares {attr_name} to invalid value(s): " + f"{', '.join(repr(v) for v in invalid_values)}" + ), + suggestion=( + f"Valid options for {attr_name}: " + f"{', '.join(sorted(valid_options))}" + ), + ) + ) + # Warn about potentially problematic multipliers if modifier.multiply < 0: warnings.append( @@ -449,6 +521,64 @@ def validate_scenario( ) ) + # Validate custom timeline exposure rules (if provided) + if te.exposure_rules: + for j, rule in enumerate(te.exposure_rules): + syntax_error = validate_expression_syntax(rule.when) + if syntax_error: + errors.append( + ValidationError( + category="timeline", + location=f"timeline[{i}].exposure_rules[{j}].when", + message=f"Invalid expression syntax: {syntax_error}", + suggestion="Use valid Python expression syntax", + ) + ) + continue + + if known_attributes: + refs = extract_names_from_expression(rule.when) + unknown_refs = refs - known_attributes + if unknown_refs: + errors.append( + ValidationError( + category="attribute_reference", + location=f"timeline[{i}].exposure_rules[{j}].when", + message=( + "References unknown attribute(s): " + f"{', '.join(sorted(unknown_refs))}" + ), + suggestion="Check attribute names in population/spec", + ) + ) + + comparisons = extract_comparisons_from_expression(rule.when) + for attr_name, compared_values in comparisons: + valid_options = categorical_options_by_attr.get(attr_name) + if not valid_options: + continue + invalid_values = [ + value + for value in compared_values + if value not in valid_options + ] + if invalid_values: + errors.append( + ValidationError( + category="condition_value", + location=f"timeline[{i}].exposure_rules[{j}].when", + message=( + "Condition compares " + f"{attr_name} to invalid value(s): " + f"{', '.join(repr(v) for v in invalid_values)}" + ), + suggestion=( + f"Valid options for {attr_name}: " + f"{', '.join(sorted(valid_options))}" + ), + ) + ) + # ========================================================================= # Validate Simulation Config # ========================================================================= diff --git a/extropy/simulation/persona.py b/extropy/simulation/persona.py index 85326fd..b4811ef 100644 --- a/extropy/simulation/persona.py +++ b/extropy/simulation/persona.py @@ -276,15 +276,25 @@ def generate_persona( # Build display_format_map from population_spec if available display_format_map = None + semantic_type_map = None if population_spec: display_format_map = { attr.name: attr.display_format for attr in population_spec.attributes if attr.display_format } + semantic_type_map = { + attr.name: attr.semantic_type + for attr in population_spec.attributes + if attr.semantic_type + } return render_new_persona( - agent, persona_config, decision_relevant_attributes, display_format_map + agent, + persona_config, + decision_relevant_attributes, + display_format_map, + semantic_type_map, ) # Legacy rendering below diff --git a/tests/test_cli.py b/tests/test_cli.py index e61d40e..519f03a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,7 +8,11 @@ from typer.testing import CliRunner from extropy.cli.app import app -from extropy.cli.commands.validate import _is_scenario_file +from extropy.cli.commands.validate import ( + _detect_spec_type, + _is_persona_file, + _is_scenario_file, +) from extropy.population.network.config import NetworkConfig from extropy.storage import open_study_db @@ -56,6 +60,444 @@ def test_scenario_filename_detection(self): assert _is_scenario_file(Path("foo.scenario.yaml")) assert not _is_scenario_file(Path("population.yaml")) + def test_persona_filename_detection(self): + assert _is_persona_file(Path("persona.yaml")) + assert _is_persona_file(Path("foo.persona.yaml")) + assert _is_persona_file(Path("persona.v1.yaml")) + assert not _is_persona_file(Path("population.yaml")) + + def test_detect_spec_type_by_content(self, tmp_path): + persona_like = tmp_path / "config.yaml" + persona_like.write_text(""" +intro_template: "I am {role}." +treatments: [] +groups: [] +phrasings: + boolean: [] + categorical: [] + relative: [] + concrete: [] +""") + + assert _detect_spec_type(persona_like) == "persona" + + def test_validate_persona_config_passes(self, tmp_path): + study_dir = tmp_path / "study" + scenario_dir = study_dir / "scenario" / "test" + scenario_dir.mkdir(parents=True) + + pop_yaml = study_dir / "population.v1.yaml" + pop_yaml.write_text(""" +meta: + description: Test population + geography: USA + agent_focus: test agents + +grounding: + overall: low + sources_count: 0 + strong_count: 0 + medium_count: 0 + low_count: 2 + +attributes: + - name: role + type: categorical + category: population_specific + description: Role + sampling: + strategy: independent + distribution: + type: categorical + options: [x, y] + weights: [0.5, 0.5] + grounding: + level: low + method: estimated + - name: has_children + type: boolean + category: population_specific + description: Children in household + sampling: + strategy: independent + distribution: + type: boolean + probability_true: 0.4 + grounding: + level: low + method: estimated + +sampling_order: + - role + - has_children +""") + + scenario_yaml = scenario_dir / "scenario.v1.yaml" + scenario_yaml.write_text(""" +meta: + name: test + description: Test scenario + base_population: population.v1 + created_at: 2024-01-01T00:00:00 + +event: + type: announcement + content: Test announcement + source: test_source + credibility: 0.8 + ambiguity: 0.2 + emotional_valence: 0.0 + +seed_exposure: + channels: [] + rules: [] + +interaction: + primary_model: passive_observation + description: Test interaction + +spread: + share_probability: 0.3 + +outcomes: + suggested_outcomes: [] + capture_full_reasoning: true + +simulation: + max_timesteps: 10 + timestep_unit: day +""") + + persona_yaml = scenario_dir / "persona.v1.yaml" + persona_yaml.write_text(""" +population_description: Test population +created_at: "2024-01-01T00:00:00" +intro_template: "I am {role}." +treatments: + - attribute: role + treatment: concrete + group: basics + - attribute: has_children + treatment: concrete + group: basics +groups: + - name: basics + label: Basics + attributes: + - role + - has_children +phrasings: + boolean: + - attribute: has_children + true_phrase: I have children. + false_phrase: I do not have children. + categorical: + - attribute: role + phrases: + x: I am in role x. + y: I am in role y. + null_options: [] + null_phrase: null + fallback: null + relative: [] + concrete: [] +population_stats: + stats: {} +""") + + result = runner.invoke(app, ["validate", str(persona_yaml)]) + assert result.exit_code == 0, result.output + + def test_validate_persona_config_fails_on_option_mismatch(self, tmp_path): + study_dir = tmp_path / "study" + scenario_dir = study_dir / "scenario" / "test" + scenario_dir.mkdir(parents=True) + + pop_yaml = study_dir / "population.v1.yaml" + pop_yaml.write_text(""" +meta: + description: Test population + geography: USA + agent_focus: test agents + +grounding: + overall: low + sources_count: 0 + strong_count: 0 + medium_count: 0 + low_count: 2 + +attributes: + - name: role + type: categorical + category: population_specific + description: Role + sampling: + strategy: independent + distribution: + type: categorical + options: [x, y] + weights: [0.5, 0.5] + grounding: + level: low + method: estimated + +sampling_order: + - role +""") + + scenario_yaml = scenario_dir / "scenario.v1.yaml" + scenario_yaml.write_text(""" +meta: + name: test + description: Test scenario + base_population: population.v1 + created_at: 2024-01-01T00:00:00 + +event: + type: announcement + content: Test announcement + source: test_source + credibility: 0.8 + ambiguity: 0.2 + emotional_valence: 0.0 + +seed_exposure: + channels: [] + rules: [] + +interaction: + primary_model: passive_observation + description: Test interaction + +spread: + share_probability: 0.3 + +outcomes: + suggested_outcomes: [] + capture_full_reasoning: true + +simulation: + max_timesteps: 10 + timestep_unit: day +""") + + persona_yaml = scenario_dir / "persona.v1.yaml" + persona_yaml.write_text(""" +population_description: Test population +created_at: "2024-01-01T00:00:00" +intro_template: "I am {role}." +treatments: + - attribute: role + treatment: concrete + group: basics +groups: + - name: basics + label: Basics + attributes: + - role +phrasings: + boolean: [] + categorical: + - attribute: role + phrases: + x: I am in role x. + null_options: [] + null_phrase: null + fallback: null + relative: [] + concrete: [] +population_stats: + stats: {} +""") + + result = runner.invoke(app, ["validate", str(persona_yaml)]) + assert result.exit_code != 0 + assert "missing options" in result.output.lower() + + +class TestSampleValidationGates: + """Tests for promoted-warning gating in sample command.""" + + def _setup_study_with_condition_value_warning(self, tmp_path): + study_dir = tmp_path / "study" + scenario_dir = study_dir / "scenario" / "test" + scenario_dir.mkdir(parents=True) + study_db = study_dir / "study.db" + with open_study_db(study_db): + pass + + pop_yaml = study_dir / "population.v1.yaml" + pop_yaml.write_text(""" +meta: + description: Test population + geography: USA + agent_focus: test agents + +grounding: + overall: low + sources_count: 0 + strong_count: 0 + medium_count: 0 + low_count: 2 + +attributes: + - name: region + type: categorical + category: universal + description: Region + sampling: + strategy: independent + distribution: + type: categorical + options: [Urban, Rural] + weights: [0.6, 0.4] + grounding: + level: low + method: estimated + - name: audience_segment + type: categorical + category: population_specific + description: Audience segment + sampling: + strategy: conditional + distribution: + type: categorical + options: [A, B] + weights: [0.5, 0.5] + depends_on: [region] + modifiers: + - when: "region == 'Suburban'" + weight_overrides: + A: 0.8 + B: 0.2 + grounding: + level: low + method: estimated + +sampling_order: + - region + - audience_segment +""") + + scenario_yaml = scenario_dir / "scenario.v1.yaml" + scenario_yaml.write_text(""" +meta: + name: test + description: Test scenario + base_population: population.v1 + created_at: 2024-01-01T00:00:00 + +event: + type: announcement + content: Test announcement + source: test_source + credibility: 0.8 + ambiguity: 0.2 + emotional_valence: 0.0 + +seed_exposure: + channels: [] + rules: [] + +interaction: + primary_model: passive_observation + description: Test interaction + +spread: + share_probability: 0.3 + +outcomes: + suggested_outcomes: [] + capture_full_reasoning: true + +simulation: + max_timesteps: 10 + timestep_unit: day +""") + + persona_yaml = scenario_dir / "persona.v1.yaml" + persona_yaml.write_text(""" +population_description: Test population +created_at: "2024-01-01T00:00:00" +intro_template: "I am in {region}." +treatments: + - attribute: region + treatment: concrete + group: basics + - attribute: audience_segment + treatment: concrete + group: basics +groups: + - name: basics + label: Basics + attributes: + - region + - audience_segment +phrasings: + boolean: [] + categorical: + - attribute: region + phrases: + Urban: I live in an urban area. + Rural: I live in a rural area. + null_options: [] + null_phrase: null + fallback: null + - attribute: audience_segment + phrases: + A: I am in segment A. + B: I am in segment B. + null_options: [] + null_phrase: null + fallback: null + relative: [] + concrete: [] +population_stats: + stats: {} +""") + + return study_dir + + def test_sample_fails_on_promoted_warnings_by_default(self, tmp_path): + study_dir = self._setup_study_with_condition_value_warning(tmp_path) + + old_cwd = os.getcwd() + try: + os.chdir(study_dir) + result = runner.invoke( + app, ["sample", "-s", "test", "-n", "30", "--seed", "42"] + ) + finally: + os.chdir(old_cwd) + + assert result.exit_code != 0 + assert "promoted warning" in result.output.lower() + + def test_sample_skip_validation_allows_promoted_warnings(self, tmp_path): + study_dir = self._setup_study_with_condition_value_warning(tmp_path) + + old_cwd = os.getcwd() + try: + os.chdir(study_dir) + result = runner.invoke( + app, + [ + "sample", + "-s", + "test", + "-n", + "30", + "--seed", + "42", + "--skip-validation", + ], + ) + finally: + os.chdir(old_cwd) + + assert result.exit_code == 0 + class TestVersionFlag: """Test the --version flag.""" diff --git a/tests/test_household_sampling.py b/tests/test_household_sampling.py index 115d727..8106605 100644 --- a/tests/test_household_sampling.py +++ b/tests/test_household_sampling.py @@ -9,6 +9,7 @@ GroundingInfo, NormalDistribution, CategoricalDistribution, + BooleanDistribution, HouseholdType, HouseholdConfig, Dependent, @@ -175,6 +176,101 @@ def _make_individual_spec(size: int = 50) -> PopulationSpec: ) +def _make_household_consistency_spec(size: int = 200) -> PopulationSpec: + """Spec with household-related attributes for reconciliation tests.""" + return PopulationSpec( + meta=SpecMeta(description="Household consistency spec", size=size), + grounding=GroundingSummary( + overall="medium", + sources_count=1, + strong_count=2, + medium_count=2, + low_count=2, + ), + attributes=[ + AttributeSpec( + name="age", + type="int", + category="universal", + description="Age", + scope="partner_correlated", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", mean=38, std=12, min=18, max=85 + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + AttributeSpec( + name="gender", + type="categorical", + category="universal", + description="Gender", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["male", "female"], + weights=[0.49, 0.51], + ), + ), + grounding=GroundingInfo(level="strong", method="researched"), + ), + AttributeSpec( + name="marital_status", + type="categorical", + category="universal", + description="Marital status", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + type="categorical", + options=["Single", "Married", "Divorced", "Widowed"], + weights=[0.62, 0.25, 0.09, 0.04], + ), + ), + grounding=GroundingInfo(level="medium", method="estimated"), + ), + AttributeSpec( + name="household_size", + type="int", + category="universal", + description="Number of people in household", + scope="household", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", mean=1.3, std=0.6, min=1, max=6 + ), + ), + grounding=GroundingInfo(level="medium", method="estimated"), + ), + AttributeSpec( + name="has_children", + type="boolean", + category="universal", + description="Whether the agent has children under 18", + sampling=SamplingConfig( + strategy="independent", + distribution=BooleanDistribution( + type="boolean", + probability_true=0.35, + ), + ), + grounding=GroundingInfo(level="medium", method="estimated"), + ), + ], + sampling_order=[ + "age", + "gender", + "marital_status", + "household_size", + "has_children", + ], + ) + + class TestHouseholdModels: def test_household_type_enum(self): assert HouseholdType.SINGLE.value == "single" @@ -290,6 +386,63 @@ def test_correlate_with_default_rate(self): # Expect ~85% (default_same_group_rate) assert 0.75 < rate < 0.95, f"Default rate {rate:.2f} outside expected range" + def test_correlate_uses_semantic_type_for_age_policy(self): + """Semantic metadata should trigger gaussian age policy without name matching.""" + rng = random.Random(42) + values = [ + correlate_partner_attribute( + "years_old", + "int", + 35, + None, + rng, + _DEFAULT_CONFIG, + semantic_type="age", + ) + for _ in range(200) + ] + assert all(v >= _DEFAULT_CONFIG.min_adult_age for v in values) + assert len(set(values)) > 1 + + def test_correlate_uses_identity_type_for_group_rate_policy(self): + """Identity metadata should trigger same-group rate without name matching.""" + rng = random.Random(42) + same_count = 0 + trials = 500 + for _ in range(trials): + result = correlate_partner_attribute( + "ethnic_group", + "categorical", + "white", + None, + rng, + _DEFAULT_CONFIG, + available_options=["white", "black", "hispanic"], + identity_type="race_ethnicity", + ) + if result == "white": + same_count += 1 + rate = same_count / trials + assert 0.80 < rate < 0.97, f"Same-group rate {rate:.2f} outside expected range" + + def test_correlate_respects_explicit_policy_override(self): + """Explicit partner policy should override inferred/default behavior.""" + rng = random.Random(42) + values = [ + correlate_partner_attribute( + "custom_numeric", + "int", + 40, + None, + rng, + _DEFAULT_CONFIG, + partner_correlation_policy="gaussian_offset", + ) + for _ in range(100) + ] + assert all(v >= _DEFAULT_CONFIG.min_adult_age for v in values) + assert len(set(values)) > 1 + def test_generate_dependents_no_kids(self): rng = random.Random(42) deps = generate_dependents( @@ -465,3 +618,64 @@ def test_country_correlation(self): rate = same_country / total # Should be close to 0.95 (within statistical margin) assert 0.90 < rate < 0.99, f"Same-country rate {rate:.2%} out of expected range" + + +class TestHouseholdReconciliation: + def test_partnered_agents_are_not_left_single(self): + spec = _make_household_consistency_spec(size=500) + result = sample_population(spec, count=500, seed=42) + + partnered = [ + a + for a in result.agents + if a.get("partner_id") is not None or a.get("partner_npc") is not None + ] + assert partnered, "Expected at least some partnered agents" + assert all(a.get("marital_status") == "Married" for a in partnered) + + def test_household_size_matches_actual_membership(self): + spec = _make_household_consistency_spec(size=300) + result = sample_population(spec, count=300, seed=7, agent_focus_mode="couples") + + by_household: dict[str, list[dict]] = {} + for agent in result.agents: + by_household.setdefault(agent["household_id"], []).append(agent) + + for members in by_household.values(): + primary = next( + (m for m in members if m.get("household_role") == "adult_primary"), + members[0], + ) + dependents = primary.get("dependents", []) + expected_size = len(members) + ( + len(dependents) if isinstance(dependents, list) else 0 + ) + for member in members: + assert member.get("household_size") == expected_size + + def test_has_children_matches_generated_dependents(self): + spec = _make_household_consistency_spec(size=300) + result = sample_population(spec, count=300, seed=21) + + by_household: dict[str, list[dict]] = {} + for agent in result.agents: + by_household.setdefault(agent["household_id"], []).append(agent) + + for members in by_household.values(): + primary = next( + (m for m in members if m.get("household_role") == "adult_primary"), + members[0], + ) + dependents = primary.get("dependents", []) + child_count = 0 + if isinstance(dependents, list): + for dep in dependents: + relationship = str(dep.get("relationship", "")).lower() + if any( + token in relationship + for token in ("son", "daughter", "child", "kid") + ): + child_count += 1 + expected = child_count > 0 + for member in members: + assert member.get("has_children") is expected diff --git a/tests/test_persona_renderer.py b/tests/test_persona_renderer.py index a1fbb00..0eff322 100644 --- a/tests/test_persona_renderer.py +++ b/tests/test_persona_renderer.py @@ -1,7 +1,17 @@ """Tests for persona categorical rendering behavior.""" -from extropy.population.persona.config import CategoricalPhrasing -from extropy.population.persona.renderer import _format_categorical_value +from extropy.population.persona.config import ( + AttributeGroup, + AttributePhrasing, + AttributeTreatment, + CategoricalPhrasing, + ConcretePhrasing, + PersonaConfig, +) +from extropy.population.persona.renderer import ( + _format_categorical_value, + render_persona, +) def test_categorical_null_option_prefers_null_phrase(): @@ -57,3 +67,93 @@ def test_categorical_non_null_options_render_normally(): rendered = _format_categorical_value("moderator", phrasing) assert rendered == "I moderate one or more communities" + + +def _make_contextual_test_config() -> PersonaConfig: + return PersonaConfig( + population_description="Test population", + intro_template="I'm {age} years old.", + treatments=[ + AttributeTreatment(attribute="age", treatment="concrete", group="identity"), + AttributeTreatment( + attribute="employment_status", treatment="concrete", group="work" + ), + AttributeTreatment( + attribute="occupation", treatment="concrete", group="identity" + ), + ], + groups=[ + AttributeGroup( + name="identity", + label="Who I Am", + attributes=["age", "occupation"], + ), + AttributeGroup( + name="work", + label="Work", + attributes=["employment_status"], + ), + ], + phrasings=AttributePhrasing( + boolean=[], + categorical=[ + CategoricalPhrasing( + attribute="employment_status", + phrases={ + "Unemployed": "I'm currently unemployed and looking for work." + }, + null_options=[], + null_phrase=None, + fallback=None, + ), + CategoricalPhrasing( + attribute="occupation", + phrases={"Tech": "I work in software engineering."}, + null_options=[], + null_phrase=None, + fallback=None, + ), + ], + relative=[], + concrete=[ + ConcretePhrasing( + attribute="age", + template="I'm {value} years old", + ) + ], + ), + ) + + +def test_render_persona_adjusts_occupation_when_unemployed(): + config = _make_contextual_test_config() + agent = { + "age": 32, + "employment_status": "Unemployed", + "occupation": "Tech", + } + + rendered = render_persona( + agent, + config, + semantic_type_map={ + "employment_status": "employment", + "occupation": "occupation", + }, + ) + + assert "My background is in software engineering." in rendered + assert "I work in software engineering." not in rendered + + +def test_render_persona_uses_more_about_me_instead_of_duplicate_header(): + config = _make_contextual_test_config() + agent = { + "age": 32, + "employment_status": "Unemployed", + "occupation": "Tech", + } + + rendered = render_persona(agent, config) + assert rendered.count("## Who I Am") == 1 + assert "## More About Me" in rendered diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 9c63433..ba09eca 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1244,3 +1244,195 @@ def test_quick_validate_accepts_static_max_for_constraint(self): result = validate_conditional_base_response(data, ["score"]) # Should NOT error - static max=100 satisfies the constraint assert not any("max_formula" in str(e.suggestion) for e in result.errors) + + def test_quick_validate_categorical_bad_weights_are_reported_not_crashed(self): + """Mixed invalid weight types should produce validation errors, not runtime crashes.""" + from extropy.population.validator import validate_independent_response + + data = { + "attributes": [ + { + "name": "media_source", + "distribution": { + "type": "categorical", + "options": ["tv", "social"], + "weights": [0.5, "bad-value"], + }, + "constraints": [], + }, + ], + } + + result = validate_independent_response(data, ["media_source"]) + assert not result.valid + assert any( + e.location == "media_source.distribution.weights[1]" for e in result.errors + ) + + def test_quick_validate_modifier_bad_weight_overrides_are_reported(self): + """Invalid weight_overrides values should be flagged as validation errors.""" + from extropy.population.validator import validate_modifier_data + + errors = validate_modifier_data( + modifier_data={ + "when": "region == 'Urban'", + "weight_overrides": {"A": 0.6, "B": "bad-value"}, + }, + attr_name="political_affiliation", + modifier_index=0, + dist_type="categorical", + ) + assert any( + e.location == "political_affiliation.modifiers[0].weight_overrides.B" + for e in errors + ) + + +class TestModifierConditionErrorHandling: + """Runtime behavior for modifier condition evaluation failures.""" + + def _make_spec_with_bad_modifier_condition(self) -> PopulationSpec: + return PopulationSpec( + meta=SpecMeta(description="bad condition test", size=10), + grounding=GroundingSummary( + overall="low", + sources_count=0, + strong_count=0, + medium_count=0, + low_count=2, + ), + attributes=[ + AttributeSpec( + name="age", + type="int", + category="universal", + description="Age", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", + mean=40, + std=10, + min=18, + max=80, + ), + ), + grounding=GroundingInfo(level="low", method="estimated"), + ), + AttributeSpec( + name="salary", + type="float", + category="population_specific", + description="Salary", + sampling=SamplingConfig( + strategy="conditional", + distribution=NormalDistribution( + type="normal", + mean=70000, + std=15000, + min=0, + max=300000, + ), + depends_on=["age"], + modifiers=[ + Modifier( + when="unknown_var > 0", + multiply=1.1, + ) + ], + ), + grounding=GroundingInfo(level="low", method="estimated"), + ), + ], + sampling_order=["age", "salary"], + ) + + def test_strict_condition_errors_raise_sampling_error(self): + spec = self._make_spec_with_bad_modifier_condition() + with pytest.raises(SamplingError): + sample_population(spec, count=20, seed=42, strict_condition_errors=True) + + def test_permissive_condition_errors_record_warnings(self): + spec = self._make_spec_with_bad_modifier_condition() + result = sample_population( + spec, count=20, seed=42, strict_condition_errors=False + ) + assert len(result.agents) == 20 + assert len(result.stats.condition_warnings) > 0 + + +class TestExpressionConstraintEnforcement: + """Strict/permissive handling for expression constraints.""" + + def _make_constraint_violation_spec(self) -> PopulationSpec: + return PopulationSpec( + meta=SpecMeta(description="constraint enforcement test", size=20), + grounding=GroundingSummary( + overall="low", + sources_count=0, + strong_count=0, + medium_count=0, + low_count=2, + ), + attributes=[ + AttributeSpec( + name="age", + type="int", + category="universal", + description="Age", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", + mean=24, + std=2, + min=18, + max=30, + ), + ), + grounding=GroundingInfo(level="low", method="estimated"), + ), + AttributeSpec( + name="children_count", + type="int", + category="universal", + description="Children count", + sampling=SamplingConfig( + strategy="independent", + distribution=NormalDistribution( + type="normal", + mean=3, + std=1, + min=0, + max=6, + ), + ), + constraints=[ + Constraint( + type="expression", + expression="value <= max(0, age - 30)", + ) + ], + grounding=GroundingInfo(level="low", method="estimated"), + ), + ], + sampling_order=["age", "children_count"], + ) + + def test_enforced_constraints_raise_sampling_error(self): + spec = self._make_constraint_violation_spec() + with pytest.raises(SamplingError): + sample_population( + spec, count=30, seed=123, enforce_expression_constraints=True + ) + + def test_non_enforced_constraints_record_violations(self): + spec = self._make_constraint_violation_spec() + result = sample_population( + spec, + count=30, + seed=123, + enforce_expression_constraints=False, + ) + assert len(result.agents) == 30 + assert result.stats.constraint_violations diff --git a/tests/test_scenario_validator.py b/tests/test_scenario_validator.py index bf10a1f..cb1f3da 100644 --- a/tests/test_scenario_validator.py +++ b/tests/test_scenario_validator.py @@ -18,6 +18,7 @@ SeedExposure, SpreadModifier, SpreadConfig, + TimelineEvent, ) from extropy.scenario.validator import load_and_validate_scenario, validate_scenario from extropy.storage import open_study_db @@ -182,3 +183,133 @@ def test_validate_scenario_allows_edge_weight_in_spread_modifier(tmp_path: Path) if issue.location == "spread.share_modifiers[0].when" ] assert not edge_weight_errors + + +def test_validate_scenario_allows_extended_attribute_reference( + minimal_population_spec, +): + """Extended attributes should be valid in scenario when-expressions.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.seed_exposure.rules[0].when = "extended_signal > 0" + spec.extended_attributes = [ + minimal_population_spec.attributes[0].model_copy( + update={"name": "extended_signal"} + ) + ] + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + ref_errors = [ + issue + for issue in result.errors + if issue.location == "seed_exposure.rules[0].when" + and issue.category == "attribute_reference" + ] + assert not ref_errors + + +def test_validate_scenario_still_rejects_unknown_attribute_reference( + minimal_population_spec, +): + """Unknown attributes must still fail even with extended attrs present.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.seed_exposure.rules[0].when = "missing_signal > 0" + spec.extended_attributes = [ + minimal_population_spec.attributes[0].model_copy( + update={"name": "extended_signal"} + ) + ] + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + assert any( + issue.location == "seed_exposure.rules[0].when" + and issue.category == "attribute_reference" + for issue in result.errors + ) + + +def test_validate_scenario_rejects_invalid_seed_condition_literal( + minimal_population_spec, +): + """Seed rule literals must match categorical option domains exactly.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.seed_exposure.rules[0].when = "gender == 'Male'" + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + assert any( + issue.location == "seed_exposure.rules[0].when" + and issue.category == "condition_value" + for issue in result.errors + ) + + +def test_validate_scenario_rejects_invalid_spread_modifier_literal( + minimal_population_spec, +): + """Spread modifier literals must match categorical option domains exactly.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.spread.share_modifiers = [ + SpreadModifier(when="gender == 'Male'", multiply=1.1, add=0.0) + ] + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + assert any( + issue.location == "spread.share_modifiers[0].when" + and issue.category == "condition_value" + for issue in result.errors + ) + + +def test_validate_scenario_rejects_invalid_timeline_rule_literal( + minimal_population_spec, +): + """Timeline exposure-rule literals must match categorical option domains.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.timeline = [ + TimelineEvent( + timestep=1, + event=Event( + type=EventType.NEWS, + content="Follow-up event", + source="Test source", + credibility=0.9, + ambiguity=0.1, + emotional_valence=0.0, + ), + exposure_rules=[ + ExposureRule( + channel="official_notice", + when="gender == 'Male'", + probability=1.0, + timestep=0, + ) + ], + ) + ] + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + assert any( + issue.location == "timeline[0].exposure_rules[0].when" + and issue.category == "condition_value" + for issue in result.errors + ) + + +def test_validate_scenario_accepts_valid_condition_literals(minimal_population_spec): + """Valid categorical literals should pass domain validation.""" + spec = _make_scenario_spec("population.yaml", "study.db") + spec.seed_exposure.rules[0].when = "gender == 'male'" + spec.spread.share_modifiers = [ + SpreadModifier(when="gender in ['male', 'female']", multiply=1.05, add=0.0) + ] + + result = validate_scenario(spec, population_spec=minimal_population_spec) + + condition_value_errors = [ + issue for issue in result.errors if issue.category == "condition_value" + ] + assert not condition_value_errors diff --git a/tests/test_validator.py b/tests/test_validator.py index 3643ef1..c4efef1 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -55,6 +55,7 @@ def make_attr( formula: str | None = None, depends_on: list[str] | None = None, modifiers: list[Modifier] | None = None, + modifier_overlap_policy: str | None = None, ) -> AttributeSpec: """Helper to create an AttributeSpec for testing.""" if distribution is None and strategy != "derived": @@ -78,6 +79,7 @@ def make_attr( formula=formula, depends_on=depends_on or [], modifiers=modifiers or [], + modifier_overlap_policy=modifier_overlap_policy, ), grounding=GroundingInfo(level="low", method="estimated"), ) @@ -691,6 +693,165 @@ def test_conditional_with_formula(self): assert any("formula" in str(e) for e in result.errors) +class TestPartnerCorrelationPolicyWarnings: + """Tests for partner-correlation policy completeness warnings.""" + + def test_partner_correlated_without_policy_metadata_warns(self): + attr = AttributeSpec( + name="custom_trait", + type="categorical", + category="population_specific", + description="Custom trait", + scope="partner_correlated", + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + options=["A", "B"], + weights=[0.5, 0.5], + ), + ), + grounding=GroundingInfo(level="low", method="estimated"), + ) + spec = make_spec([attr]) + result = validate_spec(spec) + + assert any(w.category == "PARTNER_POLICY" for w in result.warnings) + + def test_partner_correlated_with_explicit_rate_does_not_warn(self): + attr = AttributeSpec( + name="custom_trait", + type="categorical", + category="population_specific", + description="Custom trait", + scope="partner_correlated", + correlation_rate=0.65, + sampling=SamplingConfig( + strategy="independent", + distribution=CategoricalDistribution( + options=["A", "B"], + weights=[0.5, 0.5], + ), + ), + grounding=GroundingInfo(level="low", method="estimated"), + ) + spec = make_spec([attr]) + result = validate_spec(spec) + + assert not any(w.category == "PARTNER_POLICY" for w in result.warnings) + + +class TestModifierOverlapWarnings: + """Tests for categorical/boolean modifier overlap detection.""" + + def test_overlap_warning_for_categorical_modifiers(self): + region = make_attr( + "region", + "categorical", + distribution=CategoricalDistribution( + options=["Urban", "Suburban", "Rural"], + weights=[0.4, 0.3, 0.3], + ), + ) + audience = make_attr( + "audience_segment", + "categorical", + strategy="conditional", + distribution=CategoricalDistribution( + options=["A", "B"], + weights=[0.5, 0.5], + ), + depends_on=["region"], + modifiers=[ + Modifier( + when="region == 'Urban'", weight_overrides={"A": 0.8, "B": 0.2} + ), + Modifier( + when="region in ['Urban', 'Suburban']", + weight_overrides={"A": 0.3, "B": 0.7}, + ), + ], + ) + + spec = make_spec([region, audience], ["region", "audience_segment"]) + result = validate_spec(spec) + + assert any(w.category == "MODIFIER_OVERLAP" for w in result.warnings) + + def test_ordered_override_policy_suppresses_overlap_warning(self): + region = make_attr( + "region", + "categorical", + distribution=CategoricalDistribution( + options=["Urban", "Suburban", "Rural"], + weights=[0.4, 0.3, 0.3], + ), + ) + audience = make_attr( + "audience_segment", + "categorical", + strategy="conditional", + distribution=CategoricalDistribution( + options=["A", "B"], + weights=[0.5, 0.5], + ), + depends_on=["region"], + modifier_overlap_policy="ordered_override", + modifiers=[ + Modifier( + when="region == 'Urban'", weight_overrides={"A": 0.8, "B": 0.2} + ), + Modifier( + when="region in ['Urban', 'Suburban']", + weight_overrides={"A": 0.3, "B": 0.7}, + ), + ], + ) + + spec = make_spec([region, audience], ["region", "audience_segment"]) + result = validate_spec(spec) + + assert not any(w.category == "MODIFIER_OVERLAP" for w in result.warnings) + + def test_exclusive_policy_still_warns_on_overlap(self): + region = make_attr( + "region", + "categorical", + distribution=CategoricalDistribution( + options=["Urban", "Suburban", "Rural"], + weights=[0.4, 0.3, 0.3], + ), + ) + audience = make_attr( + "audience_segment", + "categorical", + strategy="conditional", + distribution=CategoricalDistribution( + options=["A", "B"], + weights=[0.5, 0.5], + ), + depends_on=["region"], + modifier_overlap_policy="exclusive", + modifiers=[ + Modifier( + when="region == 'Urban'", weight_overrides={"A": 0.8, "B": 0.2} + ), + Modifier( + when="region in ['Urban', 'Suburban']", + weight_overrides={"A": 0.3, "B": 0.7}, + ), + ], + ) + + spec = make_spec([region, audience], ["region", "audience_segment"]) + result = validate_spec(spec) + + overlap_warnings = [ + w for w in result.warnings if w.category == "MODIFIER_OVERLAP_EXCLUSIVE" + ] + assert overlap_warnings + assert "exclusive" in overlap_warnings[0].message + + class TestValidationIssue: """Tests for ValidationIssue class."""