-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import torch | ||
import pathlib | ||
|
||
# Define colors | ||
COLORS = ['#4daf4a'] | ||
|
||
# Paths | ||
PATH = pathlib.Path(__file__).parent.parent.parent.absolute() | ||
DATA_PATH = PATH / "data" | ||
LOGGING = False | ||
|
||
matplotlib.rc('text', usetex=True) | ||
matplotlib.rc('font', family='serif') | ||
matplotlib.rc('text.latex', preamble=r'\usepackage{amsmath}') | ||
|
||
datasets = { | ||
"uniform_new_HC": DATA_PATH / "Gaussian_logCA0_uniform_J=20_HC.csv", | ||
"uniform_new_MC": DATA_PATH / "Gaussian_logCA0_uniform_J=20_MC.csv", | ||
"adaptive_new_HC": DATA_PATH / "Gaussian_logCA0_adaptive_J=20_HC.csv", | ||
"adaptive_new_MC": DATA_PATH / "Gaussian_logCA0_adaptive_J=20_MC.csv", | ||
} | ||
|
||
# Loading data | ||
def load_data(path: str) -> tuple: | ||
data = np.loadtxt(path, delimiter=",", skiprows=1, dtype=np.float32) | ||
x_train = data[:, 0] | ||
y_train = data[:, 1] | ||
|
||
return x_train, y_train | ||
|
||
def remove_points(train_x, train_y, nobs): | ||
""" | ||
Remove points from the dataset equally spaced | ||
""" | ||
idx = np.linspace(0, len(train_x) - 1, nobs).astype(int) | ||
return train_x[idx], train_y[idx] | ||
|
||
|
||
true_MC_x, true_MC_y = load_data(DATA_PATH / "true_Gaussian_logCA0_MC.csv") | ||
true_HC_x, true_HC_y = load_data(DATA_PATH / "true_Gaussian_logCA0_HC.csv") | ||
|
||
model = 'SCAM' | ||
if __name__ == '__main__': | ||
for dataset_name, dataset_path in datasets.items(): | ||
if 'HC' in dataset_name: | ||
x_true, y_true = true_HC_x, true_HC_y | ||
else: | ||
x_true, y_true = true_MC_x[1:], true_MC_y[1:] | ||
|
||
for nobs in range(10, 21, 2) + [21]: | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 4), tight_layout=True) | ||
name = f"{model}_{nobs}_nobs" | ||
io_path = PATH / "experiments" / str(dataset_name) | ||
io_path.mkdir(parents=True, exist_ok=True) | ||
|
||
train_x, train_y = load_data(dataset_path) | ||
train_x, train_y = remove_points(train_x, train_y, nobs) | ||
|
||
test_x, test_y = load_data(io_path / f'{name}_.csv') | ||
mse_loss = np.mean((test_y - y_true)**2) | ||
|
||
ax.plot(x_true, y_true, color=COLORS[0]) | ||
ax.plot(train_x, train_y, "k*") | ||
ax.legend(["Observed Values", "Mean", "Confidence"]) | ||
ax.set_title("Function values") | ||
ax.set_xlim([0, 1]) | ||
# ax.set_ylim([-7.5, 12.5]) | ||
ax.set_xlabel(r"$\alpha$") | ||
ax.set_ylabel(r"$f_{\boldsymbol{z}}(\alpha)$") | ||
ax.text(0.5, 0.05, f"MSE: {mse_loss:.2f}", | ||
transform=ax.transAxes, | ||
fontsize=16, | ||
bbox=dict(facecolor=COLORS[0], alpha=0.5)) | ||
|
||
fig.suptitle("Shape Constrained Additive Model (SCAM)", | ||
fontsize=18, | ||
color=COLORS[0], | ||
fontweight='bold') | ||
|
||
# plt.show() | ||
fig.savefig(io_path / (name + ".png"), dpi=600, bbox_inches="tight") | ||
plt.close(fig) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
library(scam) | ||
library(dplyr) | ||
library(readr) | ||
|
||
# Function to remove points equally spaced in the dataset | ||
remove_points <- function(x, y, nobs) { | ||
idx <- seq(1, length(x), length.out = nobs) | ||
return(list(x = x[idx], y = y[idx])) | ||
} | ||
|
||
# Paths to datasets | ||
data_paths <- list( | ||
"uniform_new_HC" = "../../data/Gaussian_logCA0_uniform_J=20_HC.csv", | ||
"uniform_new_MC" = "../../data/Gaussian_logCA0_uniform_J=20_MC.csv", | ||
"adaptive_new_HC" = "../../data/Gaussian_logCA0_adaptive_J=20_HC.csv", | ||
"adaptive_new_MC" = "../../data/Gaussian_logCA0_adaptive_J=20_MC.csv" | ||
) | ||
|
||
# Define nobs scenarios | ||
nobs_scenarios <- append(seq(6, 20, by = 2), 21) | ||
|
||
# Data frame to store results | ||
results <- data.frame(model = character(), scenario = integer(), dataset = character(), mse = numeric()) | ||
|
||
x_true <- read_csv("../../data/true_Gaussian_logCA0_HC.csv")$a0 | ||
# Experiment loop | ||
for (dataset_name in names(data_paths)) { | ||
dataset_path <- data_paths[[dataset_name]] | ||
|
||
# Read data | ||
data <- read_csv(dataset_path) | ||
x <- data$a0 | ||
y <- data$lc_a0 | ||
|
||
|
||
for (nobs in nobs_scenarios) { | ||
|
||
# Remove points to get desired number of observations | ||
sampled_data <- remove_points(x, y, nobs) | ||
x_sampled <- sampled_data$x | ||
y_sampled <- sampled_data$y | ||
|
||
# Fit the SCAM model | ||
model <- tryCatch( | ||
{ | ||
scam(y_sampled ~ s(x_sampled, bs = "cx")) | ||
}, | ||
error = function(e) { | ||
cat("Error in model fitting:", e$message, "\n") | ||
return(NULL) | ||
} | ||
) | ||
|
||
if (!is.null(model)) { | ||
# Generate predictions on a regular grid from 0 to 1 | ||
grid <- data.frame(x_sampled = x_true) | ||
grid$y_pred <- predict(model, newdata = grid) | ||
# print("Got here") | ||
|
||
# Save predictions | ||
output_file <- paste0("../../experiments/", dataset_name, "/SCAM_", nobs , "_nobs_", ".csv") | ||
write.csv(grid, output_file, row.names = FALSE) | ||
|
||
# Calculate Mean Squared Error (MSE) on the sampled data | ||
#y_pred_sampled <- predict(model, newdata = data.frame(x = x_sampled)) | ||
#mse <- mean((y_sampled - y_pred_sampled)^2) | ||
|
||
# Store results | ||
#results <- rbind( | ||
# results, | ||
# data.frame(model = "SCAM", scenario = nobs, dataset = dataset_name, mse = mse) | ||
#) | ||
|
||
cat("Finished SCAM model on", dataset_name, "with nobs =", nobs, "\n") | ||
rm(grid) | ||
} | ||
} | ||
} | ||
|
||
# Save MSE results to CSV | ||
#write.csv(results, "scam_mse_results.csv", row.names = FALSE) | ||
|
||
# Display results | ||
#print(results) | ||
|
||
|
||
ggplot(data, aes(x = x, y = y)) + | ||
geom_point(color = "blue", alpha = 0.5) + # Original data points | ||
geom_line(data = grid, aes(x = x, y = y_pred), color = "red") + # Predicted line on grid | ||
labs(title = "Shape-Constrained Additive Model Fit", | ||
x = "x", y = "y") + | ||
theme_minimal() |