Skip to content

Commit

Permalink
Merge pull request #83 from pat-alt/73-aries-comments
Browse files Browse the repository at this point in the history
73 aries comments
  • Loading branch information
pat-alt authored Sep 26, 2023
2 parents cc69941 + bc0ce9f commit 901f4c0
Show file tree
Hide file tree
Showing 79 changed files with 8,724 additions and 3,844 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Other:
/artifacts/
/.quarto/
**/.quarto/
/Manifest.toml
/results*/
**/.CondaPkg
Expand Down
2,837 changes: 0 additions & 2,837 deletions bib.bib

This file was deleted.

496 changes: 481 additions & 15 deletions experiments/Manifest.toml

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions experiments/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
MosaicViews = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
ghr_jll = "07c12ed4-43bc-5495-8a2a-d5838ef8d533"
5 changes: 3 additions & 2 deletions experiments/grid_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ const ECCCo_Δ_NAMES = [
"ECCCo-Δ",
"ECCCo-Δ (no CP)",
"ECCCo-Δ (no EBM)",
"ECCCo-Δ (latent)",
]

"""
Expand Down Expand Up @@ -117,7 +118,7 @@ Return the best outcome from grid search results. The best outcome is defined as
function best_absolute_outcome(
outcomes::Dict;
generator=ECCCO_NAMES,
measure::AbstractArray=["distance_from_targets_l2", "distance_from_energy_l2"],
measure::AbstractArray=["distance_from_energy_l2"],
model::Union{Nothing,AbstractArray}=nothing,
weights::Union{Nothing,AbstractArray}=nothing
)
Expand All @@ -129,7 +130,7 @@ function best_absolute_outcome(
for (params, outcome) in outcomes

# Setup
evaluation = outcome.bmk.evaluation
evaluation = deepcopy(outcome.bmk.evaluation)
exper = outcome.exper
generator_dict = outcome.generator_dict
model_dict = outcome.model_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

module load 2023r1 openmpi

srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=25 > experiments/grid_search_california_housing.log
srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=25 store_ce > experiments/grid_search_california_housing.log
2 changes: 1 addition & 1 deletion experiments/jobscripts/tuning/generators/tabular.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

module load 2023r1 openmpi

srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit output_path=results mpi grid_search n_individuals=25 > experiments/grid_search_tabular.log
srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit output_path=results mpi grid_search n_individuals=25 store_ce > experiments/grid_search_tabular.log
204 changes: 204 additions & 0 deletions experiments/post_processing/plotting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
using Plots

function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5", img_height=125, seed=966)

# Set seed:
if !isnothing(seed)
Random.seed!(seed)
end

# Get output:
bmk = outcome.bmk()
grouped_bmk = groupby(bmk[bmk.variable.=="distance" .&& bmk.model.==model,:], [:dataname, :target, :factual])
random_choice = rand(1:length(grouped_bmk))
generators = unique(bmk.generator)
n_generators = length(generators)

# Get data:
df = grouped_bmk[random_choice][1:n_generators, :] |>
x -> sort(x, :generator) |>
x -> subset(x, :generator => ByRow(x -> x != "ECCCo"))
generators = df.generator
replace!(generators, "ECCCo-Δ" => "ECCCo")
replace!(generators, "ECCCo-Δ (latent)" => "ECCCo+")
n_generators = length(generators)

# Factual:
img = CounterfactualExplanations.factual(grouped_bmk[random_choice][1:n_generators,:].ce[1]) |> ECCCo.convert2mnist
p1 = Plots.plot(
img,
axis=([], false),
size=(img_height, img_height),
title="Factual",
)
plts = [p1]
ces = []

# Counterfactuals:
for (i, generator) in enumerate(generators)
ce = df.ce[i]
img = CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist
p = Plots.plot(
img,
axis=([], false),
size=(img_height, img_height),
title="$generator",
)
push!(plts, p)
push!(ces, ce)
end

plt = Plots.plot(
plts...,
layout=(1, n_generators + 1),
size=(img_height * (n_generators + 1), img_height),
dpi=300
)
display(plt)

return plt, df.target[1], seed, ces, df.sample[1]

end

function plot_random_eccco(outcome::ExperimentOutcome; ce=nothing, generator="ECCCo-Δ", img_height=200, seed=966)
# Set seed:
if !isnothing(seed)
Random.seed!(seed)
end

# Get output:
bmk = outcome.bmk()
ce = isnothing(ce) ? rand(bmk.ce) : ce
gen = outcome.generator_dict[generator]
models = outcome.model_dict
x = CounterfactualExplanations.counterfactual(ce)
target = ce.target
data = ce.data

# Factual:
img = CounterfactualExplanations.factual(ce) |> ECCCo.convert2mnist
p1 = Plots.plot(
img,
axis=([], false),
size=(img_height, img_height),
title="Factual",
)
plts = [p1]

for (model_name, M) in models
ce = generate_counterfactual(x, target, data, M, gen; initialization=:identity, converge_when=:generator_conditions)
img = CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist
p = Plots.plot(
img,
axis=([], false),
size=(img_height, img_height),
title="$model_name",
)
push!(plts, p)
end
n_models = length(models)

plt = Plots.plot(
plts...,
layout=(1, n_models + 1),
size=(img_height * (n_models + 1), img_height),
dpi=300
)
display(plt)

return plt, target, seed
end

function plot_all_mnist(gen, model, data=load_mnist_test(); img_height=150, seed=123, maxoutdim=64)

# Set seed:
if !isnothing(seed)
Random.seed!(seed)
end

# Dimensionality reduction:
data.dt = MultivariateStats.fit(MultivariateStats.PCA, data.X; maxoutdim=maxoutdim)

# VAE for REVISE:
data.generative_model = CounterfactualExplanations.Models.load_mnist_vae()

targets = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
factuals = targets
plts = []

for factual in factuals
chosen = rand(findall(data.output_encoder.labels .== factual))
x = select_factual(data, chosen)
for target in targets
if factual != target
@info "Generating counterfactual for $(factual) -> $(target)"
ce = generate_counterfactual(
x, target, data, model, gen;
initialization=:identity, converge_when=:generator_conditions
)
plt = Plots.plot(
CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist,
axis=([], false),
size=(img_height, img_height),
title="$factual$target",
)
else
plt = Plots.plot(
x |> ECCCo.convert2mnist,
axis=([], false),
size=(img_height, img_height),
title="Factual",
)
end
push!(plts, plt)
end
end

plt = Plots.plot(
plts...,
layout=(length(factuals), length(targets)),
size=(img_height * length(targets), img_height * length(factuals)),
dpi=300
)

return plt

end

using MLDatasets
using MosaicViews
function vae_reconstructions(seed=123)

# Set seed:
if !isnothing(seed)
Random.seed!(seed)
end

counterfactual_data = load_mnist()
counterfactual_data.generative_model = CounterfactualExplanations.Models.load_mnist_vae()
X = counterfactual_data.X
y = counterfactual_data.output_encoder.y
images = []
rec_images = []
for i in 0:9
j = 0
while j < 10
x = X[:,rand(findall(y .== i))]
= CounterfactualExplanations.GenerativeModels.reconstruct(vae, x)[1] |>
-> clamp.((x̂ .+ 1.0) ./ 2.0, 0.0, 1.0) |>
-> reshape(x̂, 28,28) |>
-> MLDatasets.convert2image(MNIST, x̂)
x = clamp.((x .+ 1.0) ./ 2.0, 0.0, 1.0) |>
x -> reshape(x, 28,28) |>
x -> MLDatasets.convert2image(MNIST, x)
push!(images, x)
push!(rec_images, x̂)
j += 1
end
end
p1 = plot(mosaic(images..., ncol=10), title="Images")
p2 = plot(mosaic(rec_images..., ncol=10), title="Reconstructions")
plt = plot(p1, p2, axis=false, size=(800,375))

return plt
end
3 changes: 2 additions & 1 deletion experiments/post_processing/post_processing.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include("meta_data.jl")
include("artifacts.jl")
include("results.jl")
include("results.jl")
include("plotting.jl")
2 changes: 1 addition & 1 deletion experiments/setup_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ DEFAULT_GENERATOR_TUNING_LARGE = (
],
reg_strength=[0.0, 0.1, 0.5],
opt=[
Flux.Optimise.Descent(0.05),
Flux.Optimise.Descent(0.01),
Flux.Optimise.Descent(0.05),
],
)

Expand Down
16 changes: 8 additions & 8 deletions notebooks/tables.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ kbl(
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
Expand Down Expand Up @@ -227,7 +227,7 @@ kbl(
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
Expand Down Expand Up @@ -255,7 +255,7 @@ kbl(
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>%
save_kable("paper/contents/table_all.tex")
```
Expand All @@ -282,7 +282,7 @@ kbl(
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>%
save_kable("paper/contents/table_all_valid.tex")
```
Expand Down Expand Up @@ -317,7 +317,7 @@ kbl(
format="latex"
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
save_kable("paper/contents/table_ebm_params.tex")
```

Expand All @@ -337,7 +337,7 @@ kbl(
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
save_kable("paper/contents/table_params.tex")
```
Expand All @@ -361,7 +361,7 @@ kbl(
format="latex"
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
save_kable("paper/contents/table_gen_params.tex")
```

Expand All @@ -387,7 +387,7 @@ kbl(
format="latex", digits=2
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(c(" "=2, "Performance Metrics" = 3)) %>%
collapse_rows(columns = 1, latex_hline = "custom", valign = "top", custom_latex_hline = 1) %>%
save_kable("paper/contents/table_perf.tex")
Expand Down
1 change: 1 addition & 0 deletions paper/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/.quarto/
1 change: 1 addition & 0 deletions paper/_quarto.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bibliography: bib.bib
Loading

0 comments on commit 901f4c0

Please sign in to comment.