From e9287eb305645a35b60f5098dcb63329583ceaab Mon Sep 17 00:00:00 2001 From: Chris Mihiar <28452317+mihiarc@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:13:29 -0500 Subject: [PATCH 1/3] Fix bugs and clean up technical debt from critical review - Fix version mismatch: __init__.py now matches pyproject.toml (1.2.3) - Fix FIA.__exit__ no-op: added close() to properly disconnect DB backend - Remove pandas dependency: not used anywhere in src/pyfia/ - Fix hardcoded year 2023 default in format_output - Preserve AREA_TOTAL column in aggregation output (was dropped prematurely) - Replace lru_cache on instance method with manual cache to prevent GC leak - Migrate carbon_pools to AggregationResult pattern, remove dead code - Add algebraic identity documentation to variance formula - Use NotImplementedError for aggregate_results (not abstract) to support estimators like AreaChangeEstimator that override estimate() entirely All 677 unit tests pass. --- pyproject.toml | 4 +- src/pyfia/__init__.py | 2 +- src/pyfia/core/fia.py | 50 ++- src/pyfia/estimation/aggregation.py | 4 +- src/pyfia/estimation/base.py | 323 +++++++++--------- src/pyfia/estimation/data_loading.py | 24 +- .../estimation/estimators/carbon_pools.py | 179 ++-------- src/pyfia/estimation/variance.py | 5 + 8 files changed, 254 insertions(+), 337 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48732ca..ccaf518 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,13 +67,11 @@ dependencies = [ "pydantic-settings>=2.0.0", "rich", "pyyaml>=6.0.0", - "pandas>=2.0.0", "requests>=2.31.0", ] [project.optional-dependencies] spatial = ["geopandas>=0.14.0", "shapely>=2.0.0"] -pandas = ["pandas>=2.0.0"] dev = [ "pytest>=8.4.1", "pytest-cov>=6.2.1", @@ -86,7 +84,7 @@ dev = [ "mkdocs-material>=9.5.0", "pre-commit>=3.6.0", ] -all = ["pyfia[spatial,pandas,dev]"] +all = ["pyfia[spatial,dev]"] [project.urls] Homepage = "https://github.com/mihiarc/pyfia" diff --git a/src/pyfia/__init__.py b/src/pyfia/__init__.py index b937ee3..c20d85b 100755 --- a/src/pyfia/__init__.py +++ b/src/pyfia/__init__.py @@ -11,7 +11,7 @@ - AskFIA: AI conversational interface (https://github.com/mihiarc/askfia) """ -__version__ = "1.1.0b7" +__version__ = "1.2.3" __author__ = "Chris Mihiar" # Core exports - Main functionality diff --git a/src/pyfia/core/fia.py b/src/pyfia/core/fia.py index 9afb520..f3a312f 100755 --- a/src/pyfia/core/fia.py +++ b/src/pyfia/core/fia.py @@ -10,6 +10,7 @@ import logging import warnings from pathlib import Path +from typing import TYPE_CHECKING import polars as pl @@ -24,6 +25,9 @@ SpatialFileError, ) +if TYPE_CHECKING: + from .backends import MotherDuckBackend + logger = logging.getLogger(__name__) @@ -96,8 +100,10 @@ def __init__(self, db_path: str | Path, engine: str | None = None): "motherduck:" ) + # Type annotation: str for MotherDuck, Path for local files + self.db_path: str | Path if self._is_motherduck: - self.db_path = db_str # type: ignore[assignment] + self.db_path = db_str else: self.db_path = Path(db_path) if not self.db_path.exists(): @@ -186,8 +192,13 @@ def __enter__(self): def __exit__(self, _exc_type, _exc_val, _exc_tb): """Context manager exit.""" - # Connection cleanup handled by FIADataReader - pass + self.close() + + def close(self): + """Close the database connection.""" + if hasattr(self, "_reader") and self._reader: + if hasattr(self._reader, "_backend") and self._reader._backend: + self._reader._backend.disconnect() # Connection management moved to FIADataReader with backend support @@ -331,6 +342,10 @@ def query_batch(batch: list) -> pl.LazyFrame: result = batch_query_by_values(valid_plot_cns, query_batch) + # Ensure result is a LazyFrame for consistent processing + if isinstance(result, pl.DataFrame): + result = result.lazy() + # Join polygon attributes for PLOT table if available if table_name == "PLOT" and self._polygon_attributes is not None: result = result.join( @@ -437,12 +452,23 @@ def find_evalid( if eval_type is not None: # FIA uses 'EXP' prefix for evaluation types - # Special case: "ALL" maps to "EXPALL" for area estimation - if eval_type.upper() == "ALL": + # Special cases: + # - "ALL" maps to "EXPALL" for area estimation + # - "GRM" maps to EXPGROW, EXPMORT, or EXPREMV for growth/mortality/removals + # - "CURR" maps to "EXPCURR" for current area + eval_type_upper = eval_type.upper() + if eval_type_upper == "ALL": eval_type_full = "EXPALL" + df = df.filter(pl.col("EVAL_TYP") == eval_type_full) + elif eval_type_upper == "GRM": + # GRM is a composite type - filter for any of the three GRM eval types + grm_types = ["EXPGROW", "EXPMORT", "EXPREMV"] + df = df.filter(pl.col("EVAL_TYP").is_in(grm_types)) + elif eval_type_upper == "CURR": + df = df.filter(pl.col("EVAL_TYP") == "EXPCURR") else: - eval_type_full = f"EXP{eval_type}" - df = df.filter(pl.col("EVAL_TYP") == eval_type_full) + eval_type_full = f"EXP{eval_type_upper}" + df = df.filter(pl.col("EVAL_TYP") == eval_type_full) if most_recent: # Add parsed EVALID columns for robust year sorting @@ -1352,7 +1378,8 @@ def __init__(self, database: str, motherduck_token: str | None = None): self._backend.connect() # Create a minimal reader-like wrapper for compatibility - self._reader = _MotherDuckReaderWrapper(self._backend) + # Type note: _MotherDuckReaderWrapper provides the same interface as FIADataReader + self._reader = _MotherDuckReaderWrapper(self._backend) # type: ignore[assignment] def __enter__(self): """Context manager entry.""" @@ -1388,12 +1415,13 @@ class _MotherDuckReaderWrapper: without requiring a full FIADataReader instance. """ - def __init__(self, backend) -> None: + def __init__(self, backend: "MotherDuckBackend") -> None: self._backend = backend def get_table_schema(self, table_name: str) -> dict[str, str]: """Get schema for a table from the MotherDuck database.""" - return self._backend.get_table_schema(table_name) + schema: dict[str, str] = self._backend.get_table_schema(table_name) + return schema def read_table( self, @@ -1411,7 +1439,7 @@ def read_table( if where: query += f" WHERE {where}" - df = self._backend.execute_query(query) + df: pl.DataFrame = self._backend.execute_query(query) if lazy: return df.lazy() diff --git a/src/pyfia/estimation/aggregation.py b/src/pyfia/estimation/aggregation.py index c0e73cf..176b387 100644 --- a/src/pyfia/estimation/aggregation.py +++ b/src/pyfia/estimation/aggregation.py @@ -226,8 +226,8 @@ def compute_per_acre_values( results_df = results_df.with_columns(per_acre_exprs) - # Clean up intermediate columns (keep totals and per-acre values) - cols_to_drop = ["N_CONDITIONS", "AREA_TOTAL"] + # Clean up intermediate columns (keep totals, per-acre values, and AREA_TOTAL) + cols_to_drop = ["N_CONDITIONS"] for adj_col, cond_col in metric_mappings.items(): metric_name = cond_col.replace("CONDITION_", "") cols_to_drop.append(f"{metric_name}_NUM") diff --git a/src/pyfia/estimation/base.py b/src/pyfia/estimation/base.py index 5d7bb25..a8ce255 100644 --- a/src/pyfia/estimation/base.py +++ b/src/pyfia/estimation/base.py @@ -198,23 +198,23 @@ def apply_filters(self, data: pl.LazyFrame) -> pl.LazyFrame: data = data.filter(gs_filter) # "all" means no filter - # Apply land type filter using centralized indicator function - # This replaces magic numbers with named constants from status_codes.py + # Apply land type filter using centralized indicator function. + # NOTE: DataLoader._build_cond_sql_filter also pushes this to SQL for + # performance (reduces data transfer). This Polars filter is the + # definitive filter that guarantees correctness regardless of how + # data was loaded or cached. land_type = self.config.get("land_type", "forest") if land_type and land_type != "all" and "COND_STATUS_CD" in columns: data = data.filter(get_land_domain_indicator(land_type)) return data - def aggregate_results( - self, data: pl.LazyFrame | None - ) -> AggregationResult | pl.DataFrame: + def aggregate_results(self, data: pl.LazyFrame | None) -> AggregationResult: """ Aggregate results with stratification. - Subclasses should override this to return AggregationResult for proper - variance calculation. The base implementation returns a DataFrame for - backward compatibility. + Subclasses must implement this to return AggregationResult containing + the aggregated results, plot-tree data for variance, and group columns. Parameters ---------- @@ -223,57 +223,23 @@ def aggregate_results( Returns ------- - Union[AggregationResult, pl.DataFrame] - AggregationResult with results, plot_tree_data, and group_cols, - or DataFrame for backward compatibility + AggregationResult + Bundle with results, plot_tree_data, and group_cols for + variance calculation. """ - # Get stratification data - strat_data = self._get_stratification_data() - - if data is None: - # Area-only estimation - return self._aggregate_area_only(strat_data) - - # Join with stratification - data_with_strat = data.join(strat_data, on="PLT_CN", how="inner") - - # Setup grouping columns - group_cols = self._setup_grouping() - - # Aggregate by groups - if group_cols: - results = ( - data_with_strat.group_by(group_cols) - .agg( - [ - pl.sum("ESTIMATE_VALUE").alias("ESTIMATE"), - pl.count("PLT_CN").alias("N_PLOTS"), - ] - ) - .collect() - ) - else: - results = data_with_strat.select( - [ - pl.sum("ESTIMATE_VALUE").alias("ESTIMATE"), - pl.count("PLT_CN").alias("N_PLOTS"), - ] - ).collect() - - return results + raise NotImplementedError( + f"{type(self).__name__} must implement aggregate_results()" + ) - def calculate_variance( - self, agg_result: AggregationResult | pl.DataFrame - ) -> pl.DataFrame: + def calculate_variance(self, agg_result: AggregationResult) -> pl.DataFrame: """ Calculate variance for estimates. Parameters ---------- - agg_result : Union[AggregationResult, pl.DataFrame] - Either an AggregationResult containing results, plot_tree_data, - and group_cols for explicit data passing, or a DataFrame for - backward compatibility with subclasses that haven't been updated. + agg_result : AggregationResult + Bundle containing results, plot_tree_data, and group_cols + from aggregate_results(). Returns ------- @@ -306,9 +272,8 @@ def format_output(self, results: pl.DataFrame) -> pl.DataFrame: Formatted results """ # Add metadata columns - results = results.with_columns( - [pl.lit(self.config.get("year", 2023)).alias("YEAR")] - ) + year = self._extract_evaluation_year() + results = results.with_columns([pl.lit(year).alias("YEAR")]) # Reorder columns col_order = ["YEAR", "ESTIMATE", "SE", "N_PLOTS"] @@ -497,11 +462,6 @@ def _get_stratification_data(self) -> pl.LazyFrame: """ return self.data_loader.get_stratification_data() - def _aggregate_area_only(self, strat_data: pl.LazyFrame) -> pl.DataFrame: - """Handle area-only aggregation without tree data.""" - # This would be implemented by area estimator - return pl.DataFrame() - def _preserve_plot_tree_data( self, data_with_strat: pl.LazyFrame, @@ -1023,9 +983,11 @@ def _calculate_variance_for_metrics( ) # Build plot-level aggregation - plot_agg_exprs = [ - pl.sum("CONDPROP_UNADJ").cast(pl.Float64).alias("x_i") - ] if "CONDPROP_UNADJ" in plot_cond_data.columns else [pl.lit(1.0).alias("x_i")] + plot_agg_exprs = ( + [pl.sum("CONDPROP_UNADJ").cast(pl.Float64).alias("x_i")] + if "CONDPROP_UNADJ" in plot_cond_data.columns + else [pl.lit(1.0).alias("x_i")] + ) for i in range(len(metric_configs)): plot_agg_exprs.append(pl.sum(f"y_{i}_ic").alias(f"y_{i}_i")) @@ -1072,96 +1034,113 @@ def _calculate_grouped_multi_metric_variance( """ Calculate variance for grouped estimates with multiple metrics. - Uses a loop over groups for now; could be further optimized with - vectorized operations in the future. + Uses vectorized operations via calculate_grouped_domain_total_variance + to compute variance for all groups in a single pass per metric, + avoiding the O(groups × metrics) loop pattern. """ - from .variance import calculate_domain_total_variance + from .variance import calculate_grouped_domain_total_variance - variance_results = [] + # Get valid group columns that exist in the data + valid_group_cols = [c for c in group_cols if c in plot_data.columns] - for group_vals in results.iter_rows(): - # Build filter for this group - group_filter = pl.lit(True) - group_dict = {} + if not valid_group_cols: + # No valid grouping - fall back to overall calculation + return self._calculate_overall_multi_metric_variance( + plot_data, all_plots, results, metric_configs + ) - for col in group_cols: - if col in plot_data.columns: - val = group_vals[results.columns.index(col)] - group_dict[col] = val - if val is None: - group_filter = group_filter & pl.col(col).is_null() - else: - group_filter = group_filter & (pl.col(col) == val) + # Step 1: Get unique group values from results + # This ensures we only create variance rows for groups that have estimates + unique_groups = results.select(valid_group_cols).unique() - # Filter plot data for this specific group - group_plot_data = plot_data.filter(group_filter) + # Step 2: Cross-join all_plots with unique_groups + # Result: Every plot appears once for each group value + # This is essential for correct variance calculation - plots without + # a given forest type should contribute y=0 to that type's variance + all_plots_expanded = all_plots.join(unique_groups, how="cross") - # Build select columns for join - select_cols = ["PLT_CN", "x_i"] + [f"y_{i}_i" for i in range(len(metric_configs))] - select_cols = [c for c in select_cols if c in group_plot_data.columns] + # Step 3: Prepare plot_data columns for join + select_cols = ["PLT_CN", "x_i"] + valid_group_cols + select_cols += [f"y_{i}_i" for i in range(len(metric_configs))] + select_cols = [c for c in select_cols if c in plot_data.columns] - # Join with ALL plots, filling missing with zeros - all_plots_group = all_plots.join( - group_plot_data.select(select_cols), - on="PLT_CN", - how="left", + # Step 4: Left join on PLT_CN + group_cols (KEY FIX for Issue #68) + # Plot P1 with FORTYPCD=809 data only matches (P1, 809) row + # All other (P1, other_type) rows get NULL y values -> filled with 0 + join_keys = ["PLT_CN"] + valid_group_cols + all_plots_with_data = all_plots_expanded.join( + plot_data.select(select_cols), + on=join_keys, + how="left", + ) + + # Step 5: Fill NULL y values with 0.0 + # Plots without data for a specific group contribute 0 to that group's estimate + fill_exprs = [pl.col("x_i").fill_null(0.0)] + for i in range(len(metric_configs)): + col_name = f"y_{i}_i" + if col_name in all_plots_with_data.columns: + fill_exprs.append(pl.col(col_name).fill_null(0.0)) + + all_plots_with_data = all_plots_with_data.with_columns(fill_exprs) + + # Step 2: Calculate variance for each metric using vectorized operations + variance_dfs = [] + for i, cfg in enumerate(metric_configs): + y_col = f"y_{i}_i" + if y_col not in all_plots_with_data.columns: + continue + + # Use vectorized grouped variance calculation + var_df = calculate_grouped_domain_total_variance( + all_plots_with_data, + group_cols=valid_group_cols, + y_col=y_col, + x_col="x_i", + stratum_col="STRATUM_CN", + weight_col="EXPNS", ) - # Fill nulls with zeros - fill_exprs = [pl.col("x_i").fill_null(0.0)] - for i in range(len(metric_configs)): - col_name = f"y_{i}_i" - if col_name in all_plots_group.columns: - fill_exprs.append(pl.col(col_name).fill_null(0.0)) - - all_plots_group = all_plots_group.with_columns(fill_exprs) - - # Calculate variance for each metric - result_row = dict(group_dict) - - if len(all_plots_group) > 0: - # Calculate total area for per-acre SE - total_area = ( - all_plots_group["EXPNS"] * all_plots_group["x_i"] - ).sum() - - for i, cfg in enumerate(metric_configs): - y_col = f"y_{i}_i" - if y_col in all_plots_group.columns: - var_stats = calculate_domain_total_variance( - all_plots_group, y_col - ) - se_acre = ( - var_stats["se_total"] / total_area if total_area > 0 else 0.0 - ) + # Rename columns to match metric config + rename_dict = { + "se_acre": cfg["acre_se_col"], + "se_total": cfg["total_se_col"], + } + if "acre_var_col" in cfg: + rename_dict["variance_acre"] = cfg["acre_var_col"] + if "total_var_col" in cfg: + rename_dict["variance_total"] = cfg["total_var_col"] + + # Select only the columns we need + keep_cols = valid_group_cols + list(rename_dict.keys()) + keep_cols = [c for c in keep_cols if c in var_df.columns] + var_df = var_df.select(keep_cols).rename(rename_dict) + + variance_dfs.append(var_df) + + # Step 3: Join all variance results together + if not variance_dfs: + return results - result_row[cfg["acre_se_col"]] = se_acre - result_row[cfg["total_se_col"]] = var_stats["se_total"] + # Start with the first variance df and join others + combined_var = variance_dfs[0] + for var_df in variance_dfs[1:]: + # Get new columns (not group cols) - only select these from right side + # to avoid duplicate columns with _right suffix from outer join + new_cols = [c for c in var_df.columns if c not in valid_group_cols] + if not new_cols: + continue + # Use left join since all group values should be present in first df + combined_var = combined_var.join( + var_df.select(valid_group_cols + new_cols), + on=valid_group_cols, + how="left", + ) - # Add variance columns if specified - if "acre_var_col" in cfg: - result_row[cfg["acre_var_col"]] = se_acre**2 - if "total_var_col" in cfg: - result_row[cfg["total_var_col"]] = var_stats["variance_total"] - else: - # No data for this group - for cfg in metric_configs: - result_row[cfg["acre_se_col"]] = 0.0 - result_row[cfg["total_se_col"]] = 0.0 - if "acre_var_col" in cfg: - result_row[cfg["acre_var_col"]] = 0.0 - if "total_var_col" in cfg: - result_row[cfg["total_var_col"]] = 0.0 - - variance_results.append(result_row) - - # Join variance results back to main results - if variance_results: - var_df = pl.DataFrame(variance_results) - # Use only valid group columns that exist in both dataframes - join_cols = [c for c in group_cols if c in var_df.columns and c in results.columns] - if join_cols: - results = results.join(var_df, on=join_cols, how="left") + # Step 4: Join variance results back to main results + join_cols = [c for c in valid_group_cols if c in results.columns] + if join_cols: + results = results.join(combined_var, on=join_cols, how="left") return results @@ -1178,7 +1157,9 @@ def _calculate_overall_multi_metric_variance( from .variance import calculate_domain_total_variance # Build select columns for join - select_cols = ["PLT_CN", "x_i"] + [f"y_{i}_i" for i in range(len(metric_configs))] + select_cols = ["PLT_CN", "x_i"] + [ + f"y_{i}_i" for i in range(len(metric_configs)) + ] select_cols = [c for c in select_cols if c in plot_data.columns] # Join with ALL plots, filling missing with zeros @@ -1211,20 +1192,28 @@ def _calculate_overall_multi_metric_variance( ) se_acre = var_stats["se_total"] / total_area if total_area > 0 else 0.0 - results = results.with_columns([ - pl.lit(se_acre).alias(cfg["acre_se_col"]), - pl.lit(var_stats["se_total"]).alias(cfg["total_se_col"]), - ]) + results = results.with_columns( + [ + pl.lit(se_acre).alias(cfg["acre_se_col"]), + pl.lit(var_stats["se_total"]).alias(cfg["total_se_col"]), + ] + ) # Add variance columns if specified if "acre_var_col" in cfg: - results = results.with_columns([ - pl.lit(se_acre**2).alias(cfg["acre_var_col"]), - ]) + results = results.with_columns( + [ + pl.lit(se_acre**2).alias(cfg["acre_var_col"]), + ] + ) if "total_var_col" in cfg: - results = results.with_columns([ - pl.lit(var_stats["variance_total"]).alias(cfg["total_var_col"]), - ]) + results = results.with_columns( + [ + pl.lit(var_stats["variance_total"]).alias( + cfg["total_var_col"] + ), + ] + ) return results @@ -1242,21 +1231,25 @@ def _add_cv_columns( if acre_col and acre_col in results.columns: cv_col = acre_se_col.replace("_SE", "_CV") - results = results.with_columns([ - pl.when(pl.col(acre_col) > 0) - .then(pl.col(acre_se_col) / pl.col(acre_col) * 100) - .otherwise(None) - .alias(cv_col), - ]) + results = results.with_columns( + [ + pl.when(pl.col(acre_col) > 0) + .then(pl.col(acre_se_col) / pl.col(acre_col) * 100) + .otherwise(None) + .alias(cv_col), + ] + ) if total_col and total_col in results.columns: cv_col = total_se_col.replace("_SE", "_CV") - results = results.with_columns([ - pl.when(pl.col(total_col) > 0) - .then(pl.col(total_se_col) / pl.col(total_col) * 100) - .otherwise(None) - .alias(cv_col), - ]) + results = results.with_columns( + [ + pl.when(pl.col(total_col) > 0) + .then(pl.col(total_se_col) / pl.col(total_col) * 100) + .otherwise(None) + .alias(cv_col), + ] + ) return results diff --git a/src/pyfia/estimation/data_loading.py b/src/pyfia/estimation/data_loading.py index 5d160fc..d32e8db 100644 --- a/src/pyfia/estimation/data_loading.py +++ b/src/pyfia/estimation/data_loading.py @@ -12,7 +12,6 @@ """ import logging -from functools import lru_cache from typing import List, Optional, Tuple import polars as pl @@ -58,6 +57,7 @@ def __init__(self, db: FIA, config: dict) -> None: """ self.db = db self.config = config + self._stratification_cache: Optional[pl.LazyFrame] = None # Validate grp_by columns early during initialization self._validate_grp_by_columns() @@ -488,6 +488,10 @@ def _build_tree_sql_filter(self) -> Optional[str]: This pushes common filters to the database level to reduce memory usage. + NOTE: STATUSCD filter is NOT pushed to SQL because the TREE table is + cached and subsequent calls with different tree_type would use stale data. + The STATUSCD filter is applied in Polars in BaseEstimator.apply_filters(). + Returns ------- Optional[str] @@ -495,17 +499,9 @@ def _build_tree_sql_filter(self) -> Optional[str]: """ filters = [] - # Tree type filter (most common optimization) - tree_type = self.config.get("tree_type", "live") - if tree_type == "live": - filters.append("STATUSCD = 1") - elif tree_type == "dead": - filters.append("STATUSCD = 2") - elif tree_type == "gs": - # Growing stock: live trees with valid tree class - # Note: TREECLCD filter applied in Polars since it's conditional - filters.append("STATUSCD = 1") - # "all" means no STATUSCD filter + # NOTE: Do NOT add STATUSCD filter here - it causes caching bugs! + # The STATUSCD filter is applied in apply_filters() instead. + # See: https://github.com/mihiarc/pyfia/issues/XXX # Basic validity filters (these are always applied in apply_tree_filters) filters.append("DIA IS NOT NULL") @@ -542,7 +538,6 @@ def _build_cond_sql_filter(self) -> Optional[str]: return " AND ".join(filters) return None - @lru_cache(maxsize=1) def get_stratification_data(self) -> pl.LazyFrame: """ Get stratification data with simple caching. @@ -552,6 +547,8 @@ def get_stratification_data(self) -> pl.LazyFrame: pl.LazyFrame Joined PPSA, POP_STRATUM, and PLOT data including MACRO_BREAKPOINT_DIA """ + if self._stratification_cache is not None: + return self._stratification_cache # Load PPSA if "POP_PLOT_STRATUM_ASSGN" not in self.db.tables: self.db.load_table("POP_PLOT_STRATUM_ASSGN") @@ -632,4 +629,5 @@ def get_stratification_data(self) -> pl.LazyFrame: # Join with PLOT to get MACRO_BREAKPOINT_DIA strat_data = strat_data.join(plot_selected, on="PLT_CN", how="left") + self._stratification_cache = strat_data return strat_data diff --git a/src/pyfia/estimation/estimators/carbon_pools.py b/src/pyfia/estimation/estimators/carbon_pools.py index 5adb61c..d286ad9 100644 --- a/src/pyfia/estimation/estimators/carbon_pools.py +++ b/src/pyfia/estimation/estimators/carbon_pools.py @@ -13,16 +13,18 @@ - Matches EVALIDator exactly for live tree carbon estimates """ -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import polars as pl from ...core import FIA -from ..base import BaseEstimator +from ..base import AggregationResult, BaseEstimator from ..constants import LBS_TO_SHORT_TONS from ..tree_expansion import apply_tree_adjustment_factors -from ..utils import validate_required_columns -from ..variance import calculate_domain_total_variance +from ..utils import validate_aggregation_result, validate_required_columns + +if TYPE_CHECKING: + from ..base import AggregationResult class CarbonPoolEstimator(BaseEstimator): @@ -138,13 +140,19 @@ def calculate_values(self, data: pl.LazyFrame) -> pl.LazyFrame: return data - def aggregate_results(self, data: pl.LazyFrame) -> pl.DataFrame: # type: ignore[override] + def aggregate_results(self, data: pl.LazyFrame) -> AggregationResult: # type: ignore[override] """ Aggregate carbon with two-stage aggregation for correct per-acre estimates. Implements FIA's design-based estimation methodology: Stage 1: Aggregate trees to plot-condition level Stage 2: Apply expansion factors and calculate population totals + + Returns + ------- + AggregationResult + Bundle containing results, plot_tree_data, and group_cols for + explicit variance calculation. """ # Validate required columns using shared utility validate_required_columns(data, ["PLT_CN", "CARBON_ACRE"], "carbon data") @@ -167,10 +175,9 @@ def aggregate_results(self, data: pl.LazyFrame) -> pl.DataFrame: # type: ignore # Setup grouping group_cols = self._setup_grouping() - self.group_cols = group_cols - # CRITICAL: Store plot-tree level data for variance calculation - self.plot_tree_data, data_with_strat = self._preserve_plot_tree_data( + # Preserve plot-tree level data for variance calculation + plot_tree_data, data_with_strat = self._preserve_plot_tree_data( data_with_strat, metric_cols=["CARBON_ADJ"], group_cols=group_cols, @@ -193,152 +200,40 @@ def aggregate_results(self, data: pl.LazyFrame) -> pl.DataFrame: # type: ignore if cols_to_drop: results = results.drop(cols_to_drop) - return results + return AggregationResult( + results=results, + plot_tree_data=plot_tree_data, + group_cols=group_cols, + ) - def calculate_variance(self, results: pl.DataFrame) -> pl.DataFrame: + def calculate_variance( + self, + agg_result: AggregationResult, # type: ignore[override] + ) -> pl.DataFrame: """ Calculate variance for carbon estimates using domain total variance formula. - Implements the stratified domain total variance formula from - Bechtold & Patterson (2005): - - V(Ŷ) = Σ_h W_h² × s²_yh × n_h - - Where W_h is the stratum expansion factor (EXPNS), s²_yh is the sample - variance within stratum h, and n_h is the number of plots in stratum h. + Uses the unified _calculate_variance_for_metrics method with the + stratified domain total variance formula from Bechtold & Patterson (2005). Raises ------ ValueError If plot_tree_data is not available for variance calculation. """ - if self.plot_tree_data is None: - raise ValueError( - "Plot-tree data is required for carbon variance calculation. " - "Cannot compute statistically valid standard errors without tree-level " - "data. Ensure data preservation is working correctly in the estimation " - "pipeline." - ) - - # Step 1: Aggregate to plot-condition level - plot_group_cols = ["PLT_CN", "CONDID", "EXPNS"] - if "STRATUM_CN" in self.plot_tree_data.columns: - plot_group_cols.insert(2, "STRATUM_CN") - - # Add grouping columns - if self.group_cols: - for col in self.group_cols: - if col in self.plot_tree_data.columns and col not in plot_group_cols: - plot_group_cols.append(col) - - plot_cond_agg = [ - pl.sum("CARBON_ADJ").alias("y_carb_ic"), # Carbon per condition + validate_aggregation_result(agg_result, "Carbon") + + metric_configs = [ + { + "adjusted_col": "CARBON_ADJ", + "acre_se_col": "CARBON_ACRE_SE", + "total_se_col": "CARBON_TOTAL_SE", + "acre_var_col": "CARBON_ACRE_VARIANCE", + "total_var_col": "CARBON_TOTAL_VARIANCE", + } ] - plot_cond_data = self.plot_tree_data.group_by(plot_group_cols).agg( - plot_cond_agg - ) - - # Step 2: Aggregate to plot level - plot_level_cols = ["PLT_CN", "EXPNS"] - if "STRATUM_CN" in plot_cond_data.columns: - plot_level_cols.insert(1, "STRATUM_CN") - if self.group_cols: - plot_level_cols.extend( - [c for c in self.group_cols if c in plot_cond_data.columns] - ) - - plot_data = plot_cond_data.group_by(plot_level_cols).agg( - [ - pl.sum("y_carb_ic").alias("y_carb_i"), # Total carbon per plot - pl.lit(1.0).alias("x_i"), # Area proportion per plot (full plot = 1) - ] - ) - - # Step 3: Calculate variance for each group or overall - if self.group_cols: - # Get ALL plots in the evaluation for proper variance calculation - strat_data = self._get_stratification_data() - all_plots = ( - strat_data.select("PLT_CN", "STRATUM_CN", "EXPNS").unique().collect() - ) - - # Calculate variance for each group separately - variance_results = [] - - for group_vals in results.iter_rows(): - # Build filter for this group - group_filter = pl.lit(True) - group_dict = {} - - for i, col in enumerate(self.group_cols): - if col in plot_data.columns: - group_dict[col] = group_vals[results.columns.index(col)] - group_filter = group_filter & ( - pl.col(col) == group_vals[results.columns.index(col)] - ) - - # Filter plot data for this specific group - group_plot_data = plot_data.filter(group_filter) - - # Join with ALL plots, filling missing with zeros - all_plots_group = all_plots.join( - group_plot_data.select(["PLT_CN", "y_carb_i", "x_i"]), - on="PLT_CN", - how="left", - ).with_columns( - [ - pl.col("y_carb_i").fill_null(0.0), - pl.col("x_i").fill_null(0.0), - ] - ) - - if len(all_plots_group) > 0: - carb_stats = calculate_domain_total_variance( - all_plots_group, "y_carb_i" - ) - # Calculate total area for per-acre SE - total_area = ( - all_plots_group["EXPNS"] * all_plots_group["x_i"] - ).sum() - se_acre = ( - carb_stats["se_total"] / total_area if total_area > 0 else 0.0 - ) - variance_results.append( - { - **group_dict, - "CARBON_ACRE_SE": se_acre, - "CARBON_TOTAL_SE": carb_stats["se_total"], - } - ) - else: - variance_results.append( - { - **group_dict, - "CARBON_ACRE_SE": 0.0, - "CARBON_TOTAL_SE": 0.0, - } - ) - - # Join variance results back to main results - if variance_results: - var_df = pl.DataFrame(variance_results) - results = results.join(var_df, on=self.group_cols, how="left") - else: - # No grouping, calculate overall variance - carb_stats = calculate_domain_total_variance(plot_data, "y_carb_i") - # Calculate total area for per-acre SE - total_area = (plot_data["EXPNS"] * plot_data["x_i"]).sum() - se_acre = carb_stats["se_total"] / total_area if total_area > 0 else 0.0 - - results = results.with_columns( - [ - pl.lit(se_acre).alias("CARBON_ACRE_SE"), - pl.lit(carb_stats["se_total"]).alias("CARBON_TOTAL_SE"), - ] - ) - - return results + return self._calculate_variance_for_metrics(agg_result, metric_configs) def format_output(self, results: pl.DataFrame) -> pl.DataFrame: """Format carbon estimation output.""" diff --git a/src/pyfia/estimation/variance.py b/src/pyfia/estimation/variance.py index 3720ecf..df7974d 100644 --- a/src/pyfia/estimation/variance.py +++ b/src/pyfia/estimation/variance.py @@ -14,6 +14,11 @@ - s²_yh is the sample variance within stratum h (with ddof=1) - n_h is the number of plots in stratum h +This is algebraically equivalent to the textbook form in Bechtold & +Patterson (2005): V(Ŷ_d) = Σ_h (N_h² / n_h) × s²_yh, because +W_h = N_h / n_h (EXPNS = total stratum acres / plots in stratum), +so W_h² × n_h = (N_h/n_h)² × n_h = N_h²/n_h. + This is the variance formula used by EVALIDator for tree-based estimates (volume, biomass, TPA, GRM) and produces SE estimates within 1-3% of EVALIDator output. From dcbf27e938024e76d4475b30843a58b6170eba56 Mon Sep 17 00:00:00 2001 From: Chris Mihiar <28452317+mihiarc@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:15:09 -0500 Subject: [PATCH 2/3] Fix area variance underestimation, remove deprecated code, update examples Bug fixes: - Fix Issue #68: Area variance severely underestimated for rare categories when using grp_by. Now cross-joins all stratum plots with group values so non-matching plots contribute y=0 to variance calculation. - Fix LazyFrame/DataFrame type mismatch in batch_query_by_values concat - Fix EVALIDator client using incorrect JSONDecodeError for empty responses - Pass area_domain to get_cond_columns in all estimators so domain filter columns are loaded from the database - Update area_change calculate_variance signature for AggregationResult Cleanup: - Remove deprecated assign_forest_type_group (use add_forest_type_group) - Remove 5 outdated example scripts, modernize 3 remaining ones - Add descriptive aliases for EVALIDator estimate type codes - Fix data_reader type annotation for MotherDuck path Tests: - Add 4 tests for rare category variance (Issue #68) - Update classification tests to use add_forest_type_group directly - Update reference table download test for bundled zip handling - Improve biomass validation skip reason documentation --- examples/database_backend_demo.py | 171 ----------- examples/harvest_panel_analysis.py | 236 --------------- examples/mortality_by_cause.py | 116 +++++-- examples/mortality_calculator_demo.py | 88 ------ examples/mortality_config_demo.py | 134 --------- examples/mortality_with_new_config.py | 191 ------------ examples/plot_data_access.py | 178 +++++++++-- examples/plot_domain_example.py | 272 ++++++++++++++--- examples/using_output_formatter.py | 186 ------------ scripts/validate_against_evalidator.py | 62 ++-- src/pyfia/core/data_reader.py | 4 +- src/pyfia/core/utils.py | 7 +- src/pyfia/estimation/columns.py | 21 +- src/pyfia/estimation/estimators/area.py | 121 +++++--- .../estimation/estimators/area_change.py | 25 +- src/pyfia/estimation/estimators/biomass.py | 4 +- src/pyfia/estimation/estimators/tpa.py | 4 +- src/pyfia/estimation/estimators/volume.py | 2 + src/pyfia/estimation/grm_base.py | 2 + src/pyfia/evalidator/client.py | 6 +- src/pyfia/evalidator/estimate_types.py | 13 + src/pyfia/filtering/__init__.py | 2 - src/pyfia/filtering/utils.py | 47 --- tests/e2e/test_download_e2e.py | 18 +- tests/unit/test_classification.py | 97 ++---- tests/unit/test_variance_formulas.py | 284 ++++++++++++++++++ tests/validation/test_biomass.py | 16 +- uv.lock | 4 +- 28 files changed, 984 insertions(+), 1327 deletions(-) delete mode 100644 examples/database_backend_demo.py delete mode 100644 examples/harvest_panel_analysis.py delete mode 100644 examples/mortality_calculator_demo.py delete mode 100644 examples/mortality_config_demo.py delete mode 100644 examples/mortality_with_new_config.py delete mode 100644 examples/using_output_formatter.py diff --git a/examples/database_backend_demo.py b/examples/database_backend_demo.py deleted file mode 100644 index 20e6e40..0000000 --- a/examples/database_backend_demo.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -Demonstration of multi-backend support in pyFIA. - -This example shows how to use pyFIA with both DuckDB and SQLite backends. -""" - -from pathlib import Path - -import polars as pl -from rich.console import Console -from rich.table import Table - -from pyfia import FIA -from pyfia.core.data_reader import FIADataReader - -console = Console() - - -def demo_backend_autodetection(db_path: Path): - """Demonstrate automatic backend detection.""" - console.print("\n[bold blue]Backend Auto-detection Demo[/bold blue]") - - # Create reader with auto-detection - reader = FIADataReader(db_path) # Auto-detects backend - - # Read some data - console.print(f"Database: {db_path.name}") - console.print(f"Backend: {reader._backend.__class__.__name__}") - - # Read a sample of tree data - trees = reader.read_table( - "TREE", - columns=["CN", "SPCD", "DIA", "STATUSCD"], - where="STATUSCD = 1", - limit=5 - ) - - console.print(f"Sample live trees: {len(trees)} records") - - # Show data types (demonstrates backend-specific handling) - table = Table(title="Column Types") - table.add_column("Column") - table.add_column("Type") - - for col, dtype in trees.schema.items(): - table.add_row(col, str(dtype)) - - console.print(table) - - -def demo_explicit_backend(db_path: Path, engine: str): - """Demonstrate explicit backend selection.""" - console.print(f"\n[bold blue]Explicit {engine.upper()} Backend Demo[/bold blue]") - - # Create reader with explicit backend - reader = FIADataReader(db_path, engine=engine) - - # Demonstrate batch processing for large queries - console.print("Testing batch processing with large IN clause...") - - # Get some plot CNs - plots = reader.read_table("PLOT", columns=["CN"], limit=2000) - plot_cns = plots["CN"].to_list() - - # Read filtered data (will use batching automatically) - trees = reader.read_filtered_data("TREE", "PLT_CN", plot_cns[:1500]) - - console.print(f"Trees from {len(plot_cns[:1500])} plots: {len(trees)} records") - - -def demo_fia_class_with_backend(db_path: Path, engine: str = None): - """Demonstrate FIA class with backend support.""" - console.print(f"\n[bold blue]FIA Class Backend Demo[/bold blue]") - - # Create FIA instance (auto-detect or explicit engine) - fia = FIA(db_path, engine=engine) - - # Find evaluations - evalids = fia.find_evalid(most_recent=True) - console.print(f"Found {len(evalids)} most recent evaluations") - - if evalids: - # Use the first evaluation - fia.clip_by_evalid(evalids[0]) - - # Get some data - plots = fia.get_plots(columns=["CN", "STATECD", "PLOT"]) - console.print(f"Plots in evaluation: {len(plots)}") - - # Show state distribution - state_counts = plots.group_by("STATECD").agg( - pl.count().alias("plot_count") - ).sort("plot_count", descending=True) - - table = Table(title="Plots by State") - table.add_column("State") - table.add_column("Plot Count") - - for row in state_counts.head(5).iter_rows(): - table.add_row(str(row[0]), str(row[1])) - - console.print(table) - - -def demo_performance_options(db_path: Path): - """Demonstrate backend-specific performance options.""" - console.print("\n[bold blue]Performance Options Demo[/bold blue]") - - # DuckDB with memory configuration - if db_path.suffix.lower() in [".duckdb", ".ddb"]: - console.print("Creating DuckDB reader with performance settings...") - reader = FIADataReader( - db_path, - engine="duckdb", - memory_limit="8GB", - threads=4 - ) - console.print("✓ DuckDB configured with 8GB memory limit and 4 threads") - - # SQLite with timeout configuration - elif db_path.suffix.lower() in [".db", ".sqlite"]: - console.print("Creating SQLite reader with timeout settings...") - reader = FIADataReader( - db_path, - engine="sqlite", - timeout=60.0 - ) - console.print("✓ SQLite configured with 60 second timeout") - - -def main(): - """Run the demonstration.""" - # Example database paths (adjust to your setup) - duckdb_path = Path("data/fia_georgia.duckdb") - sqlite_path = Path("data/fia_georgia.db") - - console.print("[bold green]pyFIA Multi-Backend Demonstration[/bold green]") - - # Check which databases are available - available_dbs = [] - if duckdb_path.exists(): - available_dbs.append(("DuckDB", duckdb_path)) - if sqlite_path.exists(): - available_dbs.append(("SQLite", sqlite_path)) - - if not available_dbs: - console.print("[red]No FIA databases found. Please adjust the paths in main()[/red]") - return - - # Run demos for available databases - for db_type, db_path in available_dbs: - console.print(f"\n[yellow]{'=' * 60}[/yellow]") - console.print(f"[yellow]Testing with {db_type} database: {db_path}[/yellow]") - console.print(f"[yellow]{'=' * 60}[/yellow]") - - # Auto-detection demo - demo_backend_autodetection(db_path) - - # Explicit backend demo - engine = "duckdb" if db_type == "DuckDB" else "sqlite" - demo_explicit_backend(db_path, engine) - - # FIA class demo - demo_fia_class_with_backend(db_path) - - # Performance options demo - demo_performance_options(db_path) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/harvest_panel_analysis.py b/examples/harvest_panel_analysis.py deleted file mode 100644 index e3dfe40..0000000 --- a/examples/harvest_panel_analysis.py +++ /dev/null @@ -1,236 +0,0 @@ -""" -Example script demonstrating harvest panel analysis using pyFIA. - -This script shows how to use the panel() function to: -1. Create condition-level panels for harvest probability analysis -2. Create tree-level panels for individual tree fate tracking -3. Analyze harvest rates by ownership and species -4. Explore multi-period remeasurement chains - -Requirements: - - pyFIA installed - - FIA database (e.g., downloaded via pyfia.download()) -""" - -from pathlib import Path - -import polars as pl -from rich.console import Console -from rich.table import Table - -from pyfia import FIA, panel - -console = Console() - - -def analyze_condition_harvest(db: FIA) -> None: - """Analyze harvest rates at the condition level.""" - console.print("\n[bold blue]CONDITION-LEVEL HARVEST ANALYSIS[/bold blue]") - console.print("-" * 50) - - # Create condition panel for forest land - cond_panel = panel(db, level="condition", land_type="forest") - - # Basic statistics - n_conditions = len(cond_panel) - n_harvested = cond_panel["HARVEST"].sum() - harvest_rate = cond_panel["HARVEST"].mean() - avg_remper = cond_panel["REMPER"].mean() - - # Calculate annualized rate - annual_rate = 1 - (1 - harvest_rate) ** (1 / avg_remper) - - console.print(f"Total condition pairs: {n_conditions:,}") - console.print(f"Harvested conditions: {n_harvested:,}") - console.print(f"Period harvest rate: {harvest_rate:.1%}") - console.print(f"Avg remeasurement period: {avg_remper:.1f} years") - console.print(f"[green]Annualized harvest rate: {annual_rate:.2%}/year[/green]") - - # Harvest by ownership - console.print("\n[bold]Harvest Rate by Ownership:[/bold]") - own_names = {10: "USFS", 20: "Other Federal", 30: "State/Local", 40: "Private"} - - harvest_by_owner = ( - cond_panel.group_by("t2_OWNGRPCD") - .agg([pl.len().alias("n_conditions"), pl.col("HARVEST").mean().alias("harvest_rate")]) - .sort("t2_OWNGRPCD") - ) - - table = Table(show_header=True, header_style="bold") - table.add_column("Ownership") - table.add_column("N Conditions", justify="right") - table.add_column("Harvest Rate", justify="right") - - for row in harvest_by_owner.iter_rows(named=True): - name = own_names.get(row["t2_OWNGRPCD"], f"Code {row['t2_OWNGRPCD']}") - table.add_row(name, f"{row['n_conditions']:,}", f"{row['harvest_rate']:.1%}") - - console.print(table) - - -def analyze_tree_fates(db: FIA) -> None: - """Analyze tree fates including cut trees.""" - console.print("\n[bold blue]TREE-LEVEL FATE ANALYSIS[/bold blue]") - console.print("-" * 50) - - # Create tree panel (infer_cut=True by default) - tree_panel = panel(db, level="tree", tree_type="all") - - console.print(f"Total tree pairs: {len(tree_panel):,}") - - # Tree fate distribution - console.print("\n[bold]Tree Fate Distribution:[/bold]") - fate_dist = ( - tree_panel.group_by("TREE_FATE") - .agg(pl.len().alias("n")) - .with_columns((pl.col("n") / pl.col("n").sum() * 100).alias("pct")) - .sort("n", descending=True) - ) - - table = Table(show_header=True, header_style="bold") - table.add_column("Fate") - table.add_column("Count", justify="right") - table.add_column("Percent", justify="right") - - for row in fate_dist.iter_rows(named=True): - table.add_row(row["TREE_FATE"], f"{row['n']:,}", f"{row['pct']:.1f}%") - - console.print(table) - - # Cut trees analysis - cut_trees = tree_panel.filter(pl.col("TREE_FATE") == "cut") - if len(cut_trees) > 0: - console.print(f"\n[bold]Top 10 Species Cut ({len(cut_trees):,} total):[/bold]") - - spcd_names = { - 131: "Loblolly pine", - 132: "Longleaf pine", - 110: "Shortleaf pine", - 111: "Slash pine", - 611: "Sweetgum", - 621: "Yellow-poplar", - 802: "White oak", - 833: "N. red oak", - 316: "Red maple", - 261: "E. white pine", - } - - species_cut = ( - cut_trees.filter(pl.col("t1_SPCD").is_not_null()) - .group_by("t1_SPCD") - .agg([pl.len().alias("n_trees"), pl.col("t1_DIA").mean().alias("avg_dia")]) - .sort("n_trees", descending=True) - .head(10) - ) - - table = Table(show_header=True, header_style="bold") - table.add_column("Species") - table.add_column("Trees Cut", justify="right") - table.add_column("Avg DBH", justify="right") - - for row in species_cut.iter_rows(named=True): - spcd = int(row["t1_SPCD"]) - name = spcd_names.get(spcd, f"SPCD {spcd}") - table.add_row(name, f"{row['n_trees']:,}", f"{row['avg_dia']:.1f}\"") - - console.print(table) - - -def analyze_remeasurement_chains(db: FIA) -> None: - """Analyze multi-period remeasurement chains.""" - console.print("\n[bold blue]REMEASUREMENT CHAIN ANALYSIS[/bold blue]") - console.print("-" * 50) - - cond_panel = panel(db, level="condition") - - # Count plots by number of measurement periods - chain_lengths = ( - cond_panel.group_by("PLT_CN") - .agg(pl.len().alias("periods")) - .group_by("periods") - .agg(pl.len().alias("n_plots")) - .sort("periods") - ) - - console.print("[bold]Plots by Number of Measurement Periods:[/bold]") - table = Table(show_header=True, header_style="bold") - table.add_column("Periods") - table.add_column("N Plots", justify="right") - - for row in chain_lengths.iter_rows(named=True): - table.add_row(str(row["periods"]), f"{row['n_plots']:,}") - - console.print(table) - - # Harvest transition analysis for multi-period plots - multi_period_plots = ( - cond_panel.group_by("PLT_CN").agg(pl.len().alias("n")).filter(pl.col("n") > 1)["PLT_CN"] - ) - - multi_period = cond_panel.filter(pl.col("PLT_CN").is_in(multi_period_plots)).sort( - ["PLT_CN", "INVYR"] - ) - - if len(multi_period) > 0: - console.print("\n[bold]Harvest Transitions (plots with 2+ periods):[/bold]") - - transitions = ( - multi_period.with_columns( - [pl.col("HARVEST").shift(1).over("PLT_CN").alias("PREV_HARVEST")] - ) - .filter(pl.col("PREV_HARVEST").is_not_null()) - .group_by(["PREV_HARVEST", "HARVEST"]) - .agg(pl.len().alias("count")) - .sort(["PREV_HARVEST", "HARVEST"]) - ) - - table = Table(show_header=True, header_style="bold") - table.add_column("Previous Period") - table.add_column("Current Period") - table.add_column("Count", justify="right") - - for row in transitions.iter_rows(named=True): - prev = "Harvested" if row["PREV_HARVEST"] == 1 else "Not harvested" - curr = "Harvested" if row["HARVEST"] == 1 else "Not harvested" - table.add_row(prev, curr, f"{row['count']:,}") - - console.print(table) - - -def main(db_path: str = "data/nc.duckdb", state_code: int = 37): - """ - Run harvest panel analysis. - - Args: - db_path: Path to FIA DuckDB database - state_code: FIPS state code (default 37 = North Carolina) - """ - console.print(f"[bold green]Harvest Panel Analysis[/bold green]") - console.print(f"Database: {db_path}") - console.print(f"State code: {state_code}") - - # Check if database exists - if not Path(db_path).exists(): - console.print(f"[red]Database not found: {db_path}[/red]") - console.print("Download data first with: pyfia.download(states='NC', dir='data/')") - return - - with FIA(db_path) as db: - db.clip_by_state(state_code) - - # Run analyses - analyze_condition_harvest(db) - analyze_tree_fates(db) - analyze_remeasurement_chains(db) - - console.print("\n[bold green]Analysis complete![/bold green]") - - -if __name__ == "__main__": - import sys - - # Allow command-line arguments for database path and state - db_path = sys.argv[1] if len(sys.argv) > 1 else "data/nc.duckdb" - state_code = int(sys.argv[2]) if len(sys.argv) > 2 else 37 - - main(db_path, state_code) diff --git a/examples/mortality_by_cause.py b/examples/mortality_by_cause.py index 608a246..50ac05f 100644 --- a/examples/mortality_by_cause.py +++ b/examples/mortality_by_cause.py @@ -1,22 +1,62 @@ #!/usr/bin/env python3 """ -Example: Mortality estimates grouped by cause of death (AGENTCD). - -This example demonstrates the new AGENTCD and DSTRBCD grouping capabilities -in the mortality() function, enabling timber casualty loss analysis. +Mortality Estimates Grouped by Cause of Death +============================================== + +This example demonstrates how to estimate annual tree mortality grouped by +cause of death (AGENTCD) - a real-world use case for forest landowners who +need to classify timber losses for federal income tax purposes. + +Tax Classification for Timber Losses +------------------------------------ +The IRS allows deductions for timber casualties, but the treatment varies: + +- CASUALTY (fully deductible): Sudden, unexpected events + - Fire (AGENTCD=30) + - Weather damage like hurricanes, tornadoes, ice storms (AGENTCD=50) + +- NON-CASUALTY (limited deduction): Gradual losses + - Insect damage (AGENTCD=10) + - Disease (AGENTCD=20) + +- NON-DEDUCTIBLE: Normal forest losses + - Animal damage (AGENTCD=40) + - Vegetation competition (AGENTCD=60) + - Silvicultural activities (AGENTCD=80) + +How This Script Works +--------------------- +1. Connects to an FIA database (local DuckDB or cloud MotherDuck) +2. Runs mortality estimation with `grp_by="AGENTCD"` to group by cause +3. Maps AGENTCD codes to human-readable names and tax classifications +4. Summarizes total mortality volume by tax category +5. Also shows mortality by disturbance type (DSTRBCD1) for context + +Key pyFIA Features Demonstrated +------------------------------- +- `grp_by` parameter: Group estimates by any column (AGENTCD, SPCD, etc.) +- `measure` parameter: Choose what to measure (volume, tpa, biomass, etc.) +- `variance=True`: Include standard errors for uncertainty quantification +- Support for both local DuckDB and cloud MotherDuck databases + +Usage +----- + # With local DuckDB file + uv run python examples/mortality_by_cause.py --duckdb data/ri/ri/ri.duckdb + + # With MotherDuck cloud database + uv run python examples/mortality_by_cause.py --motherduck fia_va -Use Case: Forest landowners need to classify mortality by cause for -federal income tax purposes: -- Casualty Loss (tax-deductible): Fire, hurricanes/wind -- Non-Casualty Loss: Insects, disease, drought -- Non-Deductible: Animal damage, vegetation competition + # With specific EVALID + uv run python examples/mortality_by_cause.py --duckdb data/va.duckdb --evalid 512001 -Usage: - # With MotherDuck - uv run python examples/mortality_by_cause.py --motherduck fia_va +Output +------ +The script produces two tables: +1. Annual mortality volume by cause (AGENTCD) with tax classification +2. Annual mortality volume by disturbance type (DSTRBCD1) - # With local DuckDB - uv run python examples/mortality_by_cause.py --duckdb data/virginia.duckdb +Plus a summary showing total volume in each tax category. """ import argparse @@ -26,7 +66,7 @@ console = Console() -# AGENTCD code descriptions +# AGENTCD code descriptions (from FIA documentation) AGENTCD_NAMES = { 0: "No agent recorded", 10: "Insect", @@ -54,13 +94,35 @@ def run_mortality_by_agentcd(db): - """Run mortality estimates grouped by AGENTCD.""" + """ + Run mortality estimates grouped by AGENTCD (cause of death). + + This function demonstrates the core pyFIA pattern: + 1. Call an estimation function (mortality, volume, area, etc.) + 2. Use grp_by to group results by a column of interest + 3. Process and display the results + + Parameters + ---------- + db : FIA or MotherDuckFIA + Connected database instance. + + Returns + ------- + pl.DataFrame + Mortality results grouped by AGENTCD. + """ from pyfia import mortality console.print("\n[bold]Mortality by Cause of Death (AGENTCD)[/bold]") console.print("=" * 60) - # Run mortality with AGENTCD grouping + # Run mortality estimation with AGENTCD grouping + # Key parameters: + # - grp_by="AGENTCD": Group results by mortality agent code + # - measure="volume": Report mortality in cubic feet + # - tree_type="gs": Growing stock trees only + # - variance=True: Include standard errors result = mortality( db, grp_by="AGENTCD", @@ -70,7 +132,7 @@ def run_mortality_by_agentcd(db): variance=True, ) - # Display results + # Display results in a formatted table table = Table(title="Annual Mortality Volume by Cause") table.add_column("AGENTCD", justify="right") table.add_column("Cause", justify="left") @@ -122,7 +184,22 @@ def run_mortality_by_agentcd(db): def run_mortality_by_dstrbcd(db): - """Run mortality estimates grouped by DSTRBCD1 (disturbance code).""" + """ + Run mortality estimates grouped by DSTRBCD1 (disturbance code). + + DSTRBCD1 records the primary disturbance affecting a condition, + providing additional context beyond the mortality agent. + + Parameters + ---------- + db : FIA or MotherDuckFIA + Connected database instance. + + Returns + ------- + pl.DataFrame + Mortality results grouped by DSTRBCD1. + """ from pyfia import mortality console.print("\n[bold]Mortality by Disturbance Type (DSTRBCD1)[/bold]") @@ -161,6 +238,7 @@ def run_mortality_by_dstrbcd(db): def main(): + """Main entry point - parse arguments and run analysis.""" parser = argparse.ArgumentParser( description="Demonstrate mortality grouping by cause of death" ) diff --git a/examples/mortality_calculator_demo.py b/examples/mortality_calculator_demo.py deleted file mode 100644 index 3d4a39e..0000000 --- a/examples/mortality_calculator_demo.py +++ /dev/null @@ -1,88 +0,0 @@ -#!/usr/bin/env python -""" -Example demonstrating the enhanced mortality calculator for pyFIA. - -This script shows how to use the new mortality estimation features -with various grouping options and variance calculations. -""" - -from pyfia import FIA, mortality - - -def main(): - """Run mortality estimation examples.""" - # Initialize FIA database (adjust path as needed) - db_path = "path/to/fia.duckdb" - - print("pyFIA Enhanced Mortality Calculator Demo") - print("=" * 50) - - try: - with FIA(db_path) as db: - # Example 1: Basic mortality estimation - print("\n1. Basic mortality estimation (trees per acre)") - results = mortality(db) - print(results.select(["MORTALITY_TPA", "MORTALITY_TPA_SE", "N_PLOTS"])) - - # Example 2: Mortality by species - print("\n2. Mortality by species") - results = mortality(db, by_species=True) - print(results.select(["SPCD", "COMMON_NAME", "MORTALITY_TPA", "MORTALITY_TPA_SE"]) - .sort("MORTALITY_TPA", descending=True) - .head(10)) - - # Example 3: Mortality by ownership and agent - print("\n3. Mortality by ownership group and mortality agent") - results = mortality( - db, - by_ownership=True, - by_agent=True, - include_components=True - ) - print(results.select([ - "OWNGRPCD", "OWNGRPNM", "AGENTCD", "AGENTNM", - "MORTALITY_TPA", "MORTALITY_BA", "MORTALITY_VOL" - ]).head(10)) - - # Example 4: Mortality with multiple groupings and variance - print("\n4. Detailed mortality with variance components") - results = mortality( - db, - by_species=True, - by_ownership=True, - by_disturbance=True, - variance=True, - totals=True - ) - print(results.columns) - - # Example 5: Mortality for specific domain - print("\n5. Mortality for loblolly pine in forest land") - results = mortality( - db, - tree_domain="SPCD == 131", # Loblolly pine - land_type="forest", - by_agent=True, - include_components=True - ) - print(results.select([ - "AGENTCD", "AGENTNM", - "MORTALITY_TPA", "MORTALITY_BA", "MORTALITY_VOL", - "N_PLOTS" - ])) - - # Example 6: Compare growing stock vs all trees mortality - print("\n6. Comparing growing stock vs all trees mortality") - gs_mort = mortality(db, tree_class="growing_stock") - all_mort = mortality(db, tree_class="all") - - print(f"Growing stock mortality: {gs_mort['MORTALITY_TPA'][0]:.2f} TPA") - print(f"All trees mortality: {all_mort['MORTALITY_TPA'][0]:.2f} TPA") - - except Exception as e: - print(f"Error: {e}") - print("\nNote: Update db_path to point to your FIA database") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/mortality_config_demo.py b/examples/mortality_config_demo.py deleted file mode 100644 index 941c4e6..0000000 --- a/examples/mortality_config_demo.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3 -""" -Demonstration of the new MortalityConfig usage in pyFIA. - -This script shows how to use the enhanced Pydantic-based configuration -for mortality estimation with proper validation and type safety. -""" - -from pyfia import FIA, MortalityCalculator, MortalityConfig -import polars as pl - - -def main(): - """Demonstrate mortality configuration usage.""" - - # Example database path (adjust as needed) - db_path = "/path/to/fia.duckdb" - - # Example 1: Basic mortality configuration - print("Example 1: Basic mortality estimation") - basic_config = MortalityConfig( - mortality_type="tpa", - land_type="forest", - tree_type="all", - variance=True, - totals=True - ) - print(f"Basic config grouping columns: {basic_config.get_grouping_columns()}") - print(f"Expected output columns: {basic_config.get_output_columns()}") - print() - - # Example 2: Mortality by species and ownership - print("Example 2: Mortality by species and ownership") - species_config = MortalityConfig( - mortality_type="both", # Calculate both TPA and volume - by_species=True, - group_by_ownership=True, - variance=True, - totals=True, - include_components=True # Include BA components - ) - print(f"Species config grouping columns: {species_config.get_grouping_columns()}") - print(f"Expected output columns: {species_config.get_output_columns()}") - print() - - # Example 3: Complex grouping with validation - print("Example 3: Complex mortality grouping") - complex_config = MortalityConfig( - mortality_type="volume", - grp_by=["STATECD", "UNITCD", "COUNTYCD"], - by_species=True, - group_by_species_group=True, - group_by_agent=True, - group_by_disturbance=True, - tree_domain="DIA >= 10.0", - area_domain="COND_STATUS_CD == 1", - variance=True, - totals=True, - tree_class="timber" - ) - print(f"Complex config grouping columns: {complex_config.get_grouping_columns()}") - print(f"Expected output columns: {complex_config.get_output_columns()}") - print() - - # Example 4: Validation examples - print("Example 4: Configuration validation") - - try: - # This will raise a validation error - invalid_config = MortalityConfig( - mortality_type="volume", - tree_type="live", # Can't calculate mortality on live trees! - land_type="forest" - ) - except ValueError as e: - print(f"Validation error caught: {e}") - - try: - # This will also raise a validation error - invalid_config2 = MortalityConfig( - tree_class="timber", - land_type="forest" # timber class requires timber land type - ) - except ValueError as e: - print(f"Validation error caught: {e}") - - # Example 5: Using with MortalityCalculator (if database available) - print("\nExample 5: Using with MortalityCalculator") - - # Create a configuration for actual use - calc_config = MortalityConfig( - mortality_type="tpa", - by_species=True, - group_by_ownership=True, - tree_domain="STATUSCD == 2", # Dead trees - variance=True, - totals=True, - variance_method="ratio" - ) - - # Show how it would be used (uncomment with real database) - # with FIA(db_path) as db: - # db.clip_by_state(37) # North Carolina - # calculator = MortalityCalculator(db, calc_config) - # results = calculator.estimate() - # print(results.head()) - - # Example 6: Converting to legacy config for backwards compatibility - print("\nExample 6: Backwards compatibility") - legacy_config = calc_config.to_estimator_config() - print(f"Legacy config type: {type(legacy_config)}") - print(f"Legacy config grp_by: {legacy_config.grp_by}") - print(f"Legacy config extra_params: {legacy_config.extra_params}") - - # Example 7: Domain expression validation - print("\nExample 7: Domain expression validation") - - safe_config = MortalityConfig( - tree_domain="DIA >= 10.0 AND STATUSCD == 2", - area_domain="FORTYPCD IN (121, 122, 123)" - ) - print("Safe domain expressions validated successfully") - - try: - # This would be caught by validation - dangerous_config = MortalityConfig( - tree_domain="DIA >= 10; DROP TABLE TREE; --" - ) - except ValueError as e: - print(f"Dangerous SQL caught: {e}") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/mortality_with_new_config.py b/examples/mortality_with_new_config.py deleted file mode 100644 index 983943b..0000000 --- a/examples/mortality_with_new_config.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -""" -Example of using the new MortalityConfig with pyFIA mortality estimation. - -This script demonstrates how to use the enhanced Pydantic-based configuration -for mortality analysis with real FIA data. -""" - -from pyfia import FIA, MortalityCalculator, MortalityConfig, mortality -import polars as pl - - -def example_basic_mortality(db_path: str): - """Basic mortality estimation example.""" - print("=" * 60) - print("Example 1: Basic Mortality Estimation") - print("=" * 60) - - # Using the convenience function (backwards compatible) - with FIA(db_path) as db: - db.clip_by_state(13) # Georgia - - # Traditional approach still works - results = mortality( - db, - by_species=True, - variance=True, - totals=True - ) - - print("\nTop 5 species by mortality (TPA):") - print(results.sort("MORTALITY_TPA", descending=True).head(5)) - - -def example_advanced_config(db_path: str): - """Advanced mortality estimation with new config.""" - print("\n" + "=" * 60) - print("Example 2: Advanced Mortality with MortalityConfig") - print("=" * 60) - - # Create a sophisticated configuration - config = MortalityConfig( - # Mortality calculation options - mortality_type="both", # Calculate both TPA and volume - tree_class="timber", # Focus on timber trees - land_type="timber", # Timber land only - - # Grouping options - grp_by=["UNITCD", "COUNTYCD"], # Geographic grouping - by_species=True, # Group by species - group_by_ownership=True, # Include ownership - group_by_agent=True, # Include mortality agent - - # Domain filters - tree_domain="DIA >= 10.0 AND STATUSCD == 2", # Large dead trees - area_domain="COND_STATUS_CD == 1", # Forested conditions - - # Output options - variance=True, # Include variance calculations - totals=True, # Include total estimates - include_components=True, # Include BA components - - # Variance method - variance_method="ratio" # Use ratio variance method - ) - - with FIA(db_path) as db: - db.clip_by_state(13) # Georgia - - # Use the calculator directly with new config - calculator = MortalityCalculator(db, config) - results = calculator.estimate() - - print(f"\nGrouping columns used: {config.get_grouping_columns()}") - print(f"Output columns: {results.columns}") - print(f"\nNumber of groups: {len(results)}") - - # Show summary by ownership - if "OWNGRPCD" in results.columns: - ownership_summary = ( - results - .group_by("OWNGRPCD") - .agg([ - pl.sum("MORTALITY_TPA_TOTAL").alias("TOTAL_MORTALITY_TPA"), - pl.sum("MORTALITY_VOL_TOTAL").alias("TOTAL_MORTALITY_VOL") - ]) - .sort("TOTAL_MORTALITY_TPA", descending=True) - ) - print("\nMortality by ownership group:") - print(ownership_summary) - - -def example_validation(db_path: str): - """Demonstrate configuration validation.""" - print("\n" + "=" * 60) - print("Example 3: Configuration Validation") - print("=" * 60) - - # Example of validation in action - try: - # This will fail validation - bad_config = MortalityConfig( - mortality_type="volume", - tree_type="live" # Can't calculate mortality on live trees! - ) - except ValueError as e: - print(f"✓ Validation caught error: {e}") - - try: - # This will also fail - bad_config2 = MortalityConfig( - tree_class="timber", - land_type="forest" # Timber class needs timber land type - ) - except ValueError as e: - print(f"✓ Validation caught error: {e}") - - # Show a valid timber configuration - valid_config = MortalityConfig( - mortality_type="volume", - tree_type="dead", - tree_class="timber", - land_type="timber" - ) - print(f"\n✓ Valid timber mortality config created successfully") - print(f" - Tree type: {valid_config.tree_type}") - print(f" - Tree class: {valid_config.tree_class}") - print(f" - Land type: {valid_config.land_type}") - - -def example_comparison(db_path: str): - """Compare old and new configuration approaches.""" - print("\n" + "=" * 60) - print("Example 4: Old vs New Configuration Approaches") - print("=" * 60) - - with FIA(db_path) as db: - db.clip_by_state(13) # Georgia - - # Old approach - passing many parameters - print("Old approach - function with many parameters:") - results_old = mortality( - db, - by_species=True, - by_ownership=True, - by_agent=True, - tree_domain="DIA >= 10.0", - variance=True, - totals=True - ) - - # New approach - structured configuration - print("\nNew approach - structured configuration:") - config = MortalityConfig( - by_species=True, - group_by_ownership=True, - group_by_agent=True, - tree_domain="DIA >= 10.0", - variance=True, - totals=True - ) - - calculator = MortalityCalculator(db, config) - results_new = calculator.estimate() - - # Results should be identical - print(f"\nOld approach result shape: {results_old.shape}") - print(f"New approach result shape: {results_new.shape}") - print(f"Results identical: {results_old.equals(results_new)}") - - -def main(): - """Run all examples.""" - # Update this path to your FIA database - db_path = "/path/to/fia_georgia.duckdb" - - # Check if path needs updating - import os - if not os.path.exists(db_path): - print("Please update db_path in the script to point to your FIA database.") - print("Example: db_path = '/data/fia/georgia.duckdb'") - return - - example_basic_mortality(db_path) - example_advanced_config(db_path) - example_validation(db_path) - example_comparison(db_path) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/plot_data_access.py b/examples/plot_data_access.py index 7a4efe2..4806630 100644 --- a/examples/plot_data_access.py +++ b/examples/plot_data_access.py @@ -1,20 +1,81 @@ """ -Example: Accessing Plot-Level Data from FIA Database. - -This example demonstrates how to retrieve raw plot-level data from pyFIA, -including site index, location coordinates, and stand attributes. This is -useful for: - -- Building predictive models with FIA plot data -- Spatial analysis and mapping -- Linking FIA data with external datasets (climate, soils, etc.) -- Custom analyses beyond the standard estimation functions - -Key Concepts: -- PLOT table: Contains location (LAT, LON), elevation, inventory year -- COND table: Contains stand-level attributes (site index, forest type, age) -- A single plot can have multiple conditions (stands) -- Site index is a condition-level (stand) variable, not plot-level +Accessing Raw Plot-Level Data from FIA +====================================== + +This example demonstrates how to retrieve raw plot-level data from pyFIA +for custom analyses beyond the standard estimation functions. + +When to Use This Approach +------------------------- +Use direct data access when you need to: + +- Build predictive models (ML/statistics) with FIA plot data +- Create maps or perform spatial analysis +- Link FIA data with external datasets (climate, soils, remote sensing) +- Perform custom analyses not covered by estimation functions +- Export data for use in other tools (R, GIS, etc.) + +Understanding FIA Data Structure +-------------------------------- +FIA uses a hierarchical data model: + + PLOT (1 per location) + | + +-- COND (1+ per plot, represents different "conditions" or stands) + | | + | +-- Forest type, ownership, site index, etc. + | + +-- TREE (many per plot) + | + +-- Species, diameter, height, volume, etc. + +Key insight: A single plot can have MULTIPLE conditions. For example, +a plot might be 60% pine plantation and 40% hardwood. Each condition +has its own site index, forest type, and ownership. + +Key Tables and Columns +---------------------- +PLOT table (location information): + - CN: Unique plot identifier + - LAT, LON: Coordinates (fuzzed for privacy) + - ELEV: Elevation in feet + - INVYR: Inventory year + - COUNTYCD: County FIPS code + +COND table (stand attributes - one row per condition): + - PLT_CN: Links to PLOT.CN + - CONDID: Condition number (1, 2, 3...) + - SICOND: Site index (height at base age) + - SIBASE: Base age for site index (25 or 50 years) + - SISP: Species code used for site index + - FORTYPCD: Forest type code + - STDAGE: Stand age + - OWNGRPCD: Ownership group code + - CONDPROP_UNADJ: Proportion of plot in this condition + +TREE table (individual trees): + - PLT_CN: Links to PLOT.CN + - SPCD: Species code + - DIA: Diameter at breast height (inches) + - HT: Total height (feet) + - VOLCFNET: Net cubic foot volume + - DRYBIO_AG: Aboveground dry biomass (pounds) + +Usage +----- + # Default: uses NC data from ~/.pyfia/ + uv run python examples/plot_data_access.py + + # With custom database + uv run python examples/plot_data_access.py /path/to/state.duckdb + +Functions Provided +------------------ +- get_plot_locations(): Basic plot coordinates +- get_site_index_by_condition(): Site index at condition level +- get_site_index_with_location(): Conditions joined with plot locations +- get_weighted_site_index_by_plot(): Area-weighted site index per plot +- get_tree_data_by_plot(): Tree data with plot locations """ from pyfia import FIA @@ -25,6 +86,9 @@ def get_plot_locations(db_path: str, state_fips: int) -> pl.DataFrame: """ Retrieve plot locations with basic attributes. + This is the simplest form of data access - just plot coordinates + and identifiers. Useful for mapping or as a starting point for joins. + Parameters ---------- db_path : str @@ -36,6 +100,11 @@ def get_plot_locations(db_path: str, state_fips: int) -> pl.DataFrame: ------- pl.DataFrame Plot data with CN, LAT, LON, ELEV, INVYR, COUNTYCD. + + Example + ------- + >>> plots = get_plot_locations("data/nc.duckdb", 37) + >>> print(plots.head()) """ db = FIA(db_path) db.clip_by_state(state_fips) @@ -52,8 +121,12 @@ def get_site_index_by_condition(db_path: str, state_fips: int) -> pl.DataFrame: """ Retrieve site index data at the condition (stand) level. - Site index is measured per condition, not per plot. A plot can have - multiple conditions if it contains different forest types or ownerships. + Site index is a measure of site productivity - the expected height + of dominant trees at a base age (typically 25 or 50 years). Higher + site index = more productive site. + + IMPORTANT: Site index is measured per CONDITION, not per plot. + A plot can have multiple conditions with different site indices. Parameters ---------- @@ -67,13 +140,18 @@ def get_site_index_by_condition(db_path: str, state_fips: int) -> pl.DataFrame: pl.DataFrame Condition data with site index and related attributes. - Notes - ----- - Key columns: + Key Columns + ----------- - SICOND: Site index for the condition (feet at base age) - SIBASE: Base age for site index (typically 25 or 50 years) - SISP: Species code used for site index determination - CONDPROP_UNADJ: Proportion of plot area in this condition + + Example + ------- + >>> conds = get_site_index_by_condition("data/nc.duckdb", 37) + >>> # Filter to valid site index values + >>> valid = conds.filter(pl.col("SICOND").is_not_null()) """ db = FIA(db_path) db.clip_by_state(state_fips) @@ -98,7 +176,8 @@ def get_site_index_with_location(db_path: str, state_fips: int) -> pl.DataFrame: Join site index data with plot locations. Returns one row per condition, with both stand attributes and - plot location information. + plot location information. Useful for spatial analysis of site + productivity. Parameters ---------- @@ -111,6 +190,12 @@ def get_site_index_with_location(db_path: str, state_fips: int) -> pl.DataFrame: ------- pl.DataFrame Combined plot location and condition site index data. + + Example + ------- + >>> data = get_site_index_with_location("data/nc.duckdb", 37) + >>> # Export for GIS analysis + >>> data.write_csv("site_index_locations.csv") """ db = FIA(db_path) db.clip_by_state(state_fips) @@ -152,6 +237,16 @@ def get_weighted_site_index_by_plot(db_path: str, state_fips: int) -> pl.DataFra this function computes a weighted average based on condition proportions. This provides a single representative site index per plot location. + The Formula + ----------- + weighted_SI = sum(SICOND * CONDPROP_UNADJ) / sum(CONDPROP_UNADJ) + + For example, if a plot is: + - 60% pine plantation with SI=80 + - 40% hardwood with SI=70 + + Then: weighted_SI = (80*0.6 + 70*0.4) / (0.6 + 0.4) = 76 + Parameters ---------- db_path : str @@ -162,15 +257,17 @@ def get_weighted_site_index_by_plot(db_path: str, state_fips: int) -> pl.DataFra Returns ------- pl.DataFrame - One row per plot with weighted site index and location. + One row per plot with: + - weighted_SICOND: Area-weighted site index + - dominant_SISP: Site index species from largest condition + - dominant_FORTYPCD: Forest type from largest condition + - n_conditions: Number of conditions contributing to average - Notes - ----- - The weighted average formula is: - weighted_SI = sum(SICOND * CONDPROP_UNADJ) / sum(CONDPROP_UNADJ) - - Only conditions with valid (non-null) site index values are included. - The dominant site index species (SISP) is taken from the largest condition. + Example + ------- + >>> weighted = get_weighted_site_index_by_plot("data/nc.duckdb", 37) + >>> # Find high-productivity sites + >>> high_si = weighted.filter(pl.col("weighted_SICOND") > 90) """ db = FIA(db_path) db.clip_by_state(state_fips) @@ -236,6 +333,10 @@ def get_tree_data_by_plot(db_path: str, state_fips: int) -> pl.DataFrame: """ Retrieve tree-level data with plot locations. + Returns individual tree measurements joined with their plot + coordinates. Useful for species distribution modeling or + tree-level analyses. + Parameters ---------- db_path : str @@ -247,6 +348,23 @@ def get_tree_data_by_plot(db_path: str, state_fips: int) -> pl.DataFrame: ------- pl.DataFrame Tree data joined with plot locations. + + Key Tree Columns + ---------------- + - SPCD: Species code + - DIA: Diameter at breast height (inches) + - HT: Total height (feet) + - STATUSCD: 1=live, 2=dead + - VOLCFNET: Net cubic foot volume + - DRYBIO_AG: Aboveground dry biomass (pounds) + + Example + ------- + >>> trees = get_tree_data_by_plot("data/nc.duckdb", 37) + >>> # Filter to live loblolly pine + >>> loblolly = trees.filter( + ... (pl.col("SPCD") == 131) & (pl.col("STATUSCD") == 1) + ... ) """ db = FIA(db_path) db.clip_by_state(state_fips) diff --git a/examples/plot_domain_example.py b/examples/plot_domain_example.py index eddca49..c185acd 100644 --- a/examples/plot_domain_example.py +++ b/examples/plot_domain_example.py @@ -1,19 +1,92 @@ """ -Example: Using plot_domain to filter FIA estimates by county and location. +Filtering Estimates by Plot Attributes (plot_domain) +===================================================== -This example demonstrates how to use the new plot_domain parameter to filter -FIA estimates by PLOT-level attributes like COUNTYCD, UNITCD, LAT, LON, and ELEV. +This example documents the `plot_domain` parameter, which allows filtering +FIA estimates by PLOT-level attributes like county, coordinates, and elevation. -The plot_domain parameter is useful when you need to filter by attributes that -are stored in the PLOT table rather than the COND table. +The Problem +----------- +FIA data is stored in multiple tables: + + PLOT table: Location info (LAT, LON, COUNTYCD, ELEV, INVYR) + COND table: Stand attributes (ownership, forest type, site class) + TREE table: Individual tree measurements + +The standard `area_domain` and `tree_domain` parameters only filter COND +and TREE attributes. Before `plot_domain`, filtering by county or geographic +bounds required custom SQL or post-processing. + +The Solution +------------ +The `plot_domain` parameter accepts SQL-like expressions that filter on +PLOT table columns: + + area(db, plot_domain="COUNTYCD == 183") # Single county + area(db, plot_domain="LAT >= 35.0 AND LAT <= 36.0") # Lat range + area(db, plot_domain="ELEV > 2000") # High elevation + +Available PLOT Columns +---------------------- +Location: + - LAT: Latitude (decimal degrees, fuzzed for privacy) + - LON: Longitude (decimal degrees, fuzzed for privacy) + - ELEV: Elevation (feet) + +Administrative: + - STATECD: State FIPS code + - COUNTYCD: County FIPS code + - UNITCD: FIA survey unit code + +Temporal: + - INVYR: Inventory year (when plot was assigned to panel) + - MEASYEAR: Measurement year (when plot was actually measured) + - MEASMON: Measurement month (1-12) + +Identifiers: + - PLOT: Plot number within county + - CN: Unique plot identifier + +Combining with Other Filters +---------------------------- +You can use plot_domain together with area_domain and tree_domain: + + # County AND ownership filter + area(db, + plot_domain="COUNTYCD == 183", # PLOT-level: Wake County + area_domain="OWNGRPCD == 40") # COND-level: Private land + + # Geographic bounds AND species filter + volume(db, + plot_domain="LAT >= 35 AND LAT <= 36", # PLOT-level + tree_domain="SPCD == 131") # TREE-level: Loblolly pine + +Common Use Cases +---------------- +1. County-level estimates for local planning +2. Geographic subsets for regional analysis +3. Elevation bands for mountain/lowland comparisons +4. Temporal subsets for trend analysis +5. Survey unit summaries for FIA reporting + +Note: This file contains example code patterns, not a runnable script. +Replace 'path/to/fia.duckdb' with your actual database path. """ from pyfia import FIA, area, volume, biomass, tpa -# Example 1: Filter by county -# This was not possible before without custom SQL - now it's simple! + +# ============================================================================= +# Example 1: Single County Filter +# ============================================================================= + def example_county_filter(): - """Estimate forest area for a specific county.""" + """ + Estimate forest area for a specific county. + + Use Case: A county forester needs forest area statistics for their + jurisdiction. Previously this required custom SQL; now it's one line. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) # North Carolina @@ -27,26 +100,42 @@ def example_county_filter(): print(results) -# Example 2: Filter by multiple counties +# ============================================================================= +# Example 2: Multiple Counties with Grouping +# ============================================================================= + def example_multiple_counties(): - """Estimate forest area for multiple counties.""" + """ + Estimate forest area for multiple counties, grouped by county. + + Use Case: Compare forest resources across a multi-county region, + such as a metropolitan statistical area or watershed. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) # North Carolina - # Get area for multiple counties + # Get area for multiple counties, grouped by county results = area( db, plot_domain="COUNTYCD IN (183, 185, 187)", # Wake, Warren, Washington - grp_by="COUNTYCD", + grp_by="COUNTYCD", # Separate estimate per county land_type="forest" ) print("Forest area by county:") print(results) -# Example 3: Filter by survey unit +# ============================================================================= +# Example 3: Survey Unit Filter +# ============================================================================= + def example_survey_unit(): - """Estimate volume by survey unit.""" + """ + Estimate volume by FIA survey unit. + + Use Case: FIA reports are often organized by survey unit (groups of + counties). This enables matching pyFIA output to official FIA reports. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) @@ -61,13 +150,25 @@ def example_survey_unit(): print(results) -# Example 4: Geographic filtering by latitude/longitude +# ============================================================================= +# Example 4: Geographic Bounding Box +# ============================================================================= + def example_geographic_filter(): - """Estimate biomass within a geographic bounding box.""" + """ + Estimate biomass within a geographic bounding box. + + Use Case: Analyze forest resources within a specific geographic area, + such as a national forest boundary approximation or study area. + + Note: FIA coordinates are fuzzed up to 1 mile for privacy, so precise + boundary matching is not possible. Use spatial filtering (clip_by_polygon) + for accurate boundary analysis. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) - # Filter to plots within a specific geographic area + # Filter to plots within a lat/lon box results = biomass( db, plot_domain="LAT >= 35.0 AND LAT <= 36.0 AND LON >= -80.0 AND LON <= -79.0", @@ -78,16 +179,24 @@ def example_geographic_filter(): print(results) -# Example 5: Filter by elevation +# ============================================================================= +# Example 5: Elevation Filter +# ============================================================================= + def example_elevation_filter(): - """Estimate trees per acre at high elevations.""" + """ + Estimate trees per acre at high elevations. + + Use Case: Compare forest characteristics between mountain and lowland + forests, or focus analysis on specific elevation zones. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) - # Get TPA for high-elevation forests + # Get TPA for high-elevation forests (> 2000 feet) results = tpa( db, - plot_domain="ELEV > 2000", # Above 2000 feet + plot_domain="ELEV > 2000", land_type="forest", tree_type="live" ) @@ -95,74 +204,137 @@ def example_elevation_filter(): print(results) -# Example 6: Combine plot_domain with area_domain +# ============================================================================= +# Example 6: Combining PLOT and COND Filters +# ============================================================================= + def example_combined_filters(): - """Combine PLOT-level and COND-level filters.""" + """ + Combine PLOT-level and COND-level filters. + + Use Case: Answer questions like "What is the private forest area + in Wake County by forest type?" - requires filtering on both + PLOT (county) and COND (ownership, forest type) attributes. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) # Filter by county (PLOT) AND ownership (COND) results = area( db, - plot_domain="COUNTYCD == 183", # Wake County - area_domain="OWNGRPCD == 40", # Private land - grp_by="FORTYPCD", + plot_domain="COUNTYCD == 183", # PLOT-level: Wake County + area_domain="OWNGRPCD == 40", # COND-level: Private land + grp_by="FORTYPCD", # Group by forest type land_type="forest" ) print("Private forest area in Wake County by forest type:") print(results) -# Example 7: Filter by inventory year +# ============================================================================= +# Example 7: Temporal Filter (Measurement Year) +# ============================================================================= + def example_temporal_filter(): - """Estimate area from specific inventory years.""" + """ + Estimate area from specific measurement years. + + Use Case: Focus on recently measured plots for more current data, + or analyze temporal patterns in forest attributes. + + Note: MEASYEAR is when the plot was actually visited. INVYR is when + it was assigned to the panel (may differ by 1-2 years). + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) - # Get area from recent inventory + # Get area from plots measured since 2015 results = area( db, - plot_domain="MEASYEAR >= 2015", # Plots measured since 2015 + plot_domain="MEASYEAR >= 2015", land_type="forest" ) print("Forest area from plots measured since 2015:") print(results) -# Example 8: Complex plot filtering +# ============================================================================= +# Example 8: Complex Multi-Condition Filter +# ============================================================================= + def example_complex_plot_filter(): - """Use complex plot filtering with multiple conditions.""" + """ + Use complex plot filtering with multiple conditions. + + Use Case: Highly specific analyses combining geographic, temporal, + and administrative constraints. + """ with FIA("path/to/fia.duckdb") as db: db.clip_by_state(37) - # Complex filter: specific counties, elevation range, recent measurements + # Complex filter: specific counties + elevation range + recent measurements results = volume( db, plot_domain=( - "COUNTYCD IN (183, 185) AND " - "ELEV >= 100 AND ELEV <= 500 AND " - "MEASYEAR >= 2015" + "COUNTYCD IN (183, 185) AND " # Wake or Warren county + "ELEV >= 100 AND ELEV <= 500 AND " # Piedmont elevation + "MEASYEAR >= 2015" # Recent measurements ), land_type="forest", tree_type="live", - grp_by="COUNTYCD" + grp_by="COUNTYCD" # Separate results per county ) print("Volume in specific counties, elevation range, and time period:") print(results) +# ============================================================================= +# Quick Reference +# ============================================================================= + if __name__ == "__main__": print(__doc__) - print("\nNote: Replace 'path/to/fia.duckdb' with actual database path.") - print("\nAvailable PLOT columns for filtering:") - print(" - COUNTYCD: County FIPS code") - print(" - UNITCD: Survey unit code") - print(" - STATECD: State FIPS code") - print(" - LAT: Latitude (decimal degrees)") - print(" - LON: Longitude (decimal degrees)") - print(" - ELEV: Elevation (feet)") - print(" - INVYR: Inventory year") - print(" - MEASYEAR: Measurement year") - print(" - MEASMON: Measurement month") - print(" - PLOT: Plot number") - print("\nFor COND-level attributes (ownership, forest type, etc.), use area_domain instead.") + print("\n" + "=" * 70) + print("QUICK REFERENCE: plot_domain vs area_domain vs tree_domain") + print("=" * 70) + print(""" + plot_domain - Filters PLOT table (location, county, elevation) + Example: plot_domain="COUNTYCD == 183" + + area_domain - Filters COND table (ownership, forest type, site class) + Example: area_domain="OWNGRPCD == 40" + + tree_domain - Filters TREE table (species, diameter, status) + Example: tree_domain="SPCD == 131 AND DIA >= 10" + + All three can be combined in a single call: + + volume(db, + plot_domain="COUNTYCD == 183", + area_domain="OWNGRPCD == 40", + tree_domain="SPCD == 131") + """) + + print("\n" + "=" * 70) + print("AVAILABLE PLOT COLUMNS FOR FILTERING") + print("=" * 70) + print(""" + Location: + - LAT Latitude (decimal degrees) + - LON Longitude (decimal degrees) + - ELEV Elevation (feet) + + Administrative: + - STATECD State FIPS code + - COUNTYCD County FIPS code + - UNITCD FIA survey unit code + + Temporal: + - INVYR Inventory year + - MEASYEAR Measurement year + - MEASMON Measurement month (1-12) + + Identifiers: + - PLOT Plot number within county + - CN Unique plot identifier + """) diff --git a/examples/using_output_formatter.py b/examples/using_output_formatter.py deleted file mode 100644 index 9538aee..0000000 --- a/examples/using_output_formatter.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Example of using the OutputFormatter with estimation functions. - -This example demonstrates how to use the centralized output formatter -to ensure consistent output across different estimators. -""" - -import polars as pl -from pyfia.constants.constants import EstimatorType -from pyfia.estimation.formatters import OutputFormatter, format_estimation_output - - -def example_direct_formatter_usage(): - """Example of using OutputFormatter directly.""" - - # Simulate raw estimation output from area estimator - raw_output = pl.DataFrame({ - "LAND_TYPE": ["Forest", "Timber", "Other"], - "FA_TOTAL": [1000.0, 800.0, 200.0], - "FAD_TOTAL": [2000.0, 2000.0, 2000.0], - "AREA_PERC": [50.0, 40.0, 10.0], - "AREA_PERC_VAR": [4.0, 3.0, 1.0], - "nPlots": [100, 80, 20], - }) - - # Create formatter for area estimator - formatter = OutputFormatter(EstimatorType.AREA) - - # Format the output - formatted = formatter.format_output( - raw_output, - variance=False, # Convert to SE - totals=True, # Include total columns - group_cols=["LAND_TYPE"], - year=2023 - ) - - print("Formatted Area Output:") - print(formatted) - print() - - return formatted - - -def example_convenience_function(): - """Example using the convenience function.""" - - # Simulate raw TPA output - raw_tpa = pl.DataFrame({ - "SPCD": [110, 131, 833], - "TPA": [125.5, 95.2, 78.3], - "TPA_VAR": [15.5, 12.2, 9.8], - "BAA": [85.2, 72.1, 65.5], - "BAA_VAR": [8.5, 6.8, 5.2], - "TREE_TOTAL": [50000, 40000, 35000], - "BA_TOTAL": [35000, 30000, 28000], - "nPlots_TREE": [250, 230, 210], - }) - - # Use convenience function - formatted = format_estimation_output( - raw_tpa, - EstimatorType.TPA, - variance=False, # Output SE instead of variance - totals=True, # Include totals - group_cols=["SPCD"], - year=2023 - ) - - print("Formatted TPA Output:") - print(formatted) - print() - - return formatted - - -def example_biomass_formatting(): - """Example of formatting biomass output.""" - - # Simulate biomass estimation output - raw_biomass = pl.DataFrame({ - "SPCD": [110, 131], - "BIO_ACRE": [25.5, 32.8], - "BIO_ACRE_SE": [1.2, 1.5], - "CARB_ACRE": [12.0, 15.4], - "CARB_ACRE_SE": [0.6, 0.7], - "BIO_TOTAL": [125000, 164000], - "CARB_TOTAL": [58750, 77080], - "nPlots_TREE": [150, 180], - }) - - formatter = OutputFormatter(EstimatorType.BIOMASS) - - # Add metadata and ensure consistent formatting - formatted = formatter.format_output( - raw_biomass, - variance=False, - totals=True, - group_cols=["SPCD"], - year=2023 - ) - - print("Formatted Biomass Output:") - print(formatted) - print() - - return formatted - - -def example_variance_conversion(): - """Example showing variance/SE conversion.""" - - # Data with variance values - data_with_var = pl.DataFrame({ - "VOLUME_ACRE": [1500.0, 1200.0], - "VOLUME_ACRE_VAR": [225.0, 144.0], # Variance - "VOLUME": [750000, 600000], - "VOLUME_VAR": [56250, 36000], - }) - - formatter = OutputFormatter(EstimatorType.VOLUME) - - # Convert to SE - data_with_se = formatter.convert_variance_to_se(data_with_var) - - print("Converted to Standard Error:") - print(data_with_se) - print() - - # Convert back to variance - data_with_var_again = formatter.convert_se_to_variance(data_with_se) - - print("Converted back to Variance:") - print(data_with_var_again) - print() - - -def example_custom_formatter(): - """Example of customizing the formatter for specific needs.""" - - # Create a custom formatter by extending OutputFormatter - class CustomAreaFormatter(OutputFormatter): - def format_output(self, df, **kwargs): - # First apply standard formatting - df = super().format_output(df, **kwargs) - - # Add custom columns - if "AREA_PERC" in df.columns: - df = df.with_columns([ - # Add confidence interval - (pl.col("AREA_PERC") - 1.96 * pl.col("AREA_PERC_SE")).alias("CI_LOWER"), - (pl.col("AREA_PERC") + 1.96 * pl.col("AREA_PERC_SE")).alias("CI_UPPER"), - ]) - - return df - - # Use custom formatter - raw_data = pl.DataFrame({ - "AREA_PERC": [45.5, 54.5], - "AREA_PERC_VAR": [4.0, 5.0], - "N_PLOTS": [100, 120], - }) - - custom_formatter = CustomAreaFormatter(EstimatorType.AREA) - formatted = custom_formatter.format_output( - raw_data, - variance=False, - year=2023 - ) - - print("Custom Formatted Output with Confidence Intervals:") - print(formatted) - print() - - -if __name__ == "__main__": - print("=== Output Formatter Examples ===\n") - - # Run examples - example_direct_formatter_usage() - example_convenience_function() - example_biomass_formatting() - example_variance_conversion() - example_custom_formatter() - - print("All examples completed successfully!") \ No newline at end of file diff --git a/scripts/validate_against_evalidator.py b/scripts/validate_against_evalidator.py index 34e72a4..4c6503b 100644 --- a/scripts/validate_against_evalidator.py +++ b/scripts/validate_against_evalidator.py @@ -27,44 +27,56 @@ def extract_estimate_and_se(result: pl.DataFrame, estimate_type: str) -> Tuple[f Different estimators use different column naming conventions: - area(): AREA, AREA_SE - - volume(): VOLUME_TOTAL, VOLUME_SE - - biomass(): BIOMASS_TOTAL, BIOMASS_SE - - tpa(): TPA_TOTAL, TPA_SE (when totals=True) + - volume(): VOLCFNET_TOTAL, VOLCFNET_TOTAL_SE + - biomass(): BIO_TOTAL, BIO_TOTAL_SE + - tpa(): TPA_TOTAL, TPA_TOTAL_SE (when totals=True) """ cols = result.columns - # Find the main estimate column + # Find the main estimate column and corresponding SE column + est_col = None + se_col = None + if estimate_type == "area": est_col = "AREA" if "AREA" in cols else None se_col = "AREA_SE" if "AREA_SE" in cols else None elif estimate_type == "volume": - est_col = "VOLUME_TOTAL" if "VOLUME_TOTAL" in cols else "VOL_TOTAL" if "VOL_TOTAL" in cols else None - se_col = "VOLUME_SE" if "VOLUME_SE" in cols else "VOL_SE" if "VOL_SE" in cols else None + # Look for VOLCFNET_TOTAL or similar volume total columns + for col in cols: + if col.endswith("_TOTAL") and "VOL" in col.upper() and "SE" not in col: + est_col = col + se_col = f"{col}_SE" if f"{col}_SE" in cols else None + break elif estimate_type == "biomass": - est_col = "BIOMASS_TOTAL" if "BIOMASS_TOTAL" in cols else "BIO_TOTAL" if "BIO_TOTAL" in cols else None - se_col = "BIOMASS_SE" if "BIOMASS_SE" in cols else "BIO_SE" if "BIO_SE" in cols else None + # Look for BIO_TOTAL or BIOMASS_TOTAL + for col in cols: + if col.endswith("_TOTAL") and "BIO" in col.upper() and "SE" not in col: + est_col = col + se_col = f"{col}_SE" if f"{col}_SE" in cols else None + break elif estimate_type == "tpa": - est_col = "TPA_TOTAL" if "TPA_TOTAL" in cols else "TREE_TOTAL" if "TREE_TOTAL" in cols else None - se_col = "TPA_SE" if "TPA_SE" in cols else "TREE_SE" if "TREE_SE" in cols else None - else: - est_col = None - se_col = None + # Look for TPA_TOTAL + for col in cols: + if col.endswith("_TOTAL") and "TPA" in col.upper() and "SE" not in col: + est_col = col + se_col = f"{col}_SE" if f"{col}_SE" in cols else None + break - # Fallback: look for columns ending in _TOTAL or just named like the estimate type + # Fallback: look for columns ending in _TOTAL if est_col is None: for col in cols: col_upper = col.upper() - if col_upper.endswith("_TOTAL") and "EXPNS" not in col_upper: - est_col = col - break - if col_upper == estimate_type.upper(): + if col_upper.endswith("_TOTAL") and "EXPNS" not in col_upper and "SE" not in col_upper: est_col = col + # Look for corresponding SE column + se_col = f"{col}_SE" if f"{col}_SE" in cols else None break + # Final fallback for SE: look for _TOTAL_SE columns (not _ACRE_SE) if se_col is None: for col in cols: col_upper = col.upper() - if "_SE" in col_upper and "PCT" not in col_upper and "PERCENT" not in col_upper and "EXPNS" not in col_upper: + if "_TOTAL_SE" in col_upper: se_col = col break @@ -389,19 +401,9 @@ def main(): client = EVALIDatorClient(timeout=60) all_results = [] - # Test Georgia - georgia_results = validate_state( - db_path="/Users/mihiarc/pyfia/data/georgia.duckdb", - state_code=13, - state_name="Georgia", - year=2023, - client=client - ) - all_results.extend(georgia_results) - # Test Rhode Island ri_results = validate_state( - db_path="/Users/mihiarc/pyfia/data/rhode_island.duckdb", + db_path=str(Path(__file__).parent.parent / "data" / "ri" / "ri" / "ri.duckdb"), state_code=44, state_name="Rhode Island", year=2024, # Use 2024 to match most recent EVALID in database diff --git a/src/pyfia/core/data_reader.py b/src/pyfia/core/data_reader.py index 4603344..fbe9050 100755 --- a/src/pyfia/core/data_reader.py +++ b/src/pyfia/core/data_reader.py @@ -53,8 +53,10 @@ def __init__( "motherduck:" ) + # Type annotation: str for MotherDuck, Path for local files + self.db_path: Union[str, Path] if self._is_motherduck: - self.db_path = db_str # type: ignore[assignment] + self.db_path = db_str else: self.db_path = Path(db_path) if not self.db_path.exists(): diff --git a/src/pyfia/core/utils.py b/src/pyfia/core/utils.py index d6973cf..6328c7f 100644 --- a/src/pyfia/core/utils.py +++ b/src/pyfia/core/utils.py @@ -58,4 +58,9 @@ def batch_query_by_values( if len(results) == 1: return results[0] - return pl.concat(results) + # Ensure consistent types for concat - collect LazyFrames if any exist + first_is_lazy = isinstance(results[0], pl.LazyFrame) + if first_is_lazy: + return pl.concat([r if isinstance(r, pl.LazyFrame) else r.lazy() for r in results]) + else: + return pl.concat([r.collect() if isinstance(r, pl.LazyFrame) else r for r in results]) diff --git a/src/pyfia/estimation/columns.py b/src/pyfia/estimation/columns.py index 1b208dc..5e066a1 100644 --- a/src/pyfia/estimation/columns.py +++ b/src/pyfia/estimation/columns.py @@ -134,12 +134,13 @@ def get_cond_columns( grp_by: Optional[Union[str, List[str]]] = None, base_cols: Optional[List[str]] = None, include_prop_basis: bool = False, + area_domain: Optional[str] = None, ) -> List[str]: """ Resolve condition columns for estimation. - Combines base columns, land type-specific columns, and grouping columns - into a single deduplicated list. + Combines base columns, land type-specific columns, grouping columns, + and columns referenced in area_domain into a single deduplicated list. Parameters ---------- @@ -155,6 +156,10 @@ def get_cond_columns( Override default base columns. If not provided, uses BASE_COND_COLUMNS. include_prop_basis : bool, default False Whether to include PROP_BASIS column for area adjustment calculations. + area_domain : str, optional + SQL-like domain expression. Columns referenced in this expression + will be automatically added to the column list if they are valid + condition columns. Returns ------- @@ -171,6 +176,9 @@ def get_cond_columns( >>> get_cond_columns(grp_by="OWNGRPCD") ['PLT_CN', 'CONDID', 'COND_STATUS_CD', 'CONDPROP_UNADJ', 'OWNGRPCD'] + + >>> get_cond_columns(area_domain="FORTYPCD == 161") + ['PLT_CN', 'CONDID', 'COND_STATUS_CD', 'CONDPROP_UNADJ', 'FORTYPCD'] """ cols = list(base_cols or BASE_COND_COLUMNS) @@ -192,4 +200,13 @@ def get_cond_columns( if col not in cols and col in COND_GROUPING_COLUMNS: cols.append(col) + # Add columns referenced in area_domain expression + if area_domain: + from ..filtering.parser import DomainExpressionParser + + domain_cols = DomainExpressionParser.extract_columns(area_domain) + for col in domain_cols: + if col not in cols and col in COND_GROUPING_COLUMNS: + cols.append(col) + return cols diff --git a/src/pyfia/estimation/estimators/area.py b/src/pyfia/estimation/estimators/area.py index 83015f5..02e1d4f 100644 --- a/src/pyfia/estimation/estimators/area.py +++ b/src/pyfia/estimation/estimators/area.py @@ -355,43 +355,92 @@ def calculate_variance(self, agg_result: AggregationResult) -> pl.DataFrame: # # If we have grouping variables, calculate variance for each group if group_cols: - # Calculate variance for each group separately - variance_results = [] - - for group_vals in results.iter_rows(): - # Filter plot data for this group - group_filter = pl.lit(True) - group_dict = {} - - for i, col in enumerate(group_cols): - if col in plot_data.columns: - val = group_vals[results.columns.index(col)] - group_dict[col] = val - if val is None: - group_filter = group_filter & pl.col(col).is_null() - else: - group_filter = group_filter & (pl.col(col) == val) - - group_plot_data = plot_data.filter(group_filter) - - if len(group_plot_data) > 0: - # Calculate variance for this group - var_stats = self._calculate_variance_for_group( - group_plot_data, strat_cols + # Use vectorized grouped variance calculation + from ..variance import calculate_grouped_domain_total_variance + + # Get valid group columns that exist in the data + valid_group_cols = [c for c in group_cols if c in plot_data.columns] + + if valid_group_cols: + # Determine the stratum column name + # Prefer STRATUM_CN over STRATUM, and don't use ESTN_UNIT for variance + if "STRATUM_CN" in strat_cols: + stratum_col = "STRATUM_CN" + elif "STRATUM" in strat_cols: + stratum_col = "STRATUM" + else: + stratum_col = strat_cols[0] if strat_cols else "STRATUM" + + # FIX for Issue #68: Variance underestimation for rare categories + # When grouping by a categorical variable like FORTYPCD, variance must + # include ALL plots in each stratum, not just plots matching each group. + # Non-matching plots contribute y=0 to the variance calculation. + # + # Step 1: Get unique group values from results + unique_groups = results.select(valid_group_cols).unique() + + # Step 2: Get base plot data without group columns (all plots with stratum info) + base_plot_cols = ["PLT_CN", stratum_col, "EXPNS"] + all_plots_base = cond_data.select( + [c for c in base_plot_cols if c in cond_data.columns] + ).unique() + + # Step 3: Cross-join all plots with unique groups + # Result: Every plot appears once for each group value + all_plots_expanded = all_plots_base.join(unique_groups, how="cross") + + # Step 4: Left join with actual plot_data on PLT_CN + group columns + # Plot P1 with FORTYPCD=809 data only matches (P1, 809) row + # All other (P1, other_type) rows get NULL y values -> filled with 0 + join_keys = ["PLT_CN"] + valid_group_cols + all_plots_with_data = all_plots_expanded.join( + plot_data.select(["PLT_CN", "y_i"] + valid_group_cols), + on=join_keys, + how="left", + ) + + # Step 5: Fill NULL y values with 0.0 + all_plots_with_data = all_plots_with_data.with_columns( + pl.col("y_i").fill_null(0.0) + ) + + # Calculate variance for all groups in one vectorized operation + var_df = calculate_grouped_domain_total_variance( + all_plots_with_data, + group_cols=valid_group_cols, + y_col="y_i", + x_col="y_i", # For area, x is also area (no ratio) + stratum_col=stratum_col, + weight_col="EXPNS", + ) + + # Rename columns to match expected output + var_df = var_df.rename( + { + "se_total": "AREA_SE", + "variance_total": "AREA_VARIANCE", + } + ) + + # SE_PERCENT will be calculated after joining with results + # since we need the actual AREA_TOTAL values + + # Select only the columns we need + keep_cols = valid_group_cols + ["AREA_SE", "AREA_VARIANCE"] + keep_cols = [c for c in keep_cols if c in var_df.columns] + var_df = var_df.select(keep_cols) + + # Join variance results back to main results + results = results.join(var_df, on=valid_group_cols, how="left") + + # Calculate SE_PERCENT using actual AREA_TOTAL from results + if "AREA_TOTAL" in results.columns and "AREA_SE" in results.columns: + results = results.with_columns( + pl.when(pl.col("AREA_TOTAL") > 0) + .then(100 * pl.col("AREA_SE") / pl.col("AREA_TOTAL")) + .otherwise(0.0) + .alias("AREA_SE_PERCENT") ) - variance_results.append( - { - **group_dict, - "AREA_SE": var_stats["se_total"], - "AREA_SE_PERCENT": var_stats["se_percent"], - "AREA_VARIANCE": var_stats["variance"], - } - ) - - # Join variance results back to main results - if variance_results: - var_df = pl.DataFrame(variance_results) - results = results.join(var_df, on=group_cols, how="left") else: # No grouping, calculate overall variance var_stats = self._calculate_variance_for_group(plot_data, strat_cols) diff --git a/src/pyfia/estimation/estimators/area_change.py b/src/pyfia/estimation/estimators/area_change.py index 1baac2c..ecfd7b5 100644 --- a/src/pyfia/estimation/estimators/area_change.py +++ b/src/pyfia/estimation/estimators/area_change.py @@ -11,7 +11,7 @@ FIA Database User Guide, SUBP_COND_CHNG_MTRX table documentation """ -from typing import List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import polars as pl @@ -19,6 +19,9 @@ from ..base import BaseEstimator from ..utils import format_output_columns +if TYPE_CHECKING: + from ..base import AggregationResult + class AreaChangeEstimator(BaseEstimator): """ @@ -412,12 +415,24 @@ def calculate_totals(self, data: pl.LazyFrame) -> pl.DataFrame: return result - def calculate_variance(self, result: pl.DataFrame) -> pl.DataFrame: + def calculate_variance( + self, agg_result: "AggregationResult | pl.DataFrame" + ) -> pl.DataFrame: """ Calculate variance for area change estimates. Uses stratified variance estimation following Bechtold & Patterson. + + Parameters + ---------- + agg_result : AggregationResult or pl.DataFrame + Either an AggregationResult bundle or a DataFrame for backward compat. """ + # Extract results DataFrame from AggregationResult if needed + from ..base import AggregationResult + + result = agg_result.results if isinstance(agg_result, AggregationResult) else agg_result + if self.plot_change_data is None: return result @@ -426,7 +441,7 @@ def calculate_variance(self, result: pl.DataFrame) -> pl.DataFrame: if grp_by: if isinstance(grp_by, str): grp_by = [grp_by] - group_cols = list(grp_by) + group_cols: list[str] = list(grp_by) else: group_cols = [] @@ -438,8 +453,8 @@ def calculate_variance(self, result: pl.DataFrame) -> pl.DataFrame: var_result = calculate_grouped_domain_total_variance( plot_data=self.plot_change_data, - value_col="AREA_CHANGE", - group_cols=group_cols if group_cols else None, + y_col="AREA_CHANGE", + group_cols=group_cols if group_cols else [], ) # Join variance to result diff --git a/src/pyfia/estimation/estimators/biomass.py b/src/pyfia/estimation/estimators/biomass.py index c62b6f1..e37c670 100644 --- a/src/pyfia/estimation/estimators/biomass.py +++ b/src/pyfia/estimation/estimators/biomass.py @@ -66,12 +66,14 @@ def get_cond_columns(self) -> List[str]: Uses centralized column resolution from columns.py to reduce duplication. Dynamically includes timber land columns when land_type='timber' and - adds grouping columns as needed. + adds grouping columns as needed. Also includes columns referenced in + area_domain for proper filtering. """ return _get_cond_columns( land_type=self.config.get("land_type", "forest"), grp_by=self.config.get("grp_by"), include_prop_basis=False, # Biomass doesn't need PROP_BASIS + area_domain=self.config.get("area_domain"), ) def calculate_values(self, data: pl.LazyFrame) -> pl.LazyFrame: diff --git a/src/pyfia/estimation/estimators/tpa.py b/src/pyfia/estimation/estimators/tpa.py index eba3682..4dffff7 100644 --- a/src/pyfia/estimation/estimators/tpa.py +++ b/src/pyfia/estimation/estimators/tpa.py @@ -57,12 +57,14 @@ def get_cond_columns(self) -> List[str]: Uses centralized column resolution from columns.py to reduce duplication. Dynamically includes timber land columns when land_type='timber' and - adds grouping columns as needed. + adds grouping columns as needed. Also includes columns referenced in + area_domain for proper filtering. """ return _get_cond_columns( land_type=self.config.get("land_type", "forest"), grp_by=self.config.get("grp_by"), include_prop_basis=False, # TPA doesn't need PROP_BASIS + area_domain=self.config.get("area_domain"), ) def calculate_values(self, data: pl.LazyFrame) -> pl.LazyFrame: diff --git a/src/pyfia/estimation/estimators/volume.py b/src/pyfia/estimation/estimators/volume.py index 46e9edb..030748e 100644 --- a/src/pyfia/estimation/estimators/volume.py +++ b/src/pyfia/estimation/estimators/volume.py @@ -64,11 +64,13 @@ def get_cond_columns(self) -> List[str]: Uses centralized column resolution from columns.py to reduce duplication. Volume estimation needs PROP_BASIS for area adjustment calculations. + Also includes columns referenced in area_domain for proper filtering. """ return _get_cond_columns( land_type=self.config.get("land_type", "forest"), grp_by=self.config.get("grp_by"), include_prop_basis=True, # Volume needs PROP_BASIS for area adjustment + area_domain=self.config.get("area_domain"), ) def calculate_values(self, data: pl.LazyFrame) -> pl.LazyFrame: diff --git a/src/pyfia/estimation/grm_base.py b/src/pyfia/estimation/grm_base.py index 15e6254..84deb28 100644 --- a/src/pyfia/estimation/grm_base.py +++ b/src/pyfia/estimation/grm_base.py @@ -102,11 +102,13 @@ def get_cond_columns(self) -> List[str]: Uses centralized column resolution from columns.py to reduce duplication. GRM estimation needs additional columns for filtering and grouping. + Also includes columns referenced in area_domain for proper filtering. """ base_cols = _get_cond_columns( land_type=self.config.get("land_type", "forest"), grp_by=self.config.get("grp_by"), include_prop_basis=False, + area_domain=self.config.get("area_domain"), ) # GRM estimation needs these columns for aggregate_cond_to_plot() diff --git a/src/pyfia/evalidator/client.py b/src/pyfia/evalidator/client.py index 3fd3632..31ef12e 100644 --- a/src/pyfia/evalidator/client.py +++ b/src/pyfia/evalidator/client.py @@ -147,11 +147,7 @@ def _make_request( # Handle empty responses (server returns 200 but no content) if not response.content or not response.content.strip(): - raise requests.exceptions.JSONDecodeError( - "Empty response from EVALIDator API", - doc="", - pos=0, - ) + raise ValueError("Empty response from EVALIDator API") data = response.json() diff --git a/src/pyfia/evalidator/estimate_types.py b/src/pyfia/evalidator/estimate_types.py index 3b5f69f..0ec81fe 100644 --- a/src/pyfia/evalidator/estimate_types.py +++ b/src/pyfia/evalidator/estimate_types.py @@ -42,17 +42,25 @@ class EstimateType(IntEnum): # --- AREA CHANGE (10 estimates) --- SNUM_126 = 126 + AREA_CHANGE_SAMPLED = 126 # alias SNUM_127 = 127 + AREA_CHANGE_FOREST_EITHER = 127 # alias SNUM_128 = 128 + AREA_CHANGE_FOREST_REMEASURED = 128 # alias SNUM_129 = 129 + AREA_CHANGE_TIMBERLAND_REMEASURED = 129 # alias SNUM_130 = 130 + AREA_CHANGE_TIMBERLAND_EITHER = 130 # alias SNUM_135 = 135 + AREA_CHANGE_ANNUAL_SAMPLED = 135 # alias SNUM_136 = 136 AREA_CHANGE_ANNUAL_FOREST_BOTH = 136 # alias SNUM_137 = 137 AREA_CHANGE_ANNUAL_FOREST_EITHER = 137 # alias SNUM_138 = 138 + AREA_CHANGE_ANNUAL_TIMBERLAND_BOTH = 138 # alias SNUM_139 = 139 + AREA_CHANGE_ANNUAL_TIMBERLAND_EITHER = 139 # alias # --- TREE COUNT (10 estimates) --- SNUM_4 = 4 @@ -398,9 +406,11 @@ class EstimateType(IntEnum): SNUM_10 = 10 BIOMASS_AG_LIVE = 10 # alias SNUM_13 = 13 + BIOMASS_AG_LIVE_5INCH = 13 # alias - aboveground biomass trees >=5" DBH SNUM_59 = 59 BIOMASS_BG_LIVE = 59 # alias SNUM_73 = 73 + BIOMASS_BG_LIVE_5INCH = 73 # alias - belowground biomass trees >=5" DBH SNUM_96 = 96 SNUM_105 = 105 SNUM_108 = 108 @@ -410,6 +420,7 @@ class EstimateType(IntEnum): SNUM_121 = 121 SNUM_124 = 124 SNUM_311 = 311 + GROWTH_NET_BIOMASS = 311 # alias - net annual growth of biomass SNUM_312 = 312 SNUM_313 = 313 SNUM_314 = 314 @@ -423,6 +434,7 @@ class EstimateType(IntEnum): SNUM_322 = 322 SNUM_335 = 335 SNUM_336 = 336 + MORTALITY_BIOMASS = 336 # alias - annual mortality of biomass SNUM_337 = 337 SNUM_338 = 338 SNUM_339 = 339 @@ -434,6 +446,7 @@ class EstimateType(IntEnum): SNUM_345 = 345 SNUM_346 = 346 SNUM_369 = 369 + REMOVALS_BIOMASS = 369 # alias - annual removals of biomass SNUM_370 = 370 SNUM_371 = 371 SNUM_372 = 372 diff --git a/src/pyfia/filtering/__init__.py b/src/pyfia/filtering/__init__.py index b76d14e..c123f45 100644 --- a/src/pyfia/filtering/__init__.py +++ b/src/pyfia/filtering/__init__.py @@ -29,7 +29,6 @@ add_forest_type_group, add_ownership_group_name, add_species_info, - assign_forest_type_group, assign_size_class, assign_species_group, assign_tree_basis, @@ -63,7 +62,6 @@ # Classification "assign_tree_basis", "assign_size_class", - "assign_forest_type_group", "assign_species_group", # Validation "ColumnValidator", diff --git a/src/pyfia/filtering/utils.py b/src/pyfia/filtering/utils.py index d4e5a8f..dccdc9c 100644 --- a/src/pyfia/filtering/utils.py +++ b/src/pyfia/filtering/utils.py @@ -583,53 +583,6 @@ def assign_size_class( return tree_df.with_columns(size_expr) -def assign_forest_type_group( - cond_df: pl.DataFrame, - fortypcd_column: str = "FORTYPCD", - output_column: str = "FOREST_TYPE_GROUP", -) -> pl.DataFrame: - """ - Assign forest type groups based on forest type codes. - - .. deprecated:: - Use `add_forest_type_group` instead for more accurate - western forest type handling. - - Groups forest types into major categories following FIA classification. - - Parameters - ---------- - cond_df : pl.DataFrame - Condition dataframe with forest type codes - fortypcd_column : str, default "FORTYPCD" - Column containing forest type codes - output_column : str, default "FOREST_TYPE_GROUP" - Name for output column - - Returns - ------- - pl.DataFrame - Condition dataframe with forest type group column added - - Examples - -------- - >>> # Add forest type groups - >>> conds_with_groups = assign_forest_type_group(conditions) - """ - import warnings - - warnings.warn( - "assign_forest_type_group is deprecated. Use add_forest_type_group instead.", - DeprecationWarning, - stacklevel=2, - ) - return add_forest_type_group( - cond_df, - fortypcd_col=fortypcd_column, - output_col=output_column, - ) - - def assign_species_group( tree_df: pl.DataFrame, species_df: pl.DataFrame, diff --git a/tests/e2e/test_download_e2e.py b/tests/e2e/test_download_e2e.py index 5f541d9..c34f501 100644 --- a/tests/e2e/test_download_e2e.py +++ b/tests/e2e/test_download_e2e.py @@ -273,25 +273,19 @@ def temp_data_dir(self): with tempfile.TemporaryDirectory() as temp_dir: yield Path(temp_dir) - @pytest.mark.skip( - reason="Reference tables are bundled in FIADB_REFERENCE.zip, not individual files" - ) def test_download_reference_species(self, temp_data_dir): - """Test downloading REF_SPECIES table. - - Note: FIA DataMart bundles all reference tables in FIADB_REFERENCE.zip - rather than providing individual files like REF_SPECIES.zip. - This test is skipped until we implement reference bundle download. - """ + """Test downloading REF_SPECIES table via reference bundle.""" client = DataMartClient(timeout=120) - csv_path = client.download_table( - state="REF", - table="REF_SPECIES", + # Use download_reference_tables which handles the bundled FIADB_REFERENCE.zip + result = client.download_reference_tables( dest_dir=temp_data_dir, + tables=["REF_SPECIES"], show_progress=True, ) + csv_path = result.get("REF_SPECIES") + assert csv_path is not None, "REF_SPECIES should be in result" assert csv_path.exists() # Verify contents diff --git a/tests/unit/test_classification.py b/tests/unit/test_classification.py index c1c6318..4a6ce98 100644 --- a/tests/unit/test_classification.py +++ b/tests/unit/test_classification.py @@ -10,7 +10,7 @@ import pytest from pyfia.filtering.utils import ( - assign_forest_type_group, + add_forest_type_group, assign_size_class, assign_species_group, assign_tree_basis, @@ -217,132 +217,89 @@ def test_boundary_values(self): assert classes[5] == "Large" -class TestAssignForestTypeGroup: - """Tests for assign_forest_type_group function. +class TestAddForestTypeGroup: + """Tests for add_forest_type_group function. - Note: This function is deprecated and now delegates to add_forest_type_group - from grouping_functions, which has more accurate western forest type handling. + This function adds forest type group names based on FORTYPCD codes, + with more accurate western forest type handling. """ - def test_deprecation_warning(self): - """Test that deprecation warning is raised.""" - import warnings - cond_df = pl.DataFrame({"FORTYPCD": [100]}) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - assign_forest_type_group(cond_df) - - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "deprecated" in str(w[0].message).lower() - def test_white_red_jack_pine(self): """Test 100-199 range returns White/Red/Jack Pine.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [100, 150, 199]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "White/Red/Jack Pine" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_spruce_fir_and_western_types(self): """Test 200-299 range returns Spruce/Fir or western forest type variants. - Note: The new implementation has more granular western forest type handling. + Note: The implementation has more granular western forest type handling. Code 200 returns 'Douglas-fir', 250/290 return 'Spruce/Fir'. """ - import warnings cond_df = pl.DataFrame({"FORTYPCD": [250, 290]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) # These non-special codes should still return Spruce/Fir assert all(g == "Spruce/Fir" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_douglas_fir_specific(self): """Test code 200 returns Douglas-fir specifically.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [200]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert result["FOREST_TYPE_GROUP"][0] == "Douglas-fir" def test_longleaf_slash_pine(self): """Test 300-399 range returns Longleaf/Slash Pine or western variants.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [350, 399]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Longleaf/Slash Pine" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_oak_pine(self): """Test 400-499 range returns Oak/Pine.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [400, 450, 499]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Oak/Pine" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_oak_hickory(self): """Test 500-599 range returns Oak/Hickory.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [500, 550, 599]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Oak/Hickory" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_oak_gum_cypress(self): """Test 600-699 range returns Oak/Gum/Cypress.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [600, 650, 699]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Oak/Gum/Cypress" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_elm_ash_cottonwood(self): """Test 700-799 range returns Elm/Ash/Cottonwood.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [700, 750, 799]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Elm/Ash/Cottonwood" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_maple_beech_birch(self): """Test 800-899 range returns Maple/Beech/Birch.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [800, 850, 899]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) assert all(g == "Maple/Beech/Birch" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_900_range_has_variants(self): """Test 900-999 range has various western hardwood types. - Note: The new implementation distinguishes between Aspen/Birch, + Note: The implementation distinguishes between Aspen/Birch, Alder/Maple, Western Oak, etc. in the 900 range. """ - import warnings cond_df = pl.DataFrame({"FORTYPCD": [900, 950, 999]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) groups = result["FOREST_TYPE_GROUP"].to_list() # 900 should be Aspen/Birch, 950 Other Western Hardwoods, 999 Nonstocked @@ -352,26 +309,20 @@ def test_900_range_has_variants(self): def test_other_unknown(self): """Test out-of-range codes return Other.""" - import warnings cond_df = pl.DataFrame({"FORTYPCD": [50, 1000, 0]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group(cond_df) + result = add_forest_type_group(cond_df) - # The new implementation returns "Other" for out-of-range codes + # Returns "Other" for out-of-range codes assert all(g == "Other" for g in result["FOREST_TYPE_GROUP"].to_list()) def test_custom_column_names(self): """Test custom input and output column names.""" - import warnings cond_df = pl.DataFrame({"MY_FORTYP": [500]}) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - result = assign_forest_type_group( - cond_df, - fortypcd_column="MY_FORTYP", - output_column="MY_GROUP", - ) + result = add_forest_type_group( + cond_df, + fortypcd_col="MY_FORTYP", + output_col="MY_GROUP", + ) assert "MY_GROUP" in result.columns assert result["MY_GROUP"][0] == "Oak/Hickory" diff --git a/tests/unit/test_variance_formulas.py b/tests/unit/test_variance_formulas.py index 8df512f..67410e6 100644 --- a/tests/unit/test_variance_formulas.py +++ b/tests/unit/test_variance_formulas.py @@ -1172,3 +1172,287 @@ def test_stratified_sampling_additivity(self, mock_fia_database): f"Additivity violation: V1={v1}, V2={v2}, expected total={expected_total}, " f"got {var_stats['variance']}" ) + + +# ============================================================================= +# TestRareCategoryVariance (Issue #68) +# ============================================================================= + + +class TestRareCategoryVariance: + """ + Test variance calculation for rare categories when using grp_by. + + Issue #68: Variance was severely underestimated for rare categories because + the calculation only used plots matching each group value instead of all + plots in the stratum (with non-matching plots contributing y=0). + + Example from issue: + - FORTYPCD 809 in stratum: 43 total plots, 3 have Type 809 + - Buggy variance (3 plots): 0.0209 + - Correct variance (43 plots with zeros): 0.0568 + - Underestimation: 2.7x + + The fix cross-joins all plots with unique group values, then left-joins + on PLT_CN + group columns so non-matching plots get y=0. + """ + + def test_rare_category_includes_all_plots_in_variance(self, mock_fia_database): + """ + Test that rare category variance uses all stratum plots, not just matching ones. + + Setup: 10 plots in one stratum, only 2 have forest type 999. + - Plots P1, P2 have FORTYPCD=999 with y=0.8, y=1.0 + - Plots P3-P10 have FORTYPCD=161 (common type) + + Buggy calculation (only 2 plots): + - mean = 0.9, s2 = 0.02, V = 1000^2 * 0.02 * 2 = 40,000 + + Correct calculation (10 plots, 8 with y=0): + - values = [0.8, 1.0, 0, 0, 0, 0, 0, 0, 0, 0] + - mean = 0.18, s2 = 0.1218, V = 1000^2 * 0.1218 * 10 = 1,217,777.78 + - This is ~30x higher than the buggy calculation! + """ + from pyfia.estimation.base import AggregationResult + from pyfia.estimation.estimators.area import AreaEstimator + + config = {"grp_by": "FORTYPCD"} + estimator = AreaEstimator(mock_fia_database, config) + + # Create plot data: 10 plots, 2 with rare type 999, 8 with common type 161 + plot_condition_data = pl.DataFrame( + { + "PLT_CN": [f"P{i}" for i in range(1, 11)], + "CONDID": [1] * 10, + "AREA_VALUE": [0.8, 1.0] + [0.5] * 8, # Type 999 has 0.8, 1.0 + "ADJ_FACTOR_AREA": [1.0] * 10, + "EXPNS": [1000.0] * 10, + "ESTN_UNIT": [1] * 10, + "STRATUM": [1] * 10, + "STRATUM_CN": [1] * 10, + "FORTYPCD": [999, 999] + [161] * 8, # 2 rare, 8 common + } + ) + + # Create results for type 999 only (we're testing the rare type) + results = pl.DataFrame( + { + "FORTYPCD": [999], + "AREA_TOTAL": [1800.0], # 0.8 + 1.0 = 1.8 * 1000 + "N_PLOTS": [2], + } + ) + + agg_result = AggregationResult( + results=results, + plot_tree_data=plot_condition_data, + group_cols=["FORTYPCD"], + ) + + variance_results = estimator.calculate_variance(agg_result) + + # Calculate expected variance with ALL 10 plots + # For type 999: 2 plots with y=0.8, 1.0; 8 plots with y=0 + y_values_correct = [0.8, 1.0] + [0.0] * 8 + s2_correct = np.var(y_values_correct, ddof=1) + expected_variance_correct = (1000.0**2) * s2_correct * 10 + + # Calculate buggy variance (only 2 plots) + y_values_buggy = [0.8, 1.0] + s2_buggy = np.var(y_values_buggy, ddof=1) + variance_buggy = (1000.0**2) * s2_buggy * 2 + + actual_variance = variance_results["AREA_VARIANCE"][0] + + # The correct variance should be much larger than the buggy one + assert actual_variance > variance_buggy * 5, ( + f"Variance {actual_variance} not significantly larger than buggy " + f"variance {variance_buggy}. Fix may not be working." + ) + + # Should be close to the correctly calculated value + assert abs(actual_variance - expected_variance_correct) < 1e-3, ( + f"Expected variance {expected_variance_correct}, got {actual_variance}" + ) + + def test_multiple_rare_categories_in_same_stratum(self, mock_fia_database): + """ + Test variance for multiple rare categories, each using all stratum plots. + + Setup: 10 plots, 3 different forest types: + - Type 999: 2 plots (P1, P2) + - Type 888: 1 plot (P3) + - Type 161: 7 plots (P4-P10) + """ + from pyfia.estimation.base import AggregationResult + from pyfia.estimation.estimators.area import AreaEstimator + + config = {"grp_by": "FORTYPCD"} + estimator = AreaEstimator(mock_fia_database, config) + + plot_condition_data = pl.DataFrame( + { + "PLT_CN": [f"P{i}" for i in range(1, 11)], + "CONDID": [1] * 10, + "AREA_VALUE": [0.8, 1.0, 0.6] + [0.5] * 7, + "ADJ_FACTOR_AREA": [1.0] * 10, + "EXPNS": [1000.0] * 10, + "ESTN_UNIT": [1] * 10, + "STRATUM": [1] * 10, + "STRATUM_CN": [1] * 10, + "FORTYPCD": [999, 999, 888] + [161] * 7, + } + ) + + results = pl.DataFrame( + { + "FORTYPCD": [999, 888], + "AREA_TOTAL": [1800.0, 600.0], + "N_PLOTS": [2, 1], + } + ) + + agg_result = AggregationResult( + results=results, + plot_tree_data=plot_condition_data, + group_cols=["FORTYPCD"], + ) + + variance_results = estimator.calculate_variance(agg_result) + + # Both rare types should have variance calculated using all 10 plots + assert len(variance_results) == 2 + + # Verify variances are calculated correctly for each type + for i, fortypcd in enumerate([999, 888]): + row = variance_results.filter(pl.col("FORTYPCD") == fortypcd) + actual_variance = row["AREA_VARIANCE"][0] + + # Get expected y values (matching plots have their value, others have 0) + if fortypcd == 999: + y_correct = [0.8, 1.0] + [0.0] * 8 + else: # 888 + y_correct = [0.0, 0.0, 0.6] + [0.0] * 7 + + s2_correct = np.var(y_correct, ddof=1) + expected_variance = (1000.0**2) * s2_correct * 10 + + assert abs(actual_variance - expected_variance) < 1e-3, ( + f"FORTYPCD {fortypcd}: expected {expected_variance}, got {actual_variance}" + ) + + def test_rare_category_across_multiple_strata(self, mock_fia_database): + """ + Test that rare category variance sums correctly across strata. + + Setup: 2 strata, rare type appears in both: + - Stratum 1: 6 plots, 1 has type 999 + - Stratum 2: 4 plots, 1 has type 999 + + Each stratum should independently contribute to the variance using + all its plots. + """ + from pyfia.estimation.base import AggregationResult + from pyfia.estimation.estimators.area import AreaEstimator + + config = {"grp_by": "FORTYPCD"} + estimator = AreaEstimator(mock_fia_database, config) + + # Stratum 1: 6 plots, P1 has type 999 with y=0.9 + # Stratum 2: 4 plots, P7 has type 999 with y=1.1 + plot_condition_data = pl.DataFrame( + { + "PLT_CN": [f"P{i}" for i in range(1, 11)], + "CONDID": [1] * 10, + "AREA_VALUE": [0.9] + [0.5] * 5 + [1.1] + [0.6] * 3, + "ADJ_FACTOR_AREA": [1.0] * 10, + "EXPNS": [1000.0] * 6 + [1500.0] * 4, # Different weights per stratum + "ESTN_UNIT": [1] * 10, + "STRATUM": [1] * 6 + [2] * 4, + "STRATUM_CN": [1] * 6 + [2] * 4, + "FORTYPCD": [999] + [161] * 5 + [999] + [161] * 3, + } + ) + + results = pl.DataFrame( + { + "FORTYPCD": [999], + "AREA_TOTAL": [0.9 * 1000 + 1.1 * 1500], # 2550 + "N_PLOTS": [2], + } + ) + + agg_result = AggregationResult( + results=results, + plot_tree_data=plot_condition_data, + group_cols=["FORTYPCD"], + ) + + variance_results = estimator.calculate_variance(agg_result) + + # Calculate expected variance for each stratum + # Stratum 1: y=[0.9, 0, 0, 0, 0, 0], w=1000 + y1 = [0.9] + [0.0] * 5 + s2_1 = np.var(y1, ddof=1) + v1 = (1000.0**2) * s2_1 * 6 + + # Stratum 2: y=[1.1, 0, 0, 0], w=1500 + y2 = [1.1] + [0.0] * 3 + s2_2 = np.var(y2, ddof=1) + v2 = (1500.0**2) * s2_2 * 4 + + expected_total_variance = v1 + v2 + + actual_variance = variance_results["AREA_VARIANCE"][0] + + assert abs(actual_variance - expected_total_variance) < 1e-3, ( + f"Expected variance {expected_total_variance} (V1={v1}, V2={v2}), " + f"got {actual_variance}" + ) + + def test_null_group_values_handled(self, mock_fia_database): + """ + Test that NULL group values don't cause errors in variance calculation. + """ + from pyfia.estimation.base import AggregationResult + from pyfia.estimation.estimators.area import AreaEstimator + + config = {"grp_by": "FORTYPCD"} + estimator = AreaEstimator(mock_fia_database, config) + + # Include some NULL FORTYPCD values + plot_condition_data = pl.DataFrame( + { + "PLT_CN": ["P1", "P2", "P3", "P4", "P5"], + "CONDID": [1, 1, 1, 1, 1], + "AREA_VALUE": [0.8, 1.0, 0.5, 0.6, 0.7], + "ADJ_FACTOR_AREA": [1.0] * 5, + "EXPNS": [1000.0] * 5, + "ESTN_UNIT": [1] * 5, + "STRATUM": [1] * 5, + "STRATUM_CN": [1] * 5, + "FORTYPCD": [999, 999, None, 161, 161], # P3 has NULL + } + ) + + # Results only include non-NULL types + results = pl.DataFrame( + { + "FORTYPCD": [999, 161], + "AREA_TOTAL": [1800.0, 1300.0], + "N_PLOTS": [2, 2], + } + ) + + agg_result = AggregationResult( + results=results, + plot_tree_data=plot_condition_data, + group_cols=["FORTYPCD"], + ) + + # Should not raise an error + variance_results = estimator.calculate_variance(agg_result) + + # Should have results for both non-NULL types + assert len(variance_results) == 2 + assert all(v >= 0 for v in variance_results["AREA_VARIANCE"].to_list()) diff --git a/tests/validation/test_biomass.py b/tests/validation/test_biomass.py index b5dde2e..864d3e5 100644 --- a/tests/validation/test_biomass.py +++ b/tests/validation/test_biomass.py @@ -107,11 +107,21 @@ def test_belowground_biomass(self, fia_db, evalidator_client): f"pyFIA: {pyfia_plot_count} vs EVALIDator: {ev_result.plot_count}" ) - @pytest.mark.skip(reason="snum=13 returns trees >=1\" not >=5\" - need correct snum") + @pytest.mark.skip( + reason="No direct EVALIDator snum for 'live tree biomass >=5\" DBH'. " + "EVALIDator only offers growing-stock (snum=96 is dead trees, " + "snum=312 is growth). Would need strFilter parameter to validate." + ) def test_biomass_5inch_trees(self, fia_db, evalidator_client): """Validate aboveground biomass for trees >=5" DBH. - NOTE: Skipped - EVALIDator snum=13 is for trees >=1" DBH, not >=5". - Need to verify correct snum or use strFilter parameter. + NOTE: Skipped - EVALIDator doesn't have a direct estimate for + "aboveground biomass of live trees >= 5 inches DBH". Options: + - snum=10: Live trees >= 1" DBH (too inclusive) + - snum=96: Dead trees >= 5" DBH (wrong tree type) + - snum=312: Growing-stock growth >= 5" DBH (wrong metric) + + To validate this, would need to use EVALIDator's strFilter parameter + to filter snum=10 to DIA >= 5. """ pass diff --git a/uv.lock b/uv.lock index 7898b9a..2046a45 100644 --- a/uv.lock +++ b/uv.lock @@ -1101,13 +1101,12 @@ wheels = [ [[package]] name = "pyfia" -version = "1.2.0" +version = "1.2.3" source = { editable = "." } dependencies = [ { name = "connectorx" }, { name = "duckdb" }, { name = "numpy" }, - { name = "pandas" }, { name = "polars" }, { name = "pyarrow" }, { name = "pydantic" }, @@ -1162,7 +1161,6 @@ requires-dist = [ { name = "mkdocs-material", marker = "extra == 'dev'", specifier = ">=9.5.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.16.1" }, { name = "numpy", specifier = ">=1.26.0" }, - { name = "pandas", specifier = ">=2.0.0" }, { name = "pandas", marker = "extra == 'pandas'", specifier = ">=2.0.0" }, { name = "polars", specifier = ">=1.0.0" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" }, From 2da8f66d4b489987729db7870188b4a76efb5290 Mon Sep 17 00:00:00 2001 From: Chris Mihiar <28452317+mihiarc@users.noreply.github.com> Date: Sat, 7 Feb 2026 06:34:54 -0500 Subject: [PATCH 3/3] Fix hardcoded year 2023 in area estimator format_output The base class was fixed to use _extract_evaluation_year() but the area estimator overrides format_output and still had pl.lit(2023). Now uses the same pattern as all other estimators. --- src/pyfia/estimation/estimators/area.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pyfia/estimation/estimators/area.py b/src/pyfia/estimation/estimators/area.py index 02e1d4f..852c1d7 100644 --- a/src/pyfia/estimation/estimators/area.py +++ b/src/pyfia/estimation/estimators/area.py @@ -568,7 +568,8 @@ def format_output(self, results: pl.DataFrame) -> pl.DataFrame: ) # Add year - results = results.with_columns([pl.lit(2023).alias("YEAR")]) + year = self._extract_evaluation_year() + results = results.with_columns([pl.lit(year).alias("YEAR")]) # Format columns results = format_output_columns(