Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 85 additions & 58 deletions src/pyfia/estimation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,11 +733,14 @@ def _expand_plots_for_all_groups(
tuple[pl.DataFrame, List[str]]
Expanded plot data and valid group columns
"""
# Get all plots from stratification
# Get all plots from stratification (include B&P columns when available)
strat_data = self._get_stratification_data()
all_plots = (
strat_data.select("PLT_CN", "STRATUM_CN", "EXPNS").unique().collect()
)
strat_schema = strat_data.collect_schema().names()
bp_cols = ["ESTN_UNIT_CN", "STRATUM_WGT", "AREA_USED", "P2POINTCNT"]
select_cols = ["PLT_CN", "STRATUM_CN", "EXPNS"] + [
c for c in bp_cols if c in strat_schema
]
all_plots = strat_data.select(select_cols).unique().collect()

valid_group_cols = [c for c in group_cols if c in plot_data.columns]

Expand Down Expand Up @@ -1023,9 +1026,11 @@ def _calculate_variance_for_metrics(
)

# Build plot-level aggregation
plot_agg_exprs = [
pl.sum("CONDPROP_UNADJ").cast(pl.Float64).alias("x_i")
] if "CONDPROP_UNADJ" in plot_cond_data.columns else [pl.lit(1.0).alias("x_i")]
plot_agg_exprs = (
[pl.sum("CONDPROP_UNADJ").cast(pl.Float64).alias("x_i")]
if "CONDPROP_UNADJ" in plot_cond_data.columns
else [pl.lit(1.0).alias("x_i")]
)

for i in range(len(metric_configs)):
plot_agg_exprs.append(pl.sum(f"y_{i}_ic").alias(f"y_{i}_i"))
Expand All @@ -1034,9 +1039,13 @@ def _calculate_variance_for_metrics(

# Step 3: Get ALL plots in the evaluation for proper variance calculation
strat_data = self._get_stratification_data()
all_plots = (
strat_data.select("PLT_CN", "STRATUM_CN", "EXPNS").unique().collect()
)
strat_schema = strat_data.collect_schema().names()
# Select B&P variance columns when available
bp_cols = ["ESTN_UNIT_CN", "STRATUM_WGT", "AREA_USED", "P2POINTCNT"]
select_cols = ["PLT_CN", "STRATUM_CN", "EXPNS"] + [
c for c in bp_cols if c in strat_schema
]
all_plots = strat_data.select(select_cols).unique().collect()

# Step 4: Calculate variance for each group or overall
if group_cols:
Expand Down Expand Up @@ -1074,8 +1083,11 @@ def _calculate_grouped_multi_metric_variance(

Uses a loop over groups for now; could be further optimized with
vectorized operations in the future.

Uses ratio-of-means variance for per-acre SE:
V(R) = (1/X^2) * [V(Y) + R^2*V(X) - 2*R*Cov(Y,X)]
"""
from .variance import calculate_domain_total_variance
from .variance import calculate_ratio_of_means_variance

variance_results = []

Expand All @@ -1097,7 +1109,9 @@ def _calculate_grouped_multi_metric_variance(
group_plot_data = plot_data.filter(group_filter)

# Build select columns for join
select_cols = ["PLT_CN", "x_i"] + [f"y_{i}_i" for i in range(len(metric_configs))]
select_cols = ["PLT_CN", "x_i"] + [
f"y_{i}_i" for i in range(len(metric_configs))
]
select_cols = [c for c in select_cols if c in group_plot_data.columns]

# Join with ALL plots, filling missing with zeros
Expand All @@ -1120,29 +1134,26 @@ def _calculate_grouped_multi_metric_variance(
result_row = dict(group_dict)

if len(all_plots_group) > 0:
# Calculate total area for per-acre SE
total_area = (
all_plots_group["EXPNS"] * all_plots_group["x_i"]
).sum()

for i, cfg in enumerate(metric_configs):
y_col = f"y_{i}_i"
if y_col in all_plots_group.columns:
var_stats = calculate_domain_total_variance(
all_plots_group, y_col
)
se_acre = (
var_stats["se_total"] / total_area if total_area > 0 else 0.0
ratio_stats = calculate_ratio_of_means_variance(
all_plots_group, y_col, "x_i"
)
se_acre = ratio_stats["se_ratio"]

result_row[cfg["acre_se_col"]] = se_acre
result_row[cfg["total_se_col"]] = var_stats["se_total"]
result_row[cfg["total_se_col"]] = ratio_stats["se_total"]

# Add variance columns if specified
if "acre_var_col" in cfg:
result_row[cfg["acre_var_col"]] = se_acre**2
result_row[cfg["acre_var_col"]] = ratio_stats[
"variance_ratio"
]
if "total_var_col" in cfg:
result_row[cfg["total_var_col"]] = var_stats["variance_total"]
result_row[cfg["total_var_col"]] = ratio_stats[
"variance_total"
]
else:
# No data for this group
for cfg in metric_configs:
Expand All @@ -1159,7 +1170,9 @@ def _calculate_grouped_multi_metric_variance(
if variance_results:
var_df = pl.DataFrame(variance_results)
# Use only valid group columns that exist in both dataframes
join_cols = [c for c in group_cols if c in var_df.columns and c in results.columns]
join_cols = [
c for c in group_cols if c in var_df.columns and c in results.columns
]
if join_cols:
results = results.join(var_df, on=join_cols, how="left")

Expand All @@ -1174,11 +1187,16 @@ def _calculate_overall_multi_metric_variance(
) -> pl.DataFrame:
"""
Calculate overall variance (ungrouped) for multiple metrics.

Uses ratio-of-means variance for per-acre SE:
V(R) = (1/X^2) * [V(Y) + R^2*V(X) - 2*R*Cov(Y,X)]
"""
from .variance import calculate_domain_total_variance
from .variance import calculate_ratio_of_means_variance

# Build select columns for join
select_cols = ["PLT_CN", "x_i"] + [f"y_{i}_i" for i in range(len(metric_configs))]
select_cols = ["PLT_CN", "x_i"] + [
f"y_{i}_i" for i in range(len(metric_configs))
]
select_cols = [c for c in select_cols if c in plot_data.columns]

# Join with ALL plots, filling missing with zeros
Expand All @@ -1197,34 +1215,39 @@ def _calculate_overall_multi_metric_variance(

all_plots_with_values = all_plots_with_values.with_columns(fill_exprs)

# Calculate total area for per-acre SE
total_area = (
all_plots_with_values["EXPNS"] * all_plots_with_values["x_i"]
).sum()

# Calculate variance for each metric and add to results
for i, cfg in enumerate(metric_configs):
y_col = f"y_{i}_i"
if y_col in all_plots_with_values.columns:
var_stats = calculate_domain_total_variance(
all_plots_with_values, y_col
ratio_stats = calculate_ratio_of_means_variance(
all_plots_with_values, y_col, "x_i"
)
se_acre = var_stats["se_total"] / total_area if total_area > 0 else 0.0
se_acre = ratio_stats["se_ratio"]

results = results.with_columns([
pl.lit(se_acre).alias(cfg["acre_se_col"]),
pl.lit(var_stats["se_total"]).alias(cfg["total_se_col"]),
])
results = results.with_columns(
[
pl.lit(se_acre).alias(cfg["acre_se_col"]),
pl.lit(ratio_stats["se_total"]).alias(cfg["total_se_col"]),
]
)

# Add variance columns if specified
if "acre_var_col" in cfg:
results = results.with_columns([
pl.lit(se_acre**2).alias(cfg["acre_var_col"]),
])
results = results.with_columns(
[
pl.lit(ratio_stats["variance_ratio"]).alias(
cfg["acre_var_col"]
),
]
)
if "total_var_col" in cfg:
results = results.with_columns([
pl.lit(var_stats["variance_total"]).alias(cfg["total_var_col"]),
])
results = results.with_columns(
[
pl.lit(ratio_stats["variance_total"]).alias(
cfg["total_var_col"]
),
]
)

return results

Expand All @@ -1242,21 +1265,25 @@ def _add_cv_columns(

if acre_col and acre_col in results.columns:
cv_col = acre_se_col.replace("_SE", "_CV")
results = results.with_columns([
pl.when(pl.col(acre_col) > 0)
.then(pl.col(acre_se_col) / pl.col(acre_col) * 100)
.otherwise(None)
.alias(cv_col),
])
results = results.with_columns(
[
pl.when(pl.col(acre_col) > 0)
.then(pl.col(acre_se_col) / pl.col(acre_col) * 100)
.otherwise(None)
.alias(cv_col),
]
)

if total_col and total_col in results.columns:
cv_col = total_se_col.replace("_SE", "_CV")
results = results.with_columns([
pl.when(pl.col(total_col) > 0)
.then(pl.col(total_se_col) / pl.col(total_col) * 100)
.otherwise(None)
.alias(cv_col),
])
results = results.with_columns(
[
pl.when(pl.col(total_col) > 0)
.then(pl.col(total_se_col) / pl.col(total_col) * 100)
.otherwise(None)
.alias(cv_col),
]
)

return results

Expand Down
49 changes: 49 additions & 0 deletions src/pyfia/estimation/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,66 @@ def get_stratification_data(self) -> pl.LazyFrame:
# when joining with other tables that also have STATECD, INVYR, etc.
ppsa_selected = ppsa_unique.select(["PLT_CN", "STRATUM_CN"])

# Load POP_ESTN_UNIT for exact B&P variance formula
if "POP_ESTN_UNIT" not in self.db.tables:
self.db.load_table("POP_ESTN_UNIT")
pop_estn_unit = self.db.tables["POP_ESTN_UNIT"]
if not isinstance(pop_estn_unit, pl.LazyFrame):
pop_estn_unit = pop_estn_unit.lazy()

# Apply EVALID filter to POP_ESTN_UNIT
if self.db.evalid:
pop_estn_unit = pop_estn_unit.filter(
pl.col("EVALID").is_in(self.db.evalid)
)

# Deduplicate POP_ESTN_UNIT (same reason as POP_STRATUM)
pop_estn_unit_unique = pop_estn_unit.unique(subset=["CN"])

# Select columns from POP_ESTN_UNIT
pop_estn_unit_selected = pop_estn_unit_unique.select(
[
pl.col("CN").alias("ESTN_UNIT_CN"),
"AREA_USED",
"P1PNTCNT_EU",
]
)

# Select necessary columns from POP_STRATUM
# Include ESTN_UNIT_CN, P1POINTCNT, P2POINTCNT for exact B&P variance
pop_stratum_selected = pop_stratum_unique.select(
[
pl.col("CN").alias("STRATUM_CN"),
"ESTN_UNIT_CN",
"EXPNS",
"ADJ_FACTOR_MICR",
"ADJ_FACTOR_SUBP",
"ADJ_FACTOR_MACR",
"P1POINTCNT",
"P2POINTCNT",
]
)

# Join POP_STRATUM with POP_ESTN_UNIT to get AREA_USED and P1PNTCNT_EU
pop_stratum_selected = pop_stratum_selected.join(
pop_estn_unit_selected, on="ESTN_UNIT_CN", how="left"
)

# Compute STRATUM_WGT = P1POINTCNT / P1PNTCNT_EU
# Guard against null or zero P1PNTCNT_EU to avoid inf/NaN propagation
pop_stratum_selected = pop_stratum_selected.with_columns(
pl.when(
pl.col("P1PNTCNT_EU").is_not_null()
& (pl.col("P1PNTCNT_EU").cast(pl.Float64) > 0)
)
.then(
pl.col("P1POINTCNT").cast(pl.Float64)
/ pl.col("P1PNTCNT_EU").cast(pl.Float64)
)
.otherwise(0.0)
.alias("STRATUM_WGT")
)

# Select MACRO_BREAKPOINT_DIA from PLOT table
# This is CRITICAL for correct adjustment factor selection in states with macroplots
plot_cols = [pl.col("CN").alias("PLT_CN"), "MACRO_BREAKPOINT_DIA"]
Expand Down
Loading
Loading