Skip to content

Commit

Permalink
Update dist_select plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander März committed Aug 29, 2023
1 parent ec43e67 commit b22e976
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 31 deletions.
9 changes: 4 additions & 5 deletions xgboostlss/distributions/distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,11 @@ def dist_select(self,
}
)
dist_list.append(fit_df)
fit_df = pd.concat(dist_list).sort_values(by=self.loss_fn, ascending=True)
fit_df["rank"] = fit_df[self.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
pbar.update(1)
pbar.set_description(f"Fitting of candidate distributions completed")

fit_df = pd.concat(dist_list).sort_values(by=self.loss_fn, ascending=True)
fit_df["rank"] = fit_df[self.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
if plot:
# Select best distribution
best_dist = fit_df[fit_df["rank"] == 1].reset_index(drop=True)
Expand Down Expand Up @@ -652,7 +651,7 @@ def dist_select(self,
sns.kdeplot(target.reshape(-1, ), label="Actual")
sns.kdeplot(dist_samples.reshape(-1, ), label=f"Best-Fit: {best_dist['distribution'].values[0]}")
plt.legend()
plt.title("Actual vs. Best-Fit Density")
plt.title("Actual vs. Best-Fit Density", fontweight="bold", fontsize=16)
plt.show()

fit_df.drop(columns=["rank", "params"], inplace=True)
Expand Down
9 changes: 4 additions & 5 deletions xgboostlss/distributions/flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,11 @@ def flow_select(self,
}
)
flow_list.append(fit_df)
fit_df = pd.concat(flow_list).sort_values(by=flow_sel.loss_fn, ascending=True)
fit_df["rank"] = fit_df[flow_sel.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
pbar.update(1)
pbar.set_description(f"Fitting of candidate normalizing flows completed")

fit_df = pd.concat(flow_list).sort_values(by=flow_sel.loss_fn, ascending=True)
fit_df["rank"] = fit_df[flow_sel.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
if plot:
# Select normalizing flow with the lowest loss
best_flow = fit_df[fit_df["rank"] == 1].reset_index(drop=True)
Expand All @@ -706,7 +705,7 @@ def flow_select(self,
sns.kdeplot(target.reshape(-1, ), label="Actual")
sns.kdeplot(flow_samples.reshape(-1, ), label=f"Best-Fit: {best_flow['NormFlow'].values[0]}")
plt.legend()
plt.title("Actual vs. Best-Fit Density")
plt.title("Actual vs. Best-Fit Density", fontweight="bold", fontsize=16)
plt.show()

fit_df.drop(columns=["rank", "params"], inplace=True)
Expand Down
2 changes: 1 addition & 1 deletion xgboostlss/distributions/mixture_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def dist_select(self,
sns.kdeplot(target.reshape(-1,), label="Actual")
sns.kdeplot(dist_samples.reshape(-1,), label=f"Best-Fit: {best_dist['distribution'].values[0]}")
plt.legend()
plt.title("Actual vs. Best-Fit Density")
plt.title("Actual vs. Best-Fit Density", fontweight="bold", fontsize=16)
plt.show()

fit_df.drop(columns=["rank", "params", "dist_pos", "M"], inplace=True)
Expand Down
23 changes: 3 additions & 20 deletions xgboostlss/distributions/multivariate_distribution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,11 @@ def dist_select(self,
}
)
dist_list.append(fit_df)
fit_df = pd.concat(dist_list).sort_values(by=dist_sel.loss_fn, ascending=True)
fit_df["rank"] = fit_df[dist_sel.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
pbar.update(1)
pbar.set_description(f"Fitting of candidate distributions completed")

fit_df = pd.concat(dist_list).sort_values(by=dist_sel.loss_fn, ascending=True)
fit_df["rank"] = fit_df[dist_sel.loss_fn].rank().astype(int)
fit_df.set_index(fit_df["rank"], inplace=True)
if plot:
warnings.simplefilter(action='ignore', category=UserWarning)
# Select distribution
Expand Down Expand Up @@ -630,22 +629,6 @@ def dist_select(self,
g.fig.suptitle("Actual vs. Best-Fit Density", weight="bold", fontsize=16)
g.fig.tight_layout(rect=[0, 0, 1, 0.9])

# print(
# ggplot(plot_df,
# aes(x="value",
# color="type")) +
# geom_density(alpha=0.5) +
# facet_wrap("target",
# scales="free",
# ncol=ncol) +
# theme_bw(base_size=15) +
# theme(figure_size=figure_size,
# legend_position="right",
# legend_title=element_blank(),
# plot_title=element_text(hjust=0.5)) +
# labs(title=f"Actual vs. Fitted Density")
# )

fit_df.drop(columns=["rank", "params"], inplace=True)

return fit_df
Expand Down

0 comments on commit b22e976

Please sign in to comment.