diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..7a73a41b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,2 @@ +{ +} \ No newline at end of file diff --git a/Artifacts.toml b/Artifacts.toml index 06bdd8ea..a1a752c0 100644 --- a/Artifacts.toml +++ b/Artifacts.toml @@ -22,6 +22,14 @@ lazy = true sha256 = "2e7decbde29d8d34333931e856d6b86e0c1f61cb6942a1a04cfd672aa2e64ce3" url = "https://github.com/pat-alt/ECCCo.jl/releases/download/artifacts_results_20230907_1535/artifacts_results_20230907_1535.tar.gz" +[results-extended] +git-tree-sha1 = "80fb590edda3c364e3bba56f7ebb810b52c73447" +lazy = true + + [[results-extended.download]] + sha256 = "1d161c7ec10a47347d5f3a8a5a915c9564688738640087de40437281b02e971d" + url = "https://github.com/pat-alt/ECCCo.jl/releases/download/results-extended/results-extended.tar.gz" + ["results-paper-submission-1.8.5"] git-tree-sha1 = "3be5119c4ce466017db79f10fdb72b97e745bd7d" lazy = true diff --git a/dev/artifacts.jl b/dev/artifacts.jl index 99ffb7f7..f60ac842 100644 --- a/dev/artifacts.jl +++ b/dev/artifacts.jl @@ -9,11 +9,11 @@ artifact_toml = LazyArtifacts.find_artifacts_toml(".") function generate_artifacts( datafiles; - artifact_name="artifacts-$VERSION", - root=".", - artifact_toml=joinpath(root, "Artifacts.toml"), - deploy=true, - tag=nothing, + artifact_name = "artifacts-$VERSION", + root = ".", + artifact_toml = joinpath(root, "Artifacts.toml"), + deploy = true, + tag = nothing, ) if isnothing(tag) tag = replace(lowercase(artifact_name), " " => "-") @@ -54,9 +54,9 @@ function generate_artifacts( artifact_toml, artifact_name, hash; - download_info=[(tarball_url, tarball_hash)], - lazy=true, - force=true, + download_info = [(tarball_url, tarball_hash)], + lazy = true, + force = true, ) end @@ -76,7 +76,7 @@ function generate_artifacts( end end -function get_git_remote_url(repo_path::String=".") +function get_git_remote_url(repo_path::String = ".") repo = LibGit2.GitRepo(repo_path) origin = LibGit2.get(LibGit2.GitRemote, repo, "origin") return LibGit2.url(origin) diff --git a/dev/rebuttal/aaai/draft.md b/dev/rebuttal/aaai/draft.md new file mode 100644 index 00000000..7a00d0c6 --- /dev/null +++ b/dev/rebuttal/aaai/draft.md @@ -0,0 +1,60 @@ +# Reviewer 1 + +## Weaknesses: + +1. Experiment results: need more comprehensive linguistic explanation of results + +Following the suggestion by the reviewer, we plan to add a the following linguistic explanation in a prominent place of Section 6: + +"Overall, our findings demonstrate that \textit{ECCCo} produces plausible counterfactuals if and only if the black-box model itself has learned plausible explanations for the data. Thus, \textit{ECCCo} avoids the risk of generating plausible but potentially misleading explanations for models that are highly susceptible to implausible explanations. We therefore believe that \textit{ECCCo} can help researchers and practitioners to generate explanations they can trust and discern unreliable from trustworthy models." + +Elements of this explanation are already scattered across the paper, but we agree that it would be useful to highlight this notion also in Section 6. + +2. Core innovation: need more visualizations in 2D/3D space + +Following the reviewers suggestion, we have plotted the distance of randomly generated MNIST images from images in the target class against their energy-constrained score. As expected, this relationship is positive: the higher the distance, the higher the corresponding generative loss. The size of this relationship appears to depend positively on the model's generative property: the observed relationships are stronger for joint energy models. + +3. Structural clarity: add a flow chart + +Figure 2 shows our synthetic linearly separable data in the feature space, so the highlighted path corresponds to the actual counterfactual path of the sample. We will clarify this in the paper. + +Adding a systematic flowchart is a great idea. Due to limited scope, may we suggest adding the following flowchart to the appendix? Alternatively, we may swap out Figure 2 for the flowchart. + +# Reviewer 2 + +## Weaknesses: + +1. Why the embedding? + +We agree that for any type of surrogate model, there is a risk of introducing bias. In exceptional cases, however, it may be necessary to accept some degree of bias in favour of plausibility. Our results for \textit{ECCCo+} demonstrate this tradeoff as we discuss in Section 6.3. In the context of PCA, the introduced bias can be explained intuitively: by constraining the counterfactual search to the space spanned by the first $n_z$ principal components, the search is sensitive only to the variation in the data explained by those components. In other words, we would expect counterfactuals to be less sensitive to small variations in features that do not typically vary much. It is therefore an intuitive finding, that \textit{ECCCo+} tends to generate less noisy counterfactual images, for example (the same is true for \textit{REVISE}). In our mind, restricting the search space to the first $n_z$ components quite literally corresponds to denoising the search space and hence the resulting counterfactuals. We will highlight this rationale in Section 6.3. + +We think that the bias introduced by PCA may be acceptable in some cases, precisely because it "will not add any information on the input distribution" as the reviewer correctly points out. To maintain faithfulness, we want to avoid introducing additional information through surrogate models as much as possible. We will make this intuition clearer in Section 6.3. + +Another argument in favour of using a lower-dimensional latent embedding is the reduction in computational costs, which can be prohibitive for high-dimensional input data. We will highlight this in Section 5. + +2. What is "epsilon" and "s"? + +From the paper: "$\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is the stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed. [...] To allow for faster sampling, we follow the common practice of choosing the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately." We go on to explain in the appendix that we use the following biased sampler + +$$ +\hat{\mathbf{x}}_{j+1} \leftarrow \hat{\mathbf{x}}_j - \frac{\phi}{2} \mathcal{E}_{\theta}(\hat{\mathbf{x}}_j|\mathbf{y}^+) + \sigma \mathbf{r}_j, j=1,...,J +$$ + +where "consistent with~\citet{grathwohl2020your}, we have specified $\phi=2$ and $\sigma=0.01$ as the default values for all of our experiments". Intuitively, $\epsilon_j$ determines the size of gradient updates and random noise in each iteration of SGLD. + +Regarding $s(\cdot)$, this was an oversight, apologies. In the appendix we explain that "[the calibration dataset] is then used to compute so-called nonconformity scores: $\mathcal{S}=\{s(\mathbf{x}_i,\mathbf{y}_i)\}_{i \in \mathcal{D}_{\text{cal}}}$ where $s: (\mathcal{X},\mathcal{Y}) \mapsto \mathbb{R}$ is referred to as \textit{score function}." We will add this in Section 4.2 of the main paper. + +1. Euclidean distance: problems in high dim? Latent space? + +As we mention in the additonal author response, we investigated different distance metrics. We found that the overall qualitative results were largely independent of the exact metric. In the context of the high-dimensional image data, we still decided to report the results for a dissimilarity metric that is more appropriate in this context. All of our distance-based metrics are computed with respect to features, not latent features. This is because, as the reviewer correctly points out, we would expect certaint discrepencies between distances evaluated in the feature space and distances evaluated in the latent space of the VAE, for example. Working in the feature space does comes with higher computational costs, but the evaluation of counterfactuals was generally less constly than generating counterfactuals in the first place. In cases where high dimensionality leads to prohibitive computational costs, we would suggest to either reduce the number of nearest neighbours or work in a lower-dimensional subspace that is independent of the underlying classifier itself (such as PCA). + +4. Faithfulness measure biased? + +We have taken measures to not unfairly bias our generator with respect to the unfaithfulness metric: instead of penalizing the unfaithfulness metric directly, we penalize model energy in our preferred implementation. In contrast, \textit{Wachter} penalizes the closeness criterion directly and hence does particularly well in this regard. That being said, \textit{ECCCo} is of course designed to generate faithful explanations first and foremost and therefore has an advantage with respect to our faithfulness metric. In lieue of other established metrics to measure faithfulness, we can only point out that \textit{ECCCo} achieves strong performance for other commonly used metrics as well. With respect to \textit{validity}, for example, which as we have explained corresponds to \textit{fidelity}, \textit{ECCCo} typically outperformans \textit{REVISE} and \textit{Schut}. + +Our joint energy models (JEM) are indeed explicitly trained to model $\mathcal{X}|y$ and the same quantity is used in our proposed faithfulness metric. But the faithfulness metric itself is not computed with respect to samples generated by our JEMs. It is computed with respect to counterfactuals generated by merely constraining model energy and we would therefore argue that it is not unfairly biased. Our empirical findings support this argument: firstly, \textit{ECCCo} achieves high faithfulness also for classifiers that have not been trained to model $\mathcal{X}|y$; secondly, our additional results in the appendix for \textit{ECCCo-L1} show that if we do indeed explicitly penalise the unfaithfulness metric, we achieve even better results in this regard. + +6. Test with unreliable models + +We would argue that the simple multi-layer perceptrons (MLPs) are unreliable, especially compared to ensembles, joint energy models and convolutional neural networks for our image datasets. Simple neural networks have been shown to be vulnerable to adversarial attacks, which makes them susceptible to implausible counterfactual explanations as we point out in Section 3. Our results support this notion, in that they demonstrate faithful model explanations only coincide with high plausibility if the model itself has been trained to be more reliable. Consistent with the idea proposed by the reviewer, we originally considered introducing "poisoned" VAEs as well, to illustrate what we identify as the key vulnerability of \textit{REVISE}. If the underlying VAE is trained on poisoned data, this could be expected to adversely affect counterfactual outcomes as well. We ultimately discarded this idea due to limited scope and because we decided that Section 3 sufficiently illustrates our thinking. + diff --git a/dev/rebuttal/.gitignore b/dev/rebuttal/neurips/.gitignore similarity index 100% rename from dev/rebuttal/.gitignore rename to dev/rebuttal/neurips/.gitignore diff --git a/dev/rebuttal/6zGr.md b/dev/rebuttal/neurips/6zGr.md similarity index 100% rename from dev/rebuttal/6zGr.md rename to dev/rebuttal/neurips/6zGr.md diff --git a/dev/rebuttal/ZaU8.md b/dev/rebuttal/neurips/ZaU8.md similarity index 100% rename from dev/rebuttal/ZaU8.md rename to dev/rebuttal/neurips/ZaU8.md diff --git a/dev/rebuttal/_quarto.yml b/dev/rebuttal/neurips/_quarto.yml similarity index 100% rename from dev/rebuttal/_quarto.yml rename to dev/rebuttal/neurips/_quarto.yml diff --git a/dev/rebuttal/global.md b/dev/rebuttal/neurips/global.md similarity index 100% rename from dev/rebuttal/global.md rename to dev/rebuttal/neurips/global.md diff --git a/dev/rebuttal/pekM.md b/dev/rebuttal/neurips/pekM.md similarity index 100% rename from dev/rebuttal/pekM.md rename to dev/rebuttal/neurips/pekM.md diff --git a/dev/rebuttal/support.pdf b/dev/rebuttal/neurips/support.pdf similarity index 100% rename from dev/rebuttal/support.pdf rename to dev/rebuttal/neurips/support.pdf diff --git a/dev/rebuttal/support.qmd b/dev/rebuttal/neurips/support.qmd similarity index 100% rename from dev/rebuttal/support.qmd rename to dev/rebuttal/neurips/support.qmd diff --git a/dev/rebuttal/uCjw.md b/dev/rebuttal/neurips/uCjw.md similarity index 100% rename from dev/rebuttal/uCjw.md rename to dev/rebuttal/neurips/uCjw.md diff --git a/dev/rebuttal/www/fmnist_boot.png b/dev/rebuttal/neurips/www/fmnist_boot.png similarity index 100% rename from dev/rebuttal/www/fmnist_boot.png rename to dev/rebuttal/neurips/www/fmnist_boot.png diff --git a/dev/rebuttal/www/fmnist_boot_lenet.png b/dev/rebuttal/neurips/www/fmnist_boot_lenet.png similarity index 100% rename from dev/rebuttal/www/fmnist_boot_lenet.png rename to dev/rebuttal/neurips/www/fmnist_boot_lenet.png diff --git a/dev/rebuttal/www/fmnist_dress.png b/dev/rebuttal/neurips/www/fmnist_dress.png similarity index 100% rename from dev/rebuttal/www/fmnist_dress.png rename to dev/rebuttal/neurips/www/fmnist_dress.png diff --git a/dev/rebuttal/www/fmnist_dress_lenet.png b/dev/rebuttal/neurips/www/fmnist_dress_lenet.png similarity index 100% rename from dev/rebuttal/www/fmnist_dress_lenet.png rename to dev/rebuttal/neurips/www/fmnist_dress_lenet.png diff --git a/dev/rebuttal/www/fmnist_pullover.png b/dev/rebuttal/neurips/www/fmnist_pullover.png similarity index 100% rename from dev/rebuttal/www/fmnist_pullover.png rename to dev/rebuttal/neurips/www/fmnist_pullover.png diff --git a/dev/rebuttal/www/fmnist_pullover_lenet.png b/dev/rebuttal/neurips/www/fmnist_pullover_lenet.png similarity index 100% rename from dev/rebuttal/www/fmnist_pullover_lenet.png rename to dev/rebuttal/neurips/www/fmnist_pullover_lenet.png diff --git a/dev/rebuttal/www/mnist_0to3_28.png b/dev/rebuttal/neurips/www/mnist_0to3_28.png similarity index 100% rename from dev/rebuttal/www/mnist_0to3_28.png rename to dev/rebuttal/neurips/www/mnist_0to3_28.png diff --git a/dev/rebuttal/www/mnist_0to3_29.png b/dev/rebuttal/neurips/www/mnist_0to3_29.png similarity index 100% rename from dev/rebuttal/www/mnist_0to3_29.png rename to dev/rebuttal/neurips/www/mnist_0to3_29.png diff --git a/dev/rebuttal/www/mnist_0to3_30.png b/dev/rebuttal/neurips/www/mnist_0to3_30.png similarity index 100% rename from dev/rebuttal/www/mnist_0to3_30.png rename to dev/rebuttal/neurips/www/mnist_0to3_30.png diff --git a/dev/rebuttal/www/mnist_1to4_25.png b/dev/rebuttal/neurips/www/mnist_1to4_25.png similarity index 100% rename from dev/rebuttal/www/mnist_1to4_25.png rename to dev/rebuttal/neurips/www/mnist_1to4_25.png diff --git a/dev/rebuttal/www/mnist_1to4_26.png b/dev/rebuttal/neurips/www/mnist_1to4_26.png similarity index 100% rename from dev/rebuttal/www/mnist_1to4_26.png rename to dev/rebuttal/neurips/www/mnist_1to4_26.png diff --git a/dev/rebuttal/www/mnist_1to4_27.png b/dev/rebuttal/neurips/www/mnist_1to4_27.png similarity index 100% rename from dev/rebuttal/www/mnist_1to4_27.png rename to dev/rebuttal/neurips/www/mnist_1to4_27.png diff --git a/dev/rebuttal/www/mnist_1to7_22.png b/dev/rebuttal/neurips/www/mnist_1to7_22.png similarity index 100% rename from dev/rebuttal/www/mnist_1to7_22.png rename to dev/rebuttal/neurips/www/mnist_1to7_22.png diff --git a/dev/rebuttal/www/mnist_1to7_23.png b/dev/rebuttal/neurips/www/mnist_1to7_23.png similarity index 100% rename from dev/rebuttal/www/mnist_1to7_23.png rename to dev/rebuttal/neurips/www/mnist_1to7_23.png diff --git a/dev/rebuttal/www/mnist_1to7_24.png b/dev/rebuttal/neurips/www/mnist_1to7_24.png similarity index 100% rename from dev/rebuttal/www/mnist_1to7_24.png rename to dev/rebuttal/neurips/www/mnist_1to7_24.png diff --git a/dev/rebuttal/www/mnist_2to3_13.png b/dev/rebuttal/neurips/www/mnist_2to3_13.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_13.png rename to dev/rebuttal/neurips/www/mnist_2to3_13.png diff --git a/dev/rebuttal/www/mnist_2to3_14.png b/dev/rebuttal/neurips/www/mnist_2to3_14.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_14.png rename to dev/rebuttal/neurips/www/mnist_2to3_14.png diff --git a/dev/rebuttal/www/mnist_2to3_15.png b/dev/rebuttal/neurips/www/mnist_2to3_15.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_15.png rename to dev/rebuttal/neurips/www/mnist_2to3_15.png diff --git a/dev/rebuttal/www/mnist_2to3_16.png b/dev/rebuttal/neurips/www/mnist_2to3_16.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_16.png rename to dev/rebuttal/neurips/www/mnist_2to3_16.png diff --git a/dev/rebuttal/www/mnist_2to3_17.png b/dev/rebuttal/neurips/www/mnist_2to3_17.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_17.png rename to dev/rebuttal/neurips/www/mnist_2to3_17.png diff --git a/dev/rebuttal/www/mnist_2to3_18.png b/dev/rebuttal/neurips/www/mnist_2to3_18.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_18.png rename to dev/rebuttal/neurips/www/mnist_2to3_18.png diff --git a/dev/rebuttal/www/mnist_2to3_19.png b/dev/rebuttal/neurips/www/mnist_2to3_19.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_19.png rename to dev/rebuttal/neurips/www/mnist_2to3_19.png diff --git a/dev/rebuttal/www/mnist_2to3_20.png b/dev/rebuttal/neurips/www/mnist_2to3_20.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_20.png rename to dev/rebuttal/neurips/www/mnist_2to3_20.png diff --git a/dev/rebuttal/www/mnist_2to3_7.png b/dev/rebuttal/neurips/www/mnist_2to3_7.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_7.png rename to dev/rebuttal/neurips/www/mnist_2to3_7.png diff --git a/dev/rebuttal/www/mnist_2to3_8.png b/dev/rebuttal/neurips/www/mnist_2to3_8.png similarity index 100% rename from dev/rebuttal/www/mnist_2to3_8.png rename to dev/rebuttal/neurips/www/mnist_2to3_8.png diff --git a/dev/rebuttal/www/mnist_4to1_10.png b/dev/rebuttal/neurips/www/mnist_4to1_10.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_10.png rename to dev/rebuttal/neurips/www/mnist_4to1_10.png diff --git a/dev/rebuttal/www/mnist_4to1_11.png b/dev/rebuttal/neurips/www/mnist_4to1_11.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_11.png rename to dev/rebuttal/neurips/www/mnist_4to1_11.png diff --git a/dev/rebuttal/www/mnist_4to1_12.png b/dev/rebuttal/neurips/www/mnist_4to1_12.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_12.png rename to dev/rebuttal/neurips/www/mnist_4to1_12.png diff --git a/dev/rebuttal/www/mnist_4to1_13.png b/dev/rebuttal/neurips/www/mnist_4to1_13.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_13.png rename to dev/rebuttal/neurips/www/mnist_4to1_13.png diff --git a/dev/rebuttal/www/mnist_4to1_14.png b/dev/rebuttal/neurips/www/mnist_4to1_14.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_14.png rename to dev/rebuttal/neurips/www/mnist_4to1_14.png diff --git a/dev/rebuttal/www/mnist_4to1_15.png b/dev/rebuttal/neurips/www/mnist_4to1_15.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_15.png rename to dev/rebuttal/neurips/www/mnist_4to1_15.png diff --git a/dev/rebuttal/www/mnist_4to1_5.png b/dev/rebuttal/neurips/www/mnist_4to1_5.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_5.png rename to dev/rebuttal/neurips/www/mnist_4to1_5.png diff --git a/dev/rebuttal/www/mnist_4to1_6.png b/dev/rebuttal/neurips/www/mnist_4to1_6.png similarity index 100% rename from dev/rebuttal/www/mnist_4to1_6.png rename to dev/rebuttal/neurips/www/mnist_4to1_6.png diff --git a/dev/rebuttal/www/mnist_5to8_10.png b/dev/rebuttal/neurips/www/mnist_5to8_10.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_10.png rename to dev/rebuttal/neurips/www/mnist_5to8_10.png diff --git a/dev/rebuttal/www/mnist_5to8_16.png b/dev/rebuttal/neurips/www/mnist_5to8_16.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_16.png rename to dev/rebuttal/neurips/www/mnist_5to8_16.png diff --git a/dev/rebuttal/www/mnist_5to8_17.png b/dev/rebuttal/neurips/www/mnist_5to8_17.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_17.png rename to dev/rebuttal/neurips/www/mnist_5to8_17.png diff --git a/dev/rebuttal/www/mnist_5to8_18.png b/dev/rebuttal/neurips/www/mnist_5to8_18.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_18.png rename to dev/rebuttal/neurips/www/mnist_5to8_18.png diff --git a/dev/rebuttal/www/mnist_5to8_21.png b/dev/rebuttal/neurips/www/mnist_5to8_21.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_21.png rename to dev/rebuttal/neurips/www/mnist_5to8_21.png diff --git a/dev/rebuttal/www/mnist_5to8_22.png b/dev/rebuttal/neurips/www/mnist_5to8_22.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_22.png rename to dev/rebuttal/neurips/www/mnist_5to8_22.png diff --git a/dev/rebuttal/www/mnist_5to8_23.png b/dev/rebuttal/neurips/www/mnist_5to8_23.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_23.png rename to dev/rebuttal/neurips/www/mnist_5to8_23.png diff --git a/dev/rebuttal/www/mnist_5to8_24.png b/dev/rebuttal/neurips/www/mnist_5to8_24.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_24.png rename to dev/rebuttal/neurips/www/mnist_5to8_24.png diff --git a/dev/rebuttal/www/mnist_5to8_25.png b/dev/rebuttal/neurips/www/mnist_5to8_25.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_25.png rename to dev/rebuttal/neurips/www/mnist_5to8_25.png diff --git a/dev/rebuttal/www/mnist_5to8_9.png b/dev/rebuttal/neurips/www/mnist_5to8_9.png similarity index 100% rename from dev/rebuttal/www/mnist_5to8_9.png rename to dev/rebuttal/neurips/www/mnist_5to8_9.png diff --git a/dev/rebuttal/www/mnist_6to0_10.png b/dev/rebuttal/neurips/www/mnist_6to0_10.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_10.png rename to dev/rebuttal/neurips/www/mnist_6to0_10.png diff --git a/dev/rebuttal/www/mnist_6to0_3.png b/dev/rebuttal/neurips/www/mnist_6to0_3.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_3.png rename to dev/rebuttal/neurips/www/mnist_6to0_3.png diff --git a/dev/rebuttal/www/mnist_6to0_4.png b/dev/rebuttal/neurips/www/mnist_6to0_4.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_4.png rename to dev/rebuttal/neurips/www/mnist_6to0_4.png diff --git a/dev/rebuttal/www/mnist_6to0_6.png b/dev/rebuttal/neurips/www/mnist_6to0_6.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_6.png rename to dev/rebuttal/neurips/www/mnist_6to0_6.png diff --git a/dev/rebuttal/www/mnist_6to0_7.png b/dev/rebuttal/neurips/www/mnist_6to0_7.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_7.png rename to dev/rebuttal/neurips/www/mnist_6to0_7.png diff --git a/dev/rebuttal/www/mnist_6to0_8.png b/dev/rebuttal/neurips/www/mnist_6to0_8.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_8.png rename to dev/rebuttal/neurips/www/mnist_6to0_8.png diff --git a/dev/rebuttal/www/mnist_6to0_9.png b/dev/rebuttal/neurips/www/mnist_6to0_9.png similarity index 100% rename from dev/rebuttal/www/mnist_6to0_9.png rename to dev/rebuttal/neurips/www/mnist_6to0_9.png diff --git a/dev/rebuttal/www/mnist_6to5_19.png b/dev/rebuttal/neurips/www/mnist_6to5_19.png similarity index 100% rename from dev/rebuttal/www/mnist_6to5_19.png rename to dev/rebuttal/neurips/www/mnist_6to5_19.png diff --git a/dev/rebuttal/www/mnist_6to5_20.png b/dev/rebuttal/neurips/www/mnist_6to5_20.png similarity index 100% rename from dev/rebuttal/www/mnist_6to5_20.png rename to dev/rebuttal/neurips/www/mnist_6to5_20.png diff --git a/dev/rebuttal/www/mnist_6to5_21.png b/dev/rebuttal/neurips/www/mnist_6to5_21.png similarity index 100% rename from dev/rebuttal/www/mnist_6to5_21.png rename to dev/rebuttal/neurips/www/mnist_6to5_21.png diff --git a/dev/rebuttal/www/mnist_7to2_4.png b/dev/rebuttal/neurips/www/mnist_7to2_4.png similarity index 100% rename from dev/rebuttal/www/mnist_7to2_4.png rename to dev/rebuttal/neurips/www/mnist_7to2_4.png diff --git a/dev/rebuttal/www/mnist_7to2_5.png b/dev/rebuttal/neurips/www/mnist_7to2_5.png similarity index 100% rename from dev/rebuttal/www/mnist_7to2_5.png rename to dev/rebuttal/neurips/www/mnist_7to2_5.png diff --git a/dev/rebuttal/www/mnist_7to2_6.png b/dev/rebuttal/neurips/www/mnist_7to2_6.png similarity index 100% rename from dev/rebuttal/www/mnist_7to2_6.png rename to dev/rebuttal/neurips/www/mnist_7to2_6.png diff --git a/dev/rebuttal/www/mnist_9to7_1.png b/dev/rebuttal/neurips/www/mnist_9to7_1.png similarity index 100% rename from dev/rebuttal/www/mnist_9to7_1.png rename to dev/rebuttal/neurips/www/mnist_9to7_1.png diff --git a/dev/rebuttal/www/mnist_9to7_2.png b/dev/rebuttal/neurips/www/mnist_9to7_2.png similarity index 100% rename from dev/rebuttal/www/mnist_9to7_2.png rename to dev/rebuttal/neurips/www/mnist_9to7_2.png diff --git a/dev/rebuttal/www/mnist_9to7_3.png b/dev/rebuttal/neurips/www/mnist_9to7_3.png similarity index 100% rename from dev/rebuttal/www/mnist_9to7_3.png rename to dev/rebuttal/neurips/www/mnist_9to7_3.png diff --git a/dev/rebuttal/www/mnist_9to7_4.png b/dev/rebuttal/neurips/www/mnist_9to7_4.png similarity index 100% rename from dev/rebuttal/www/mnist_9to7_4.png rename to dev/rebuttal/neurips/www/mnist_9to7_4.png diff --git a/dev/rebuttal/www/mnist_9to7_5.png b/dev/rebuttal/neurips/www/mnist_9to7_5.png similarity index 100% rename from dev/rebuttal/www/mnist_9to7_5.png rename to dev/rebuttal/neurips/www/mnist_9to7_5.png diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index e39e6d98..4ad990c6 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.3" +julia_version = "1.9.4" manifest_format = "2.0" project_hash = "4b0671d5fb3c16506a733dc9942e46bad62320c0" @@ -28,9 +28,9 @@ version = "0.4.4" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +git-tree-sha1 = "02f731463748db57cc2ebfbd9fbc9ce8280d3433" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" +version = "3.7.1" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -65,9 +65,9 @@ version = "3.5.1+1" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" +git-tree-sha1 = "247efbccf92448be332d154d6ca56b9fcdd93c31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.11" +version = "7.6.1" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" @@ -85,12 +85,6 @@ version = "7.4.11" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -158,9 +152,9 @@ uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" [[deps.BitFlags]] -git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" +git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.7" +version = "0.1.8" [[deps.BitTwiddlingConvenienceFunctions]] deps = ["Static"] @@ -197,16 +191,21 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.11" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "968c1365e2992824c3e7a794e30907483f8469a9" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "Statistics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "76582ae19006b1186e87dadd781747f76cead72c" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "4.4.1" +version = "5.1.1" +weakdeps = ["ChainRulesCore", "SpecialFunctions"] + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + SpecialFunctionsExt = "SpecialFunctions" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" +git-tree-sha1 = "1e42ef1bdb45487ff28de16182c0df4920181dc3" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.5.0+1" +version = "0.7.0+0" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] @@ -216,9 +215,9 @@ version = "0.2.2" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" +git-tree-sha1 = "9704e50c9158cf8896c2776b8dbc5edd136caf80" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.6.0+0" +version = "0.10.1+0" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -259,9 +258,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"] [[deps.CategoricalDistributions]] deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] -git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa" +git-tree-sha1 = "3124343a1b0c9a2f5fdc1d9bcc633ba11735a4c4" uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.11" +version = "0.1.13" [deps.CategoricalDistributions.extensions] UnivariateFiniteDisplayExt = "UnicodePlots" @@ -276,15 +275,19 @@ version = "0.5.0" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "dbeca245b0680f5393b4e6c40dcead7230ab0b3b" +git-tree-sha1 = "006cc7170be3e0fa02ccac6d4164a1eee1fc8c27" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.54.0" +version = "1.58.0" [[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "e0af648f0692ec1691b5d094b8724ba1346281cf" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" +version = "1.18.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" [[deps.Chemfiles]] deps = ["AtomsBase", "Chemfiles_jll", "DocStringExtensions", "PeriodicTable", "Unitful", "UnitfulAtomic"] @@ -306,21 +309,21 @@ version = "0.1.12" [[deps.Clustering]] deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "b86ac2c5543660d238957dbde5ac04520ae977a7" +git-tree-sha1 = "05f9816a77231b07e634ab8715ba50e5249d6f76" uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" -version = "0.15.4" +version = "0.15.5" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "a1296f0fe01a4c3f9bf0dc2934efbf4416f5db31" +git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.4" +version = "1.3.5" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" +git-tree-sha1 = "cd67fc487743b2f0fd4380d4cbd3a24660d0eec8" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.2" +version = "0.7.3" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] @@ -335,14 +338,10 @@ uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" version = "0.11.4" [[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] +git-tree-sha1 = "600cc5508d66b78aae350f7accdb58763ac18589" uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" +version = "0.9.10" [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] @@ -363,9 +362,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["UUIDs"] -git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" +git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.9.0" +version = "4.10.1" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -394,9 +393,9 @@ version = "0.3.2" [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] -git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" +git-tree-sha1 = "8cfa272e8bdedfa88b6aefbbca7c19f1befac519" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.2.1" +version = "2.3.0" [[deps.ConformalPrediction]] deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"] @@ -434,19 +433,15 @@ version = "0.6.3" [[deps.CounterfactualExplanations]] deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "14da4a8ea118b96c2477b05d5bc1c353c1d80e79" +git-tree-sha1 = "88bbfd76d8d531becf02596c94fb9975545b0a96" +repo-rev = "main" +repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.28" +version = "0.1.31" +weakdeps = ["MPI"] [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" - PythonCallExt = "PythonCall" - RCallExt = "RCall" - - [deps.CounterfactualExplanations.weakdeps] - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" - RCall = "6f49c342-dc21-5d91-9882-a32aef131414" [[deps.CpuId]] deps = ["Markdown"] @@ -498,9 +493,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DecisionTree]] deps = ["AbstractTrees", "DelimitedFiles", "LinearAlgebra", "Random", "ScikitLearnBase", "Statistics"] -git-tree-sha1 = "c6475a3ccad06cb1c2ebc0740c1bb4fe5a0731b7" +git-tree-sha1 = "526ca14aaaf2d5a0e242f3a8a7966eb9065d7d78" uuid = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb" -version = "0.12.3" +version = "0.12.4" [[deps.DefineSingletons]] git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" @@ -527,12 +522,13 @@ version = "1.15.1" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" +git-tree-sha1 = "66c4c81f259586e8f002eacebc177e1fb06363b0" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.9" -weakdeps = ["SparseArrays"] +version = "0.10.11" +weakdeps = ["ChainRulesCore", "SparseArrays"] [deps.Distances.extensions] + DistancesChainRulesCoreExt = "ChainRulesCore" DistancesSparseArraysExt = "SparseArrays" [[deps.Distributed]] @@ -540,18 +536,20 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "a6c00f894f24460379cb7136633cef54ac9f6f4a" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.100" +version = "0.25.103" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" [deps.Distributions.weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -595,10 +593,14 @@ uuid = "2702e6a9-849d-5ed8-8c21-79e8b8f9ee43" version = "0.0.20230411+0" [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "a1fa1d1743478394a0a7188d054b67546e4ca143" +deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "f08d64339d7259b0c69a00a1e321dc6da79672ea" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.16.1" +version = "0.16.5" +weakdeps = ["CUDA"] + + [deps.EvoTrees.extensions] + EvoTreesCUDAExt = "CUDA" [[deps.ExceptionUnwrapping]] deps = ["Test"] @@ -618,9 +620,9 @@ uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" version = "0.1.10" [[deps.Extents]] -git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" +git-tree-sha1 = "2140cd04483da90b2da7f99b2add0750504fc39c" uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" -version = "0.1.1" +version = "0.1.2" [[deps.FFMPEG]] deps = ["FFMPEG_jll"] @@ -672,21 +674,22 @@ version = "1.16.1" [[deps.FilePathsBase]] deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" +version = "0.9.21" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random"] -git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" +git-tree-sha1 = "25a10f2b86118664293062705fd9c7e2eda881a2" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.6.1" -weakdeps = ["SparseArrays", "Statistics"] +version = "1.9.2" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" FillArraysSparseArraysExt = "SparseArrays" FillArraysStatisticsExt = "Statistics" @@ -697,10 +700,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "b97c3fc4f3628b8835d83789b09382961a254da4" +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "75e3b3732929e880e7fd121e8a4e4dd5e1bfeaee" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.6" +version = "0.14.7" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -766,9 +769,9 @@ version = "3.3.8+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" +git-tree-sha1 = "85d7fb51afb3def5dcb85ad31c3707795c8bccc1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.8.1" +version = "9.1.0" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -778,9 +781,9 @@ version = "0.1.5" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" +git-tree-sha1 = "a846f297ce9d09ccba02ead0cae70690e072a119" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.21.4" +version = "0.25.0" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] @@ -796,15 +799,15 @@ version = "0.72.8+0" [[deps.GZip]] deps = ["Libdl", "Zlib_jll"] -git-tree-sha1 = "6388a2d8e409ce23de7d03a7c73d83c5753b3eb2" +git-tree-sha1 = "0085ccd5ec327c077ec5b91a5f937b759810ba62" uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.6.1" +version = "0.6.2" [[deps.GeoInterface]] deps = ["Extents"] -git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab" +git-tree-sha1 = "d53480c0793b13341c40199190f92c611aa2e93c" uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.1" +version = "1.3.2" [[deps.GeometryBasics]] deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] @@ -843,9 +846,9 @@ version = "1.3.14+0" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4" +git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.8.0" +version = "1.9.0" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -853,10 +856,14 @@ uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" version = "1.0.2" [[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "114e20044677badbc631ee6fdc80a67920561a29" +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "26407bd1c60129062cec9da63dc7d08251544d53" uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.16.16" +version = "0.17.1" +weakdeps = ["MPI"] + + [deps.HDF5.extensions] + MPIExt = "MPI" [[deps.HDF5_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] @@ -866,9 +873,9 @@ version = "1.12.2+2" [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "5eab648309e2e060198b45820af1a37182de3cce" +git-tree-sha1 = "abbbb9ec3afd783a7cbd82ef01dcd088ea051398" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.10.0" +version = "1.10.1" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -902,9 +909,9 @@ version = "0.3.23" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" +git-tree-sha1 = "8aa91235360659ca7560db43a7d57541120aa31d" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.10" +version = "0.4.11" [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" @@ -919,9 +926,9 @@ version = "0.6.11" [[deps.ImageBase]] deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "eb49b82c172811fd2c86759fa0553a2221feb909" +git-tree-sha1 = "b51bb8cae22c66d0f6357e3bcb6363145ef20835" uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.7" +version = "0.1.5" [[deps.ImageBinarization]] deps = ["HistogramThresholding", "ImageCore", "LinearAlgebra", "Polynomials", "Reexport", "Statistics"] @@ -936,10 +943,10 @@ uuid = "f332f351-ec65-5f6a-b3d1-319c6670881a" version = "0.3.12" [[deps.ImageCore]] -deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "PrecompileTools", "Reexport"] -git-tree-sha1 = "fc5d1d3443a124fde6e92d0260cd9e064eba69f8" +deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] +git-tree-sha1 = "acf614720ef026d38400b3817614c45882d75500" uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.10.1" +version = "0.9.4" [[deps.ImageCorners]] deps = ["ImageCore", "ImageFiltering", "PrecompileTools", "StaticArrays", "StatsBase"] @@ -955,15 +962,15 @@ version = "0.2.17" [[deps.ImageFiltering]] deps = ["CatIndices", "ComputationalResources", "DataStructures", "FFTViews", "FFTW", "ImageBase", "ImageCore", "LinearAlgebra", "OffsetArrays", "PrecompileTools", "Reexport", "SparseArrays", "StaticArrays", "Statistics", "TiledIteration"] -git-tree-sha1 = "432ae2b430a18c58eb7eca9ef8d0f2db90bc749c" +git-tree-sha1 = "3447781d4c80dbe6d71d239f7cfb1f8049d4c84f" uuid = "6a3955dd-da59-5b1f-98d4-e7296123deb5" -version = "0.7.8" +version = "0.7.6" [[deps.ImageIO]] -deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] -git-tree-sha1 = "bca20b2f5d00c4fbc192c3212da8fa79f4688009" +deps = ["FileIO", "Netpbm", "PNGFiles"] +git-tree-sha1 = "0d6d09c28d67611c68e25af0c2df7269c82b73c7" uuid = "82e4d734-157c-48bb-816b-45c225c6df19" -version = "0.6.7" +version = "0.4.1" [[deps.ImageMagick]] deps = ["FileIO", "ImageCore", "ImageMagick_jll", "InteractiveUtils"] @@ -997,9 +1004,9 @@ version = "0.3.7" [[deps.ImageSegmentation]] deps = ["Clustering", "DataStructures", "Distances", "Graphs", "ImageCore", "ImageFiltering", "ImageMorphology", "LinearAlgebra", "MetaGraphs", "RegionTrees", "SimpleWeightedGraphs", "StaticArrays", "Statistics"] -git-tree-sha1 = "3ff0ca203501c3eedde3c6fa7fd76b703c336b5f" +git-tree-sha1 = "44664eea5408828c03e5addb84fa4f916132fc26" uuid = "80713f31-8817-5129-9cf8-209ff8fb23e1" -version = "1.8.2" +version = "1.8.1" [[deps.ImageShow]] deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] @@ -1019,21 +1026,15 @@ git-tree-sha1 = "d438268ed7a665f8322572be0dabda83634d5f45" uuid = "916415d5-f1e6-5110-898d-aaa5f9f070e0" version = "0.26.0" -[[deps.Imath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3d09a9f60edf77f8a4d99f9e015e8fbf9989605d" -uuid = "905a6f67-0a94-5f89-b386-d35d92009cd1" -version = "3.1.7+0" - [[deps.IndirectArrays]] git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" version = "1.0.0" [[deps.Inflate]] -git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" +git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.3" +version = "0.1.4" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -1053,10 +1054,10 @@ uuid = "1d092043-8f09-5a30-832f-7509e371ab51" version = "0.1.5" [[deps.IntelOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ad37c091f7d7daf900963171600d7c1c5c3ede32" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "31d6adb719886d4e32e38197aae466e98881320b" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.2.0+0" +version = "2024.0.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -1076,9 +1077,9 @@ version = "0.14.7" [[deps.IntervalSets]] deps = ["Dates", "Random"] -git-tree-sha1 = "8e59ea773deee525c99a8018409f64f19fb719e6" +git-tree-sha1 = "3d8866c029dd6b16e69e0d4a939c4dfcb98fac47" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.7" +version = "0.7.8" weakdeps = ["Statistics"] [deps.IntervalSets.extensions] @@ -1111,16 +1112,16 @@ uuid = "82899510-4779-5014-852e-03e436cf321d" version = "1.0.0" [[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] -git-tree-sha1 = "c11d691a0dc8e90acfa4740d293ade57f68bfdbb" +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "PrecompileTools", "Printf", "Reexport", "Requires", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "9bbb5130d3b4fa52846546bca4791ecbdfb52730" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.35" +version = "0.4.38" [[deps.JLFzf]] deps = ["Pipe", "REPL", "Random", "fzf_jll"] -git-tree-sha1 = "f377670cda23b6b7c1c0b3893e37451c5c1a2185" +git-tree-sha1 = "a53ebe394b71470c7f97c2e7e170d51df21b17af" uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" -version = "0.1.5" +version = "0.1.7" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] @@ -1146,17 +1147,17 @@ git-tree-sha1 = "1b4a8ae085fa69edf41b66bd6a18bc7cf37465c2" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.3" -[[deps.JpegTurbo]] -deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "327713faef2a3e5c80f96bf38d1fa26f7a6ae29e" -uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.3" - [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" +git-tree-sha1 = "60b1194df0a3298f460063de985eae7b01bc011a" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "2.1.91+0" +version = "3.0.1+0" + +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -1166,9 +1167,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" +git-tree-sha1 = "81de11f7b02465435aab0ed7e935965bfcb3072b" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" +version = "0.9.14" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -1189,16 +1190,25 @@ uuid = "88015f11-f218-50d7-93a8-a6af411a945d" version = "3.0.0+1" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "0678579657515e88b6632a3a482d39adcbb80445" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.2.1" +version = "6.4.1" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa" +git-tree-sha1 = "98eaee04d96d973e79c25d49167668c5c8fb50e2" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.25+0" +version = "0.0.27+1" + +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1213,9 +1223,9 @@ uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" version = "2.10.1+0" [[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" +version = "1.3.1" [[deps.LaplaceRedux]] deps = ["CSV", "Compat", "ComputationalResources", "DataFrames", "Flux", "LinearAlgebra", "MLJ", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"] @@ -1245,9 +1255,9 @@ version = "1.9.0" [[deps.LayoutPointers]] deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "88b8f66b604da079a627b6fb2860d3704a6729a1" +git-tree-sha1 = "62edfee3211981241b57ff1cedf4d74d79519277" uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.14" +version = "0.1.15" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -1261,12 +1271,12 @@ version = "0.3.1" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" +version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" +version = "8.4.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -1275,7 +1285,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" +version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -1353,15 +1363,15 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.2" +version = "1.0.3" [[deps.LoopVectorization]] -deps = ["ArrayInterface", "ArrayInterfaceCore", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "c88a4afe1703d731b1c4fdf4e3c7e77e3b176ea2" +deps = ["ArrayInterface", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] +git-tree-sha1 = "0f5648fbae0d015e3abe5867bca2b362f67a5894" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.165" +version = "0.12.166" weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] [deps.LoopVectorization.extensions] @@ -1380,9 +1390,9 @@ weakdeps = ["CategoricalArrays"] [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "79fd0b5ee384caf8ebba6c8fb3f365ca3e2c5493" +git-tree-sha1 = "ed1cf0a322d78cee07718bed5fd945e2218c35a1" uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.5" +version = "0.10.6" [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -1392,9 +1402,9 @@ version = "2023.2.0+0" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "10bc70e4c875f1b2ca65cef3ef9ebe705ef936b5" +git-tree-sha1 = "aab72207b3c687086a400be710650a57494992bd" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.13" +version = "0.7.14" [[deps.MLFlowClient]] deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] @@ -1446,15 +1456,15 @@ version = "0.5.1" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "03ae109be87f460fe3c96b8a0dbbf9c7bf840bd5" +git-tree-sha1 = "381d99f0af76d98f50bd5512dcf96a99c13f8223" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.9.2" +version = "1.9.3" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "da9f2bfefa08d1a63167cbd9bdd862a44a1b3d9d" +git-tree-sha1 = "10d221910fc3f3eedad567178ddbca3cc0f776a3" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.11" +version = "0.16.12" [[deps.MLJTuning]] deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] @@ -1475,9 +1485,9 @@ version = "0.4.3" [[deps.MPI]] deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"] -git-tree-sha1 = "df53d0e1e0dbebf2315f4cd35e13e52ad43416c2" +git-tree-sha1 = "4e3136db3735924f96632a5b40a5979f1f53fa07" uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" -version = "0.20.15" +version = "0.20.19" [deps.MPI.extensions] AMDGPUExt = "AMDGPU" @@ -1495,9 +1505,9 @@ version = "4.1.2+0" [[deps.MPIPreferences]] deps = ["Libdl", "Preferences"] -git-tree-sha1 = "781916a2ebf2841467cda03b6f1af43e23839d85" +git-tree-sha1 = "8f6af051b9e8ec597fa09d8885ed79fd582f33c9" uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" -version = "0.1.9" +version = "0.1.10" [[deps.MPItrampoline_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] @@ -1526,10 +1536,10 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" [[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] -git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.7" +version = "1.1.9" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] @@ -1561,9 +1571,9 @@ version = "0.1.4" [[deps.MicrosoftMPI_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422" +git-tree-sha1 = "b01beb91d20b0d1312a9471a36017b5b339d26de" uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.3+4" +version = "10.1.4+1" [[deps.Missings]] deps = ["DataAPI"] @@ -1598,9 +1608,9 @@ version = "0.10.2" [[deps.Mustache]] deps = ["Printf", "Tables"] -git-tree-sha1 = "821e918c170ead5298ff84bffee41dd28929a681" +git-tree-sha1 = "a7cefa21a2ff993bff0456bf7521f46fc077ddf1" uuid = "ffc61752-8dc7-55ee-8c37-f3e9cdd09e70" -version = "1.0.17" +version = "1.0.19" [[deps.MyterialColors]] git-tree-sha1 = "01d8466fb449436348999d7c6ad740f8f853a579" @@ -1609,18 +1619,20 @@ version = "0.3.0" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "6e4e90c2e2ef091ef50b91af65fa4bb09c3d0728" +git-tree-sha1 = "ac86d2944bf7a670ac8bf0f7ec099b5898abcc09" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.6" +version = "0.9.8" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NPZ]] @@ -1629,6 +1641,18 @@ git-tree-sha1 = "60a8e272fe0c5079363b28b0953831e2dd7b7e6f" uuid = "15e1cf62-19b3-5cfa-8e77-841668bca605" version = "0.4.3" +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "8bc9ce4233be3c63f8dcd78ccaf1b63a9c0baa34" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.3" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + [[deps.NaNMath]] deps = ["OpenLibm_jll"] git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" @@ -1665,10 +1689,14 @@ uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" version = "1.1.1" [[deps.NetworkLayout]] -deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "2bfd8cd7fba3e46ce48139ae93904ee848153660" +deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "StaticArrays"] +git-tree-sha1 = "91bb2fedff8e43793650e7a677ccda6e6e6e166b" uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" -version = "0.4.5" +version = "0.4.6" +weakdeps = ["Graphs"] + + [deps.NetworkLayout.extensions] + NetworkLayoutGraphsExt = "Graphs" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" @@ -1697,18 +1725,6 @@ deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" version = "0.3.21+4" -[[deps.OpenEXR]] -deps = ["Colors", "FileIO", "OpenEXR_jll"] -git-tree-sha1 = "327f53360fdb54df7ecd01e96ef1983536d1e633" -uuid = "52e1d378-f018-4a11-a4be-720524705ac7" -version = "0.3.2" - -[[deps.OpenEXR_jll]] -deps = ["Artifacts", "Imath_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "a4ca623df1ae99d09bc9868b008262d0c0ac1e4f" -uuid = "18a262bb-aa17-5467-a713-aee519bc75cb" -version = "3.1.4+0" - [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" @@ -1722,9 +1738,9 @@ version = "0.3.1" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "f3080f4212a8ba2ceb10a34b938601b862094314" +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.5+0" +version = "4.1.6+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] @@ -1757,9 +1773,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372" version = "1.3.2+0" [[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" +version = "1.6.3" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] @@ -1768,15 +1784,15 @@ version = "10.42.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "3129380a93388e5062e946974246fe3f2e7c73e2" +git-tree-sha1 = "4e5be6bb265d33669f98eb55d2a57addd1eeb72c" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.18" +version = "0.11.30" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "9b02b27ac477cad98114584ff964e3052f656a0f" +git-tree-sha1 = "f809158b27eba0c18c269cf2a2be6ed751d3e81d" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.0" +version = "0.3.17" [[deps.PackageExtensionCompat]] git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" @@ -1798,14 +1814,15 @@ version = "0.12.3" [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" +git-tree-sha1 = "a935806434c9d4c506ba941871b327b96d41f2bf" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" +version = "2.8.0" [[deps.PartialFunctions]] -git-tree-sha1 = "b3901ea034cfd8aae57a2fa0dde0b0ea18bad1cb" +deps = ["MacroTools"] +git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" uuid = "570af359-4316-4cb7-8c74-252c00c2016b" -version = "1.1.1" +version = "1.2.0" [[deps.PeriodicTable]] deps = ["Base64", "Test", "Unitful"] @@ -1837,9 +1854,9 @@ version = "1.9.2" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "693ad322c84159a1f1f875891115423695fc08e8" +git-tree-sha1 = "52aa978f90f67ec52326ecbf3155f6fc6035b4e5" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.45" +version = "0.7.46" [[deps.PkgVersion]] deps = ["Pkg"] @@ -1930,10 +1947,10 @@ uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" version = "0.4.1" [[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "88b895d13d53b5577fd53379d913b9ab9ac82660" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.7" +version = "2.3.1" [[deps.Printf]] deps = ["Unicode"] @@ -1951,12 +1968,6 @@ git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.9.0" -[[deps.QOI]] -deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] -git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" -uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" -version = "1.0.0" - [[deps.Qt5Base_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "xkbcommon_jll"] git-tree-sha1 = "0c03844e2231e12fda4d0086fd7cbe4098ee8dc5" @@ -1971,9 +1982,9 @@ version = "2.9.1" [[deps.Quaternions]] deps = ["LinearAlgebra", "Random", "RealDot"] -git-tree-sha1 = "da095158bdc8eaccb7890f9884048555ab771019" +git-tree-sha1 = "9a46862d248ea548e340e30e2894118749dc7f51" uuid = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" -version = "0.7.4" +version = "0.7.5" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -2041,9 +2052,9 @@ version = "0.3.2" [[deps.RelocatableFolders]] deps = ["SHA", "Scratch"] -git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" +git-tree-sha1 = "ffdaf70d81cf6ff22c2b6e733c900c3321cab864" uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.0" +version = "1.0.1" [[deps.Requires]] deps = ["UUIDs"] @@ -2065,9 +2076,9 @@ version = "0.4.0+0" [[deps.Rotations]] deps = ["LinearAlgebra", "Quaternions", "Random", "StaticArrays"] -git-tree-sha1 = "0783924e4a332493f72490253ba4e668aeba1d73" +git-tree-sha1 = "792d8fd4ad770b6d517a13ebb8dadfcac79405b8" uuid = "6038ab10-8711-5258-84ad-4b1120ba62dc" -version = "1.6.0" +version = "1.6.1" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -2080,9 +2091,9 @@ version = "0.1.0" [[deps.SLEEFPirates]] deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "4b8586aece42bee682399c4c4aee95446aa5cd19" +git-tree-sha1 = "3aac6d68c5e57449f5b9b865c9ba50ac2970c4cf" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.39" +version = "0.6.42" [[deps.ScientificTypes]] deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] @@ -2103,15 +2114,15 @@ version = "0.5.0" [[deps.Scratch]] deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" +version = "1.2.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" +git-tree-sha1 = "0e7508ff27ba32f26cd459474ca2ede1bc10991f" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.0" +version = "1.4.1" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -2154,26 +2165,14 @@ git-tree-sha1 = "4b33e0e081a825dbfaf314decf58fa47e53d6acb" uuid = "47aef6b3-ad0c-573a-a1e2-d07658019622" version = "1.4.0" -[[deps.Sixel]] -deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] -git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" -uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" -version = "0.1.3" - -[[deps.SnoopPrecompile]] -deps = ["Preferences"] -git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.3" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +git-tree-sha1 = "5165dfb9fd131cf0c6957a3a7605dede376e7b63" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" +version = "1.2.0" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] @@ -2231,10 +2230,10 @@ weakdeps = ["OffsetArrays", "StaticArrays"] StaticArrayInterfaceStaticArraysExt = "StaticArrays" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "d5fb407ec3179063214bc6277712928ba78459e2" +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "5ef59aea6f18c25168842bded46b16662141ab87" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.4" +version = "1.7.0" weakdeps = ["Statistics"] [deps.StaticArrays.extensions] @@ -2334,9 +2333,9 @@ version = "1.0.1" [[deps.Tables]] deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "a1f34829d5ac0ef499f6d84428bd6b4c71f02ead" +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.0" +version = "1.11.1" [[deps.Tar]] deps = ["ArgTools", "SHA"] @@ -2365,12 +2364,6 @@ git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" version = "0.5.2" -[[deps.TiffImages]] -deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] -git-tree-sha1 = "b7dc44cb005a7ef743b8fe98970afef003efdce7" -uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.6.6" - [[deps.TiledIteration]] deps = ["OffsetArrays", "StaticArrayInterface"] git-tree-sha1 = "1176cc31e867217b06928e2f140c90bd1bc88283" @@ -2384,16 +2377,19 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" version = "0.5.23" [[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" +git-tree-sha1 = "1fbeaaca45801b4ba17c251dd8603ef24801dd84" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.13" +version = "0.10.2" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" +git-tree-sha1 = "e579d3c991938fecbb225699e8f611fa3fbf2141" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.78" +version = "0.4.79" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -2410,10 +2406,22 @@ version = "0.4.78" Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" [[deps.Tullio]] -deps = ["ChainRulesCore", "DiffRules", "LinearAlgebra", "Requires"] -git-tree-sha1 = "7871a39eac745697ee512a87eeff06a048a7905b" +deps = ["DiffRules", "LinearAlgebra", "Requires"] +git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3" uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" -version = "0.3.5" +version = "0.3.7" + + [deps.Tullio.extensions] + TullioCUDAExt = "CUDA" + TullioChainRulesCoreExt = "ChainRulesCore" + TullioFillArraysExt = "FillArrays" + TullioTrackerExt = "Tracker" + + [deps.Tullio.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [[deps.TupleTools]] git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d" @@ -2421,9 +2429,9 @@ uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.4.3" [[deps.URIs]] -git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.0" +version = "1.5.1" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -2445,9 +2453,9 @@ version = "0.4.1" [[deps.Unitful]] deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "a72d22c7e13fe2de562feda8645aa134712a87ee" +git-tree-sha1 = "3c793be6df9dd77a0cf49d80984ef9ff996948fa" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.17.0" +version = "1.19.0" [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" @@ -2487,9 +2495,9 @@ version = "0.2.0" [[deps.VectorizationBase]] deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "b182207d4af54ac64cbc71797765068fdeff475d" +git-tree-sha1 = "7209df901e6ed7489fe9b7aa3e46fb788e15db85" uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.21.64" +version = "0.21.65" [[deps.Wayland_jll]] deps = ["Artifacts", "EpollShim_jll", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] @@ -2511,9 +2519,9 @@ version = "1.4.2" [[deps.WoodburyMatrices]] deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" +git-tree-sha1 = "5f24e158cf4cee437052371455fe361f526da062" uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" -version = "0.5.5" +version = "0.5.6" [[deps.WorkerUtilities]] git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" @@ -2522,9 +2530,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "04a51d15436a572301b5abbb9d099713327e9fc4" +git-tree-sha1 = "801cbe47eae69adc50f36c3caec4758d2650741b" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.10.4+0" +version = "2.12.2+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] @@ -2677,9 +2685,9 @@ version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "b97c927497c1de55a78dc9030f6068be5d83ef80" +git-tree-sha1 = "5ded212acd815612df112bb895ef3910c5a03f57" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.64" +version = "0.6.67" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -2693,21 +2701,21 @@ version = "0.6.64" [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" +git-tree-sha1 = "9d749cd449fb448aeca4feee9a2f4186dbb5d184" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" +version = "0.2.4" [[deps.cuDNN]] -deps = ["CEnum", "CUDA", "CUDNN_jll"] -git-tree-sha1 = "5a1ba43303c62f4a09b0d6751422de03424ab0cd" +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "c092c26591a851083ed3358890d0d916c58dde62" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.1.1" +version = "1.2.1" [[deps.fzf_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a68c9655fbe6dfcab3d972808f1aafec151ce3f8" uuid = "214eeab7-80f7-51ab-84ad-2988db7cef09" -version = "0.29.0+0" +version = "0.43.0+0" [[deps.ghr_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -2739,16 +2747,10 @@ uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" version = "2.0.2+0" [[deps.libpng_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] +git-tree-sha1 = "93284c28274d9e75218a416c65ec49d0e0fcdf3d" uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.38+0" - -[[deps.libsixel_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "libpng_jll"] -git-tree-sha1 = "d4f63314c8aa1e48cd22aa0c17ed76cd1ae48c3c" -uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" -version = "1.10.3+0" +version = "1.6.40+0" [[deps.libvorbis_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] @@ -2759,7 +2761,7 @@ version = "1.3.7+1" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" +version = "1.52.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 61164009..0c313370 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -1,14 +1,15 @@ function default_generators(; - Λ::AbstractArray=[0.25, 0.75, 0.75], - Λ_Δ::AbstractArray=Λ, - use_variants::Bool=true, - use_class_loss::Bool=false, - opt=Flux.Optimise.Descent(0.01), - niter_eccco::Union{Nothing,Int}=nothing, - nsamples::Union{Nothing,Int}=nothing, - nmin::Union{Nothing,Int}=nothing, - reg_strength::Real=0.5, - dim_reduction::Bool=false, + Λ::AbstractArray = [0.25, 0.75, 0.75], + Λ_Δ::AbstractArray = Λ, + use_variants::Bool = true, + use_class_loss::Bool = false, + opt = Flux.Optimise.Descent(0.01), + niter_eccco::Union{Nothing,Int} = nothing, + nsamples::Union{Nothing,Int} = nothing, + nmin::Union{Nothing,Int} = nothing, + reg_strength::Real = 0.5, + decay::Tuple = (0.1, 5), + dim_reduction::Bool = false, ) @info "Begin benchmarking counterfactual explanations." @@ -17,23 +18,91 @@ function default_generators(; if use_variants generator_dict = Dict( - "Wachter" => WachterGenerator(λ=λ₁, opt=opt), - "REVISE" => REVISEGenerator(λ=λ₁, opt=opt), - "Schut" => GreedyGenerator(η=opt.eta), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), - "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), - "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), - "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength = reg_strength), - "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), - "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), + "Wachter" => WachterGenerator(λ = λ₁, opt = opt), + "REVISE" => REVISEGenerator(λ = λ₁, opt = opt), + "Schut" => GreedyGenerator(η=get_learning_rate(opt)), + "ECCCo" => ECCCoGenerator( + λ = Λ, + opt = opt, + use_class_loss = use_class_loss, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + ), + "ECCCo (no CP)" => ECCCoGenerator( + λ = [λ₁, 0.0, λ₃], + opt = opt, + use_class_loss = use_class_loss, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + ), + "ECCCo (no EBM)" => ECCCoGenerator( + λ = [λ₁, λ₂, 0.0], + opt = opt, + use_class_loss = use_class_loss, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + ), + "ECCCo-Δ" => ECCCoGenerator( + λ = Λ_Δ, + opt = opt, + use_class_loss = use_class_loss, + use_energy_delta = true, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + reg_strength = reg_strength, + decay = decay, + ), + "ECCCo-Δ (no CP)" => ECCCoGenerator( + λ = [λ₁_Δ, 0.0, λ₃_Δ], + opt = opt, + use_class_loss = use_class_loss, + use_energy_delta = true, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + reg_strength = reg_strength, + decay = decay, + ), + "ECCCo-Δ (no EBM)" => ECCCoGenerator( + λ = [λ₁_Δ, λ₂_Δ, 0.0], + opt = opt, + use_class_loss = use_class_loss, + use_energy_delta = true, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + reg_strength = reg_strength, + decay = decay, + ), ) else generator_dict = Dict( - "Wachter" => WachterGenerator(λ=λ₁, opt=opt), - "REVISE" => REVISEGenerator(λ=λ₁, opt=opt), + "Wachter" => WachterGenerator(λ = λ₁, opt = opt), + "REVISE" => REVISEGenerator(λ = λ₁, opt = opt), "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), - "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), + "ECCCo" => ECCCoGenerator( + λ = Λ, + opt = opt, + use_class_loss = use_class_loss, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + ), + "ECCCo-Δ" => ECCCoGenerator( + λ = Λ_Δ, + opt = opt, + use_class_loss = use_class_loss, + use_energy_delta = true, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + reg_strength = reg_strength, + decay = decay, + ), ) end @@ -42,9 +111,17 @@ function default_generators(; if dim_reduction eccco_latent = Dict( "ECCCo-Δ (latent)" => ECCCoGenerator( - λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength, - dim_reduction=dim_reduction - ) + λ = Λ_Δ, + opt = opt, + use_class_loss = use_class_loss, + use_energy_delta = true, + nsamples = nsamples, + nmin = nmin, + niter = niter_eccco, + reg_strength = reg_strength, + decay = decay, + dim_reduction = dim_reduction, + ), ) generator_dict = merge(generator_dict, eccco_latent) end @@ -72,31 +149,36 @@ function run_benchmark(exper::Experiment, model_dict::Dict) # Benchmark generators: if isnothing(generator_dict) generator_dict = default_generators(; - Λ=exper.Λ, - Λ_Δ=exper.Λ_Δ, - use_variants=exper.use_variants, - use_class_loss=exper.use_class_loss, - opt=exper.opt, - nsamples=exper.nsamples, - nmin=exper.nmin, - reg_strength=exper.reg_strength, - dim_reduction=exper.dim_reduction, + Λ = exper.Λ, + Λ_Δ = exper.Λ_Δ, + use_variants = exper.use_variants, + use_class_loss = exper.use_class_loss, + opt = exper.opt, + nsamples = exper.nsamples, + nmin = exper.nmin, + reg_strength = exper.reg_strength, + dim_reduction = exper.dim_reduction, + decay = exper.decay, ) end # Run benchmark: + storage_path = mkpath(joinpath(exper.output_path, "interim_$(dataname)")) bmk = benchmark( counterfactual_data; - models=model_dict, - generators=generator_dict, - measure=measures, - suppress_training=true, dataname=dataname, - n_individuals=n_individuals, - initialization=:identity, - converge_when=:generator_conditions, - parallelizer=parallelizer, - store_ce=exper.store_ce, + models = model_dict, + generators = generator_dict, + measure = measures, + suppress_training = true, + dataname = dataname, + n_individuals = n_individuals, + initialization = :identity, + converge_when = :generator_conditions, + parallelizer = parallelizer, + store_ce = exper.store_ce, + n_runs = exper.n_runs, + vertical_splits = VERTICAL_SPLITS, + storage_path = storage_path ) return bmk, generator_dict end - diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl index 4a486a1b..e8a8db16 100644 --- a/experiments/california_housing.jl +++ b/experiments/california_housing.jl @@ -1,18 +1,26 @@ # Data: dataname = "California Housing" -counterfactual_data, test_data = train_test_split(load_california_housing(nothing); test_size=TEST_SIZE) +counterfactual_data, test_data = + train_test_split(load_california_housing(nothing); test_size = TEST_SIZE) + +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) # VAE: using CounterfactualExplanations.GenerativeModels: VAE, train! X = counterfactual_data.X y = counterfactual_data.output_encoder.y -vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +vae = VAE(size(X, 1); nll = Flux.Losses.mse, epochs = 100, λ = 0.01, latent_dim = 5) train!(vae, X, y) counterfactual_data.generative_model = vae # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit( + MultivariateStats.PCA, + counterfactual_data.X; + maxoutdim = maxout_dim, +); # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -22,29 +30,29 @@ tuning_params = DEFAULT_GENERATOR_TUNING_LARGE # Parameter choices: params = ( - n_hidden=32, - activation=Flux.relu, - builder=default_builder(n_hidden=32, n_layers=3, activation=Flux.relu), - α=[1.0, 1.0, 1e-1], - sampling_batch_size=10, - sampling_steps=30, - use_ensembling=true, - opt=Flux.Optimise.Descent(0.05), - Λ=[0.1, 0.1, 0.1], - reg_strength=0.0, - n_individuals=25, - dim_reduction=true, + n_hidden = 32, + activation = Flux.relu, + builder = default_builder(n_hidden = 32, n_layers = 3, activation = Flux.relu), + α = [1.0, 1.0, 1e-1], + sampling_batch_size = 10, + sampling_steps = 30, + use_ensembling = true, + opt = Flux.Optimise.Descent(0.05), + Λ = [0.1, 0.1, 0.1], + reg_strength = 0.0, + dim_reduction = true, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -56,9 +64,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/circles.jl b/experiments/circles.jl index 167ae9ae..0d4a6374 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -1,7 +1,11 @@ # Data: dataname = "Circles" n_obs = Int(1000 / (1.0 - TEST_SIZE)) -counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE) +counterfactual_data, test_data = + train_test_split(load_circles(n_obs; noise = 0.05, factor = 0.5); test_size = TEST_SIZE) + +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_SMALL @@ -12,27 +16,28 @@ tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: # These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( - use_tuned=false, - n_hidden=32, - n_layers=3, - activation=Flux.swish, - epochs=100, - α=[1.0, 1.0, 1e-2], - sampling_steps=30, - opt=Flux.Optimise.Descent(0.05), - Λ=[0.1, 0.1, 0.05], - reg_strength=1.0, + use_tuned = false, + n_hidden = 32, + n_layers = 3, + activation = Flux.swish, + epochs = 100, + α = [1.0, 1.0, 1e-2], + sampling_steps = 30, + opt = Flux.Optimise.Descent(0.05), + Λ = [0.1, 0.1, 0.05], + reg_strength = 1.0, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -44,9 +49,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/daic/generators/california_housing.sh b/experiments/daic/generators/california_housing.sh index a6676447..ec8fddd1 100644 --- a/experiments/daic/generators/california_housing.sh +++ b/experiments/daic/generators/california_housing.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="California Housing (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:40:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/california_housing.log diff --git a/experiments/daic/generators/circles.sh b/experiments/daic/generators/circles.sh new file mode 100644 index 00000000..2a2ba161 --- /dev/null +++ b/experiments/daic/generators/circles.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Circles (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=circles output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/circles.log diff --git a/experiments/daic/generators/credit_default.sh b/experiments/daic/generators/credit_default.sh deleted file mode 100644 index a6676447..00000000 --- a/experiments/daic/generators/credit_default.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="California Housing (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module use /opt/insy/modulefiles # Use DAIC INSY software collection -module load openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log diff --git a/experiments/daic/generators/fmnist.sh b/experiments/daic/generators/fmnist.sh index 6131eb37..83ed1794 100644 --- a/experiments/daic/generators/fmnist.sh +++ b/experiments/daic/generators/fmnist.sh @@ -1,9 +1,9 @@ #!/bin/bash -#SBATCH --job-name="Fashion-MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="Fashion MNIST - Grid (ECCCo)" +#SBATCH --time=02:00:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=8GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=fmnist retrain output_path=results threaded mpi > experiments/fmnist.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=fmnist output_path=results mpi grid_search threaded n_individuals=25 n_each=32 > experiments/logs/grid_search_fmnist.log \ No newline at end of file diff --git a/experiments/daic/generators/german_credit.sh b/experiments/daic/generators/german_credit.sh index 5ddba96e..cf7c784c 100644 --- a/experiments/daic/generators/german_credit.sh +++ b/experiments/daic/generators/german_credit.sh @@ -1,9 +1,9 @@ #!/bin/bash #SBATCH --job-name="German Credit (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:40:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=german_credit output_path=results mpi > experiments/german_credit.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=german_credit output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/german_credit.log diff --git a/experiments/daic/generators/gmsc.sh b/experiments/daic/generators/gmsc.sh index a5fc350c..791c6033 100644 --- a/experiments/daic/generators/gmsc.sh +++ b/experiments/daic/generators/gmsc.sh @@ -1,9 +1,9 @@ #!/bin/bash #SBATCH --job-name="GMSC (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:40:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc output_path=results mpi > experiments/gmsc.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=gmsc output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/gmsc.log diff --git a/experiments/daic/generators/linearly_separable.sh b/experiments/daic/generators/linearly_separable.sh new file mode 100644 index 00000000..c630dbb3 --- /dev/null +++ b/experiments/daic/generators/linearly_separable.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Linearly Separable (ECCCo)" +#SBATCH --time=01:10:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/linearly_separable.log diff --git a/experiments/daic/generators/mnist.sh b/experiments/daic/generators/mnist.sh index 4b638dd3..4fcc9839 100644 --- a/experiments/daic/generators/mnist.sh +++ b/experiments/daic/generators/mnist.sh @@ -1,9 +1,9 @@ #!/bin/bash -#SBATCH --job-name="MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="MNIST - Grid (ECCCo)" +#SBATCH --time=02:00:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=8GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results mpi grid_search threaded n_individuals=25 n_each=32 > experiments/logs/grid_search_mnist.log \ No newline at end of file diff --git a/experiments/daic/generators/mnist_memory.sh b/experiments/daic/generators/mnist_memory.sh deleted file mode 100644 index 0e164fdb..00000000 --- a/experiments/daic/generators/mnist_memory.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=150 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=memory -#SBATCH --mem-per-cpu=64GB -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module use /opt/insy/modulefiles # Use DAIC INSY software collection -module load openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log diff --git a/experiments/daic/generators/moons.sh b/experiments/daic/generators/moons.sh new file mode 100644 index 00000000..7bed6e3b --- /dev/null +++ b/experiments/daic/generators/moons.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Moons (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=moons output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/moons.log diff --git a/experiments/daic/generators/synthetic.sh b/experiments/daic/generators/synthetic.sh deleted file mode 100644 index d5fa69ff..00000000 --- a/experiments/daic/generators/synthetic.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Synthetic (ECCCo)" -#SBATCH --time=02:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=general -#SBATCH --mem-per-cpu=4GB -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module use /opt/insy/modulefiles # Use DAIC INSY software collection -module load openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi > experiments/synthetic.log diff --git a/experiments/daic/generators/tabular.sh b/experiments/daic/generators/tabular.sh deleted file mode 100644 index da398346..00000000 --- a/experiments/daic/generators/tabular.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Tabular (ECCCo)" -#SBATCH --time=12:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module use /opt/insy/modulefiles # Use DAIC INSY software collection -module load openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit,california_housing output_path=results mpi > experiments/tabular.log diff --git a/experiments/daic/testing/cali_100_8.sh b/experiments/daic/testing/cali_100_8.sh new file mode 100644 index 00000000..0f8d2e90 --- /dev/null +++ b/experiments/daic/testing/cali_100_8.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search California Housing (ECCCo)" +#SBATCH --time=00:35:00 +#SBATCH --ntasks=14 +#SBATCH --cpus-per-task=14 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=10 threaded > experiments/grid_search_california_housing.log + diff --git a/experiments/daic/testing/cali_50_8.sh b/experiments/daic/testing/cali_50_8.sh new file mode 100644 index 00000000..34bb1102 --- /dev/null +++ b/experiments/daic/testing/cali_50_8.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search California Housing (ECCCo)" +#SBATCH --time=02:00:00 +#SBATCH --ntasks=50 +#SBATCH --cpus-per-task=1 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=8GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=10 n_each=16 threaded > experiments/grid_search_california_housing.log + diff --git a/experiments/daic/testing/lin_sep.sh b/experiments/daic/testing/lin_sep.sh new file mode 100644 index 00000000..5ade5ba4 --- /dev/null +++ b/experiments/daic/testing/lin_sep.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=5 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search n_individuals=10 threaded > experiments/grid_search_linearly_separable.log diff --git a/experiments/daic/testing/lin_sep_final.sh b/experiments/daic/testing/lin_sep_final.sh new file mode 100644 index 00000000..6884d986 --- /dev/null +++ b/experiments/daic/testing/lin_sep_final.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Linearly Separable (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks=5 +#SBATCH --cpus-per-task=4 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results_testing mpi threaded n_individuals=100 n_runs=2 > experiments/logs/linearly_separable.log diff --git a/experiments/daic/tuning/generators/tabular.sh b/experiments/daic/testing/lin_sep_no_threads.sh similarity index 55% rename from experiments/daic/tuning/generators/tabular.sh rename to experiments/daic/testing/lin_sep_no_threads.sh index cc139b8e..d69ea69a 100644 --- a/experiments/daic/tuning/generators/tabular.sh +++ b/experiments/daic/testing/lin_sep_no_threads.sh @@ -1,8 +1,8 @@ #!/bin/bash -#SBATCH --job-name="Grid-search Tabular (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=2000 +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks-per-node=20 #SBATCH --cpus-per-task=1 #SBATCH --partition=general #SBATCH --mem-per-cpu=4GB @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit output_path=results mpi grid_search > experiments/grid_search_tabular.log +source experiments/slurm_header.sh + +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search n_individuals=10 > experiments/grid_search_linearly_separable.log diff --git a/experiments/daic/testing/lin_sep_threaded.sh b/experiments/daic/testing/lin_sep_threaded.sh new file mode 100644 index 00000000..cfd2116d --- /dev/null +++ b/experiments/daic/testing/lin_sep_threaded.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=1GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search n_individuals=10 threaded > experiments/grid_search_linearly_separable.log diff --git a/experiments/daic/testing/mnist.sh b/experiments/daic/testing/mnist.sh new file mode 100644 index 00000000..352cc806 --- /dev/null +++ b/experiments/daic/testing/mnist.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#SBATCH --job-name="MNIST test (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=4 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=8GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results_testing mpi grid_search n_individuals=10 threaded n_each=32 > experiments/logs/grid_search_mnist.log + \ No newline at end of file diff --git a/experiments/daic/tuning/generators/california_housing.sh b/experiments/daic/tuning/generators/california_housing.sh index f8648ae6..7a6b5df0 100644 --- a/experiments/daic/tuning/generators/california_housing.sh +++ b/experiments/daic/tuning/generators/california_housing.sh @@ -1,9 +1,9 @@ #!/bin/bash #SBATCH --job-name="Grid-search California Housing (ECCCo)" -#SBATCH --time=01:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=02:00:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=5 > experiments/grid_search_california_housing.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_california_housing.log diff --git a/experiments/daic/tuning/generators/circles.sh b/experiments/daic/tuning/generators/circles.sh index beaa7e9b..997ce415 100644 --- a/experiments/daic/tuning/generators/circles.sh +++ b/experiments/daic/tuning/generators/circles.sh @@ -1,14 +1,16 @@ #!/bin/bash -#SBATCH --job-name="Grid-search Synthetic (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="Grid-search Circles (ECCCo)" +#SBATCH --time=01:00:00 +#SBATCH --ntasks=20 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general -#SBATCH --mem-per-cpu=4GB +#SBATCH --mem-per-cpu=2GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi grid_search > experiments/grid_search_synthetic.log \ No newline at end of file +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=circles output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_circles.log diff --git a/experiments/daic/tuning/generators/german_credit.sh b/experiments/daic/tuning/generators/german_credit.sh index 12bc660b..218692d6 100644 --- a/experiments/daic/tuning/generators/german_credit.sh +++ b/experiments/daic/tuning/generators/german_credit.sh @@ -1,14 +1,16 @@ #!/bin/bash -#SBATCH --job-name="Grid-search Germand Credit (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="Grid-search German Credit (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=german_credit output_path=results mpi grid_search > experiments/grid_search_german_credit.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=german_credit output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_german_credit.log diff --git a/experiments/daic/tuning/generators/gmsc.sh b/experiments/daic/tuning/generators/gmsc.sh index 31962abd..8f08c051 100644 --- a/experiments/daic/tuning/generators/gmsc.sh +++ b/experiments/daic/tuning/generators/gmsc.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="Grid-search GMSC (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc output_path=results mpi grid_search > experiments/grid_search_gmsc.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=gmsc output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_gmsc.log diff --git a/experiments/daic/tuning/generators/linearly_separable.sh b/experiments/daic/tuning/generators/linearly_separable.sh new file mode 100644 index 00000000..b7578a85 --- /dev/null +++ b/experiments/daic/tuning/generators/linearly_separable.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=01:00:00 +#SBATCH --ntasks=20 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_linearly_separable.log diff --git a/experiments/daic/tuning/generators/mnist.sh b/experiments/daic/tuning/generators/mnist.sh index 9e3b2c29..8dd2ada6 100644 --- a/experiments/daic/tuning/generators/mnist.sh +++ b/experiments/daic/tuning/generators/mnist.sh @@ -1,9 +1,9 @@ #!/bin/bash -#SBATCH --job-name="Grid-search MNIST (ECCCo)" -#SBATCH --time=32:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="MNIST Grid-search (ECCCo)" +#SBATCH --time=01:00:00 +#SBATCH --ntasks=40 +#SBATCH --cpus-per-task=10 #SBATCH --partition=general #SBATCH --mem-per-cpu=8GB #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. @@ -11,4 +11,6 @@ module use /opt/insy/modulefiles # Use DAIC INSY software collection module load openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi grid_search > experiments/grid_search_mnist.log \ No newline at end of file +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results mpi grid_search threaded n_individuals=10 n_each=5 > experiments/logs/grid_search_mnist.log diff --git a/experiments/daic/tuning/generators/moons.sh b/experiments/daic/tuning/generators/moons.sh new file mode 100644 index 00000000..55e16aa0 --- /dev/null +++ b/experiments/daic/tuning/generators/moons.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Moons (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=20 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=general +#SBATCH --mem-per-cpu=2GB +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module use /opt/insy/modulefiles # Use DAIC INSY software collection +module load openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=moons output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_moons.log diff --git a/experiments/daic/tuning/generators/synthetic.sh b/experiments/daic/tuning/generators/synthetic.sh deleted file mode 100644 index efe8e545..00000000 --- a/experiments/daic/tuning/generators/synthetic.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Grid-search Synthetic (ECCCo)" -#SBATCH --time=02:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=general -#SBATCH --mem-per-cpu=8GB -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module use /opt/insy/modulefiles # Use DAIC INSY software collection -module load openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=circles output_path=results mpi grid_search > experiments/grid_search_synthetic.log diff --git a/experiments/data/data.jl b/experiments/data/data.jl index 58965d3a..e81e03f7 100644 --- a/experiments/data/data.jl +++ b/experiments/data/data.jl @@ -20,9 +20,10 @@ function _prepare_data(exper::Experiment) # JEM parameters: 𝒟y = Categorical(ones(output_dim) ./ output_dim) sampler = ConditionalSampler( - 𝒟x, 𝒟y, - input_size=(input_dim,), - batch_size=sampling_batch_size, + 𝒟x, + 𝒟y, + input_size = (input_dim,), + batch_size = sampling_batch_size, ) return X, labels, n_obs, batch_size, sampler end @@ -33,11 +34,11 @@ function meta_data(exper::Experiment) end function prepare_data(exper::Experiment) - X, labels, _, _, sampler = _prepare_data(exper::Experiment) + X, labels, _, _, sampler = _prepare_data(exper::Experiment) return X, labels, sampler end function batch_size(exper::Experiment) _, _, _, batch_size, _ = _prepare_data(exper::Experiment) return batch_size -end \ No newline at end of file +end diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 632ba235..1d733971 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -9,7 +9,7 @@ Base.@kwdef struct Experiment use_pretrained::Bool = !RETRAIN models::Union{Nothing,Dict} = nothing additional_models::Union{Nothing,Dict} = nothing - 𝒟x::Distribution = Normal() + 𝒟x::Distribution = ECCCo.prior_sampling_space(counterfactual_data) sampling_batch_size::Int = 50 sampling_steps::Int = 50 min_batch_size::Int = 128 @@ -17,13 +17,15 @@ Base.@kwdef struct Experiment n_hidden::Int = 32 n_layers::Int = 3 activation::Function = Flux.relu - builder::Union{Nothing,MLJFlux.Builder} = default_builder(n_hidden=n_hidden, n_layers=n_layers, activation=activation) + builder::Union{Nothing,MLJFlux.Builder} = + default_builder(n_hidden = n_hidden, n_layers = n_layers, activation = activation) α::AbstractArray = [1.0, 1.0, 1e-1] n_ens::Int = 5 use_ensembling::Bool = true coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing n_individuals::Int = N_IND + n_runs::Int = N_RUNS ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES use_class_loss::Bool = false @@ -31,13 +33,14 @@ Base.@kwdef struct Experiment Λ::AbstractArray = [0.25, 0.75, 0.75] Λ_Δ::AbstractArray = Λ opt::Flux.Optimise.AbstractOptimiser = Flux.Optimise.Descent(0.01) - parallelizer::Union{Nothing, AbstractParallelizer} = PLZ + parallelizer::Union{Nothing,AbstractParallelizer} = PLZ nsamples::Union{Nothing,Int} = nothing nmin::Union{Nothing,Int} = nothing finaliser::Function = Flux.softmax loss::Function = Flux.Losses.crossentropy train_parallel::Bool = false reg_strength::Real = 0.1 + decay::Tuple = (0.1, 5) niter_eccco::Union{Nothing,Int} = nothing model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL use_tuned::Bool = true @@ -48,9 +51,9 @@ end "A container to hold the results of an experiment." mutable struct ExperimentOutcome exper::Experiment - model_dict::Union{Nothing, Dict} - generator_dict::Union{Nothing, Dict} - bmk::Union{Nothing, Benchmark} + model_dict::Union{Nothing,Dict} + generator_dict::Union{Nothing,Dict} + bmk::Union{Nothing,Benchmark} end """ @@ -58,11 +61,16 @@ end Train the models specified by `exper` and store them in `outcome`. """ -function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_models::Bool=true, save_meta::Bool=false) - model_dict = prepare_models(exper; save_models=save_models) +function train_models!( + outcome::ExperimentOutcome, + exper::Experiment; + save_models::Bool = true, + save_meta::Bool = false, +) + model_dict = prepare_models(exper; save_models = save_models) outcome.model_dict = model_dict if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) - meta_model_performance(outcome; save_output=save_meta) + meta_model_performance(outcome; save_output = save_meta) end end @@ -82,10 +90,15 @@ end Run the experiment specified by `exper`. """ -function run_experiment(exper::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS) - +function run_experiment( + exper::Experiment; + save_output::Bool = true, + only_models::Bool = ONLY_MODELS, +) + # Setup - if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + if save_output && + !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) @info "All results will be saved to $(exper.output_path)." isdir(exper.output_path) || mkdir(exper.output_path) @info "All parameter choices will be saved to $(exper.params_path)." @@ -103,10 +116,10 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models:: # Model training: if only_models - train_models!(outcome, exper; save_models=save_output, save_meta=true) + train_models!(outcome, exper; save_models = save_output, save_meta = true) return outcome else - train_models!(outcome, exper; save_models=save_output) + train_models!(outcome, exper; save_models = save_output) end # Benchmark: @@ -116,10 +129,17 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models:: end # Save data: - if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) - Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_outcome.jls"), outcome) - Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_bmk.jls"), outcome.bmk) - meta(outcome; save_output=true) + if save_output && + !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + Serialization.serialize( + joinpath(exper.output_path, "$(exper.save_name)_outcome.jls"), + outcome, + ) + Serialization.serialize( + joinpath(exper.output_path, "$(exper.save_name)_bmk.jls"), + outcome.bmk, + ) + all_meta(outcome; save_output = true) end # Final barrier: @@ -136,14 +156,19 @@ end Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments. """ -function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool=true, kwargs...) +function run_experiment( + counterfactual_data::CounterfactualData, + test_data::CounterfactualData; + save_output::Bool = true, + kwargs..., +) # Parameters: exper = Experiment(; - counterfactual_data=counterfactual_data, - test_data=test_data, - kwargs... + counterfactual_data = counterfactual_data, + test_data = test_data, + kwargs..., ) - return run_experiment(exper; save_output=save_output) + return run_experiment(exper; save_output = save_output) end # Pre-trained models: @@ -159,4 +184,4 @@ function pretrained_path(exper::Experiment) Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML) return joinpath(LATEST_ARTIFACT_PATH, "results") end -end \ No newline at end of file +end diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl index 54a46ef8..34f602bf 100644 --- a/experiments/fmnist.jl +++ b/experiments/fmnist.jl @@ -4,7 +4,10 @@ n_obs = 10000 counterfactual_data = load_fashion_mnist(n_obs) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) # Adjust domain constraints to account for noise added during pre-processing: -counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1)) +counterfactual_data.domain = fill( + (minimum(counterfactual_data.X), maximum(counterfactual_data.X)), + size(counterfactual_data.X, 1), +) # VAE (trained on full dataset): using CounterfactualExplanations.Models: load_fashion_mnist_vae @@ -16,52 +19,59 @@ test_data = load_fashion_mnist_test() # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit( + MultivariateStats.PCA, + counterfactual_data.X; + maxoutdim = maxout_dim, +); # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE # Tuning parameters: tuning_params = DEFAULT_GENERATOR_TUNING -tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]]) +tuning_params = (; tuning_params..., Λ = [tuning_params.Λ[2:end]..., [0.01, 0.1, 3.0]]) # Additional models: -add_models = Dict( - "LeNet-5" => lenet5, -) +add_models = Dict("LeNet-5" => lenet5) # CE measures (add cosine distance): -ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] +ce_measures = + [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 100, - builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), - 𝒟x=Uniform(-1.0, 1.0), - α=[1.0, 1.0, 1e-2], - sampling_batch_size=10, - sampling_steps=25, - use_ensembling=true, - use_variants=false, - additional_models=add_models, - epochs=100, - nsamples=10, - nmin=1, - niter_eccco=10, - Λ=[0.01, 0.25, 0.25], - Λ_Δ=[0.01, 0.1, 0.3], - opt=Flux.Optimise.Descent(0.1), - reg_strength=0.0, - ce_measures=ce_measures, - dim_reduction=true, + n_individuals = N_IND_SPECIFIED ? N_IND : 100, + builder = default_builder(n_hidden = 128, n_layers = 1, activation = Flux.swish), + 𝒟x = Uniform(-1.0, 1.0), + α = [1.0, 1.0, 1e-2], + sampling_batch_size = 10, + sampling_steps = 25, + use_ensembling = true, + use_variants = false, + additional_models = add_models, + epochs = 100, + nsamples = 10, + nmin = 1, + niter_eccco = 10, + Λ = [0.01, 0.25, 0.25], + Λ_Δ = [0.01, 0.1, 0.3], + opt = Flux.Optimise.Descent(0.1), + reg_strength = 0.0, + ce_measures = ce_measures, + dim_reduction = true, ) +# Best grid search params: +params = append_best_params(params, dataname) + if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -73,9 +83,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/german_credit.jl b/experiments/german_credit.jl index e64ba13c..6a127b07 100644 --- a/experiments/german_credit.jl +++ b/experiments/german_credit.jl @@ -1,18 +1,26 @@ # Data: dataname = "German Credit" -counterfactual_data, test_data = train_test_split(load_german_credit(nothing); test_size=TEST_SIZE) +counterfactual_data, test_data = + train_test_split(load_german_credit(nothing); test_size = TEST_SIZE) + +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) # VAE: using CounterfactualExplanations.GenerativeModels: VAE, train! X = counterfactual_data.X y = counterfactual_data.output_encoder.y -vae = VAE(size(X,1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +vae = VAE(size(X, 1); nll = Flux.Losses.mse, epochs = 100, λ = 0.01, latent_dim = 5) train!(vae, X, y) counterfactual_data.generative_model = vae # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit( + MultivariateStats.PCA, + counterfactual_data.X; + maxoutdim = maxout_dim, +); # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -22,29 +30,29 @@ tuning_params = DEFAULT_GENERATOR_TUNING_LARGE # Parameter choices: params = ( - n_hidden=32, - activation=Flux.relu, - builder=default_builder(n_hidden=32, n_layers=3, activation=Flux.relu), - α=[1.0, 1.0, 1e-1], - sampling_batch_size=10, - sampling_steps=30, - use_ensembling=true, - opt=Flux.Optimise.Descent(0.05), - Λ=[0.2, 0.2, 0.2], - reg_strength=0.5, - n_individuals=25, - dim_reduction=true, + n_hidden = 32, + activation = Flux.relu, + builder = default_builder(n_hidden = 32, n_layers = 3, activation = Flux.relu), + α = [1.0, 1.0, 1e-1], + sampling_batch_size = 10, + sampling_steps = 30, + use_ensembling = true, + opt = Flux.Optimise.Descent(0.05), + Λ = [0.2, 0.2, 0.2], + reg_strength = 0.5, + dim_reduction = true, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -56,9 +64,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 9f5ab2c5..2b93ff4e 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -1,19 +1,26 @@ # Data: dataname = "GMSC" -counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) +counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size = TEST_SIZE) nobs = size(counterfactual_data.X, 2) +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) + # VAE: using CounterfactualExplanations.GenerativeModels: VAE, train! X = counterfactual_data.X y = counterfactual_data.output_encoder.y -vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +vae = VAE(size(X, 1); nll = Flux.Losses.mse, epochs = 100, λ = 0.01, latent_dim = 5) train!(vae, X, y) counterfactual_data.generative_model = vae # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit( + MultivariateStats.PCA, + counterfactual_data.X; + maxoutdim = maxout_dim, +); # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -23,29 +30,29 @@ tuning_params = DEFAULT_GENERATOR_TUNING_LARGE # Parameter choices: params = ( - n_hidden=32, - activation=Flux.relu, - builder=default_builder(n_hidden=32, n_layers=3, activation=Flux.relu), + n_hidden = 32, + activation = Flux.relu, + builder = default_builder(n_hidden = 32, n_layers = 3, activation = Flux.relu), α = [1.0, 1.0, 1e-1], sampling_batch_size = 10, sampling_steps = 30, use_ensembling = true, - opt=Flux.Optimise.Descent(0.05), + opt = Flux.Optimise.Descent(0.05), Λ = [0.1, 0.1, 0.1], reg_strength = 0.0, - n_individuals=25, - dim_reduction=true, + dim_reduction = true, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -57,9 +64,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) end diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 323c47c9..67cb892d 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -1,7 +1,10 @@ +using DataFrames + """ grid_search( couterfactual_data::CounterfactualData, test_data::CounerfactualData; + warm_start::Bool = true, dataname::String, tuning_params::NamedTuple, kwargs..., @@ -12,44 +15,108 @@ Perform a grid search over the hyperparameters specified by `tuning_params`. Exp function grid_search( couterfactual_data::CounterfactualData, test_data::CounterfactualData; + warm_start::Bool = true, dataname::String, - n_individuals::Int=N_IND, + n_individuals::Int = N_IND, tuning_params::NamedTuple, kwargs..., ) # Output path: - grid_search_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "grid_search")) + grid_search_path = joinpath(DEFAULT_OUTPUT_PATH, "grid_search") + if !isdir(grid_search_path) + mkpath(grid_search_path) + end # Grid setup: tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] grid = Iterators.product(tuning_params...) - outcomes = Dict{Any,Any}() - + n_total = length(grid) + + # Temporary storage on disk: + storage_path = joinpath(grid_search_path, ".tmp_results_$(replace(lowercase(dataname), " " => "_"))") + if !isdir(storage_path) + mkpath(storage_path) + end + @info "Storing temporary results in $(storage_path)." + + # Warm start: + if warm_start + existing_files = readdir(storage_path) + n_files = Int(floor(length(existing_files) / 2)) + if n_files > 0 + @info "Warm start: $(n_files) existing results found." + grid = Iterators.drop(grid, n_files) + end + counter = n_files + 1 + else + counter = 1 + end + # Search: - counter = 1 - for tuning_params in grid - @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(tuning_params)" + for params in grid + @info "Running experiment $(counter)/$(length(grid)) with tuning parameters: $(params)" + + # Filter out keyword parameters that are tuned: + _keys = [k[1] for k in kwargs] + not_these = _keys[findall([k in map(k -> k[1], params) for k in _keys])] + not_these = (not_these..., :n_individuals) + kwargs = filter(x -> !(x[1] ∈ not_these), kwargs) + + # Run experiment: outcome = run_experiment( - counterfactual_data, test_data; - save_output=false, - dataname=dataname, - n_individuals=n_individuals, - output_path=grid_search_path, - tuning_params..., + counterfactual_data, + test_data; + save_output = false, + dataname = dataname, + n_individuals = n_individuals, + output_path = grid_search_path, + params..., kwargs..., ) - outcomes[tuning_params] = outcome + + # Collect: + _params = map(x -> typeof(x[2]) <: Vector ? x[1] => Tuple(x[2]) : x[1] => x[2], params) + df_params = + DataFrame(merge(Dict(:id => counter), Dict(_params))) |> + x -> select(x, :id, Not(:id)) + df_outcomes = + DataFrame(Dict(:id => counter, :params => params, :outcome => outcome)) |> + x -> select(x, :id, Not(:id)) + + # Save: + if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) + Serialization.serialize( + joinpath(storage_path, "params_$(counter).jls"), + df_params, + ) + Serialization.serialize( + joinpath(storage_path, "outcomes_$(counter).jls"), + df_outcomes, + ) + end counter += 1 end # Save: if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0) - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best.jls"), best_absolute_outcome(outcomes)) - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco.jls"), best_absolute_outcome_eccco(outcomes)) - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_"))_best_eccco_delta.jls"), best_absolute_outcome_eccco_Δ(outcomes)) + + # Deserialise: + df_params = [] + df_outcomes = [] + for i in 1:n_total + push!(df_params, Serialization.deserialize(joinpath(storage_path, "params_$(i).jls"))) + push!(df_outcomes, Serialization.deserialize(joinpath(storage_path, "outcomes_$(i).jls"))) + end + outcomes = Dict(:df_params => vcat(df_params...), :df_outcomes => vcat(df_outcomes...)) + + # Save: + Serialization.serialize( + joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), + outcomes, + ) end + end const ALL_ECCCO_NAMES = [ @@ -61,18 +128,9 @@ const ALL_ECCCO_NAMES = [ "ECCCo-Δ (no EBM)", ] -const ECCCO_NAMES = [ - "ECCCo", - "ECCCo (no CP)", - "ECCCo (no EBM)", -] +const ECCCO_NAMES = ["ECCCo", "ECCCo (no CP)", "ECCCo (no EBM)"] -const ECCCo_Δ_NAMES = [ - "ECCCo-Δ", - "ECCCo-Δ (no CP)", - "ECCCo-Δ (no EBM)", - "ECCCo-Δ (latent)", -] +const ECCCo_Δ_NAMES = ["ECCCo-Δ", "ECCCo-Δ (no CP)", "ECCCo-Δ (no EBM)", "ECCCo-Δ (latent)"] """ best_outcome(outcomes; generator=ECCCO_NAMES, measure=["distance_from_energy", "distance_from_targets"]) @@ -80,35 +138,43 @@ const ECCCo_Δ_NAMES = [ Returns the best outcome from grid search results. The best outcome is defined as the one with the lowest average rank across all datasets and variables for the specified generator and measure. """ function best_rank_outcome( - outcomes::Dict; - generator=ALL_ECCCO_NAMES, - measure=["distance_from_energy_l2", "distance_from_targets_l2"], - model::Union{Nothing,AbstractArray}=nothing, - weights::Union{Nothing,AbstractArray}=nothing + outcomes::Dict; + generator = ALL_ECCCO_NAMES, + measure = ["distance_from_energy_l2", "distance_from_targets_l2"], + model::Union{Nothing,AbstractArray} = nothing, + weights::Union{Nothing,AbstractArray} = nothing, ) weights = isnothing(weights) ? ones(length(measure)) : weights - df_weights = DataFrame(variable=measure, weight=weights) + df_weights = DataFrame(variable = measure, weight = weights) ranks = [] - for (params, outcome) in outcomes - _ranks = generator_rank(outcome; generator=generator, measure=measure, model=model) |> - x -> leftjoin(x, df_weights, on=:variable) |> - x -> x.avg_rank .* x.weight |> - x -> (sum(x) / length(x))[1] + for outcome in outcomes[:df_outcomes].outcome + _ranks = + generator_rank( + outcome; + generator = generator, + measure = measure, + model = model, + ) |> + x -> + leftjoin(x, df_weights, on = :variable) |> + x -> x.avg_rank .* x.weight |> x -> (sum(x)/length(x))[1] push!(ranks, _ranks) end best_index = argmin(ranks) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index] + params=outcomes[:df_outcomes].params[best_index], + outcome=outcomes[:df_outcomes].outcome[best_index], ) return best_outcome end -best_rank_eccco(outcomes; kwrgs...) = best_rank_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...) +best_rank_eccco(outcomes; kwrgs...) = + best_rank_outcome(outcomes; generator = ECCCO_NAMES, kwrgs...) -best_rank_eccco_Δ(outcomes; kwrgs...) = best_rank_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...) +best_rank_eccco_Δ(outcomes; kwrgs...) = + best_rank_outcome(outcomes; generator = ECCCo_Δ_NAMES, kwrgs...) """ best_absolute_outcome(outcomes; generator=ECCCO_NAMES, measure="distance_from_energy") @@ -116,18 +182,18 @@ best_rank_eccco_Δ(outcomes; kwrgs...) = best_rank_outcome(outcomes; generator=E Return the best outcome from grid search results. The best outcome is defined as the one with the lowest average value across all datasets and variables for the specified generator and measure. """ function best_absolute_outcome( - outcomes::Dict; - generator=ECCCO_NAMES, - measure::AbstractArray=["distance_from_energy_l2"], - model::Union{Nothing,AbstractArray}=nothing, - weights::Union{Nothing,AbstractArray}=nothing + outcomes::Dict; + generator = ECCCO_NAMES, + measure::AbstractArray = ["distance_from_energy_l2"], + model::Union{Nothing,AbstractArray} = nothing, + weights::Union{Nothing,AbstractArray} = nothing, ) weights = isnothing(weights) ? ones(length(measure)) : weights - df_weights = DataFrame(variable=measure, weight=weights) + df_weights = DataFrame(variable = measure, weight = weights) avg_values = [] - for (params, outcome) in outcomes + for (params, outcome) in zip(outcomes[:df_outcomes].params, outcomes[:df_outcomes].outcome) # Setup evaluation = deepcopy(outcome.bmk.evaluation) @@ -136,7 +202,7 @@ function best_absolute_outcome( model_dict = outcome.model_dict # Discard outlier results: - if any(evaluation.value .> 1e6) + if any(abs.(evaluation.value) .> 1e6) @warn "Discarding outlier results: $(params)." push!(avg_values, Inf) continue @@ -146,50 +212,72 @@ function best_absolute_outcome( higher_is_better = [var ∈ ["validity", "redundancy"] for var in evaluation.variable] evaluation.value[higher_is_better] .= -evaluation.value[higher_is_better] - # Normalise to allow for comparison across measures: - evaluation = - groupby(evaluation, [:dataname, :variable]) |> - x -> transform(x, :value => standardize => :value) - # Reconstruct outcome with normalised values: bmk = CounterfactualExplanations.Evaluation.Benchmark(evaluation) outcome = ExperimentOutcome(exper, model_dict, generator_dict, bmk) # Compute: - results = summarise_outcome(outcome, measure=measure, model=model) |> - x -> leftjoin(x, df_weights, on=:variable) - + results = + summarise_outcome(outcome, measure = measure, model = model) |> + x -> leftjoin(x, df_weights, on = :variable) + # Compute weighted averages: - _avg_values = subset(results, :generator => ByRow(x -> x ∈ generator)) |> - x -> x.mean .* x.weight |> - x -> (sum(x)/length(x))[1] + _avg_values = + subset(results, :generator => ByRow(x -> x ∈ generator)) |> + x -> x.mean .* x.weight |> x -> (sum(x)/length(x))[1] # Append: push!(avg_values, _avg_values) end best_index = argmin(avg_values) best_outcome = ( - params = collect(keys(outcomes))[best_index], - outcome = collect(values(outcomes))[best_index] + params=outcomes[:df_outcomes].params[best_index], + outcome=outcomes[:df_outcomes].outcome[best_index], ) + + return best_outcome end -best_absolute_outcome_eccco(outcomes; kwrgs...) = best_absolute_outcome(outcomes; generator=ECCCO_NAMES, kwrgs...) +best_absolute_outcome_eccco(outcomes; kwrgs...) = + best_absolute_outcome(outcomes; generator = ECCCO_NAMES, kwrgs...) -best_absolute_outcome_eccco_Δ(outcomes; kwrgs...) = best_absolute_outcome(outcomes; generator=ECCCo_Δ_NAMES, kwrgs...) +best_absolute_outcome_eccco_Δ(outcomes; kwrgs...) = + best_absolute_outcome(outcomes; generator = ECCCo_Δ_NAMES, kwrgs...) + +""" + best_outcome(outcomes) + +The best outcome is chosen as follows: choose the outcome with the minium average unfaithfulness (`distance_from_energy_l2`) aggregated across all ECCCo generators (`ECCCo_Δ_NAMES`) for the weakest models (`MLP` and `MLP Ensemble`). +""" +best_outcome(outcomes; measure=["distance_from_energy_l2"]) = best_absolute_outcome(outcomes; generator=ECCCo_Δ_NAMES, measure=measure, model=["MLP", "MLP Ensemble"]) """ append_best_params!(params::NamedTuple, dataname::String) Appends the best parameters from grid search results to the specified parameters. """ -function append_best_params!(params::NamedTuple, dataname::String) - if !isfile(joinpath(DEFAULT_OUTPUT_PATH, "grid_search", "$(replace(lowercase(dataname), " " => "_")).jls")) +function append_best_params(params::NamedTuple, dataname::String) + if !isfile( + joinpath( + DEFAULT_OUTPUT_PATH, + "grid_search", + "$(replace(lowercase(dataname), " " => "_")).jls", + ), + ) @warn "No grid search results found. Using default parameters." else @info "Appending best parameters from grid search results." - grid_search_results = Serialization.deserialize(joinpath(DEFAULT_OUTPUT_PATH, "grid_search", "$(replace(lowercase(dataname), " " => "_")).jls")) - best_params = best_absolute_outcome_eccco_Δ(grid_search_results).params + grid_search_results = Serialization.deserialize( + joinpath( + DEFAULT_OUTPUT_PATH, + "grid_search", + "$(replace(lowercase(dataname), " " => "_")).jls", + ), + ) + best_params = best_outcome(grid_search_results).params params = (; params..., best_params...) + + params = (; params..., (; Λ = typeof(params.Λ) <: Tuple ? collect(params.Λ) : params.Λ)...) end -end \ No newline at end of file + return params +end diff --git a/experiments/jobscripts/generators/california_housing.sh b/experiments/jobscripts/generators/california_housing.sh index 07d5770e..81fdebbc 100644 --- a/experiments/jobscripts/generators/california_housing.sh +++ b/experiments/jobscripts/generators/california_housing.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="California Housing (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:40:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/california_housing.log diff --git a/experiments/jobscripts/generators/circles.sh b/experiments/jobscripts/generators/circles.sh new file mode 100644 index 00000000..194782e3 --- /dev/null +++ b/experiments/jobscripts/generators/circles.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Circles (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=circles output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/circles.log diff --git a/experiments/jobscripts/generators/credit_default.sh b/experiments/jobscripts/generators/credit_default.sh deleted file mode 100644 index e4d99b73..00000000 --- a/experiments/jobscripts/generators/credit_default.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="California Housing (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log diff --git a/experiments/jobscripts/generators/fmnist.sh b/experiments/jobscripts/generators/fmnist.sh index 10504108..e61e8fe9 100644 --- a/experiments/jobscripts/generators/fmnist.sh +++ b/experiments/jobscripts/generators/fmnist.sh @@ -1,9 +1,9 @@ #!/bin/bash -#SBATCH --job-name="Fashion-MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="Fashion MNIST (ECCCo)" +#SBATCH --time=03:00:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=5 #SBATCH --partition=compute #SBATCH --mem-per-cpu=8GB #SBATCH --account=research-eemcs-insy @@ -11,4 +11,6 @@ module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=fmnist retrain output_path=results threaded mpi > experiments/fmnist.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=fmnist output_path=results mpi threaded n_individuals=100 n_runs=10 vertical_splits=100 > experiments/logs/fmnist.log diff --git a/experiments/jobscripts/generators/german_credit.sh b/experiments/jobscripts/generators/german_credit.sh index cb312771..997208fa 100644 --- a/experiments/jobscripts/generators/german_credit.sh +++ b/experiments/jobscripts/generators/german_credit.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="German Credit (ECCCo)" -#SBATCH --time=1:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:00:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=german_credit output_path=results mpi > experiments/german_credit.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=german_credit output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/german_credit.log diff --git a/experiments/jobscripts/generators/gmsc.sh b/experiments/jobscripts/generators/gmsc.sh index 5809363f..9a417354 100644 --- a/experiments/jobscripts/generators/gmsc.sh +++ b/experiments/jobscripts/generators/gmsc.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="GMSC (ECCCo)" -#SBATCH --time=3:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc output_path=results mpi > experiments/gmsc.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=gmsc output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/gmsc.log \ No newline at end of file diff --git a/experiments/jobscripts/generators/linearly_separable.sh b/experiments/jobscripts/generators/linearly_separable.sh new file mode 100644 index 00000000..8011f838 --- /dev/null +++ b/experiments/jobscripts/generators/linearly_separable.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Linearly Separable (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/linearly_separable.log diff --git a/experiments/jobscripts/generators/mnist.sh b/experiments/jobscripts/generators/mnist.sh index ef0180fe..c8d7a616 100644 --- a/experiments/jobscripts/generators/mnist.sh +++ b/experiments/jobscripts/generators/mnist.sh @@ -1,9 +1,9 @@ #!/bin/bash #SBATCH --job-name="MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=03:00:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=5 #SBATCH --partition=compute #SBATCH --mem-per-cpu=8GB #SBATCH --account=research-eemcs-insy @@ -11,4 +11,6 @@ module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results mpi threaded n_individuals=100 n_runs=10 vertical_splits=100 > experiments/logs/mnist.log diff --git a/experiments/jobscripts/generators/mnist_memory.sh b/experiments/jobscripts/generators/mnist_memory.sh deleted file mode 100644 index 0f653aeb..00000000 --- a/experiments/jobscripts/generators/mnist_memory.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="MNIST (ECCCo)" -#SBATCH --time=10:00:00 -#SBATCH --ntasks=150 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=memory -#SBATCH --mem-per-cpu=64GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log diff --git a/experiments/jobscripts/generators/moons.sh b/experiments/jobscripts/generators/moons.sh new file mode 100644 index 00000000..d526b1b7 --- /dev/null +++ b/experiments/jobscripts/generators/moons.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Moons (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=moons output_path=results mpi threaded n_individuals=100 n_runs=50 > experiments/logs/moons.log diff --git a/experiments/jobscripts/generators/synthetic.sh b/experiments/jobscripts/generators/synthetic.sh deleted file mode 100644 index 15e4f310..00000000 --- a/experiments/jobscripts/generators/synthetic.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Synthetic (ECCCo)" -#SBATCH --time=02:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=compute -#SBATCH --mem-per-cpu=4GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi > experiments/synthetic.log diff --git a/experiments/jobscripts/generators/tabular.sh b/experiments/jobscripts/generators/tabular.sh deleted file mode 100644 index 43b43b6f..00000000 --- a/experiments/jobscripts/generators/tabular.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Tabular (ECCCo)" -#SBATCH --time=12:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit,california_housing output_path=results mpi > experiments/tabular.log diff --git a/experiments/jobscripts/testing/all_small.sh b/experiments/jobscripts/testing/all_small.sh new file mode 100644 index 00000000..8a953f6d --- /dev/null +++ b/experiments/jobscripts/testing/all_small.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=01:00:00 +#SBATCH --ntasks=5 +#SBATCH --cpus-per-task=4 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=innovation +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable,moons,circles,gmsc,german_credit,california_housing output_path=results_testing mpi grid_search n_individuals=5 threaded > experiments/logs/all_testing.log + \ No newline at end of file diff --git a/experiments/jobscripts/testing/cali.sh b/experiments/jobscripts/testing/cali.sh new file mode 100644 index 00000000..383ef53f --- /dev/null +++ b/experiments/jobscripts/testing/cali.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search California Housing (ECCCo)" +#SBATCH --time=00:35:00 +#SBATCH --ntasks=5 +#SBATCH --cpus-per-task=5 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=1GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=1 threaded > experiments/logs/grid_search_california_housing.log \ No newline at end of file diff --git a/experiments/jobscripts/testing/fmnist.sh b/experiments/jobscripts/testing/fmnist.sh new file mode 100644 index 00000000..42a2e6a7 --- /dev/null +++ b/experiments/jobscripts/testing/fmnist.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Testing FMNIST" +#SBATCH --time=00:15:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=5 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=4GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=fmnist output_path=results_testing mpi threaded n_individuals=100 n_runs=5 vertical_splits=100 > experiments/logs/testing_fmnist.log diff --git a/experiments/jobscripts/testing/lin_sep.sh b/experiments/jobscripts/testing/lin_sep.sh new file mode 100644 index 00000000..5b145457 --- /dev/null +++ b/experiments/jobscripts/testing/lin_sep.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Test Lin Sep (ECCCo)" +#SBATCH --time=00:20:00 +#SBATCH --ntasks=5 +#SBATCH --cpus-per-task=4 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi threaded n_individuals=48 n_runs=5 n_each=8 > experiments/logs/lin-sep.log diff --git a/experiments/jobscripts/testing/lin_sep_balance.sh b/experiments/jobscripts/testing/lin_sep_balance.sh new file mode 100644 index 00000000..fda1c4ed --- /dev/null +++ b/experiments/jobscripts/testing/lin_sep_balance.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=00:15:00 +#SBATCH --ntasks=80 +#SBATCH --cpus-per-task=5 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=1GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search n_individuals=20 threaded n_each=nothing > experiments/logs/grid_search_linearly_separable_threaded.log diff --git a/experiments/jobscripts/testing/lin_sep_threaded.sh b/experiments/jobscripts/testing/lin_sep_threaded.sh new file mode 100644 index 00000000..b442f6fe --- /dev/null +++ b/experiments/jobscripts/testing/lin_sep_threaded.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=00:30:00 +#SBATCH --ntasks=14 +#SBATCH --cpus-per-task=14 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=1GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search n_individuals=20 threaded n_each=nothing > experiments/logs/grid_search_linearly_separable_threaded.log diff --git a/experiments/jobscripts/testing/mnist.sh b/experiments/jobscripts/testing/mnist.sh new file mode 100644 index 00000000..99867ae2 --- /dev/null +++ b/experiments/jobscripts/testing/mnist.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="MNIST (ECCCo)" +#SBATCH --time=00:20:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=16 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=4GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results_testing mpi threaded n_individuals=100 n_runs=5 vertical_splits=100 > experiments/logs/testing_mnist.log diff --git a/experiments/jobscripts/tuning/generators/california_housing.sh b/experiments/jobscripts/tuning/generators/california_housing.sh index 373d6607..da31b5a6 100644 --- a/experiments/jobscripts/tuning/generators/california_housing.sh +++ b/experiments/jobscripts/tuning/generators/california_housing.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="Grid-search California Housing (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:40:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=4GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=25 store_ce > experiments/grid_search_california_housing.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_california_housing.log diff --git a/experiments/jobscripts/tuning/generators/circles.sh b/experiments/jobscripts/tuning/generators/circles.sh index 5d926be3..7b3bf21e 100644 --- a/experiments/jobscripts/tuning/generators/circles.sh +++ b/experiments/jobscripts/tuning/generators/circles.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="Grid-search Circles (ECCCo)" -#SBATCH --time=02:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 +#SBATCH --time=01:30:00 +#SBATCH --ntasks=15 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=2GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=circles output_path=results mpi grid_search > experiments/grid_search_circles.log +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=circles output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_circles.log diff --git a/experiments/jobscripts/tuning/generators/fmnist.sh b/experiments/jobscripts/tuning/generators/fmnist.sh index d8aefd40..7678d301 100644 --- a/experiments/jobscripts/tuning/generators/fmnist.sh +++ b/experiments/jobscripts/tuning/generators/fmnist.sh @@ -1,14 +1,16 @@ #!/bin/bash -#SBATCH --job-name="Grid-search Fashion MNIST (ECCCo)" -#SBATCH --time=32:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --job-name="Fashion MNIST - Grid (ECCCo)" +#SBATCH --time=04:00:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=16GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=fmnist output_path=results mpi grid_search > experiments/grid_search_fmnist.log \ No newline at end of file +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=fmnist output_path=results mpi grid_search threaded n_individuals=10 n_each=16 > experiments/logs/grid_search_fmnist.log \ No newline at end of file diff --git a/experiments/jobscripts/tuning/generators/german_credit.sh b/experiments/jobscripts/tuning/generators/german_credit.sh new file mode 100644 index 00000000..465881f6 --- /dev/null +++ b/experiments/jobscripts/tuning/generators/german_credit.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search German Credit (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=4GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=german_credit output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_german_credit.log diff --git a/experiments/jobscripts/tuning/generators/gmsc.sh b/experiments/jobscripts/tuning/generators/gmsc.sh new file mode 100644 index 00000000..e77cb534 --- /dev/null +++ b/experiments/jobscripts/tuning/generators/gmsc.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search GMSC (ECCCo)" +#SBATCH --time=01:40:00 +#SBATCH --ntasks=30 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=4GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=gmsc output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_gmsc.log diff --git a/experiments/jobscripts/tuning/generators/innovation/california_housing.sh b/experiments/jobscripts/tuning/generators/innovation/california_housing.sh index a9bc302a..332b3915 100644 --- a/experiments/jobscripts/tuning/generators/innovation/california_housing.sh +++ b/experiments/jobscripts/tuning/generators/innovation/california_housing.sh @@ -2,13 +2,13 @@ #SBATCH --job-name="Grid-search California Housing (ECCCo)" #SBATCH --time=04:00:00 -#SBATCH --ntasks=48 +#SBATCH --ntasks=24 #SBATCH --cpus-per-task=1 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=4GB +#SBATCH --mem-per-cpu=8GB #SBATCH --account=innovation #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=20 > experiments/grid_search_california_housing.log +srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi grid_search n_individuals=10 n_each=16 > experiments/grid_search_california_housing.log diff --git a/experiments/jobscripts/tuning/generators/linearly_separable.sh b/experiments/jobscripts/tuning/generators/linearly_separable.sh new file mode 100644 index 00000000..d05f230c --- /dev/null +++ b/experiments/jobscripts/tuning/generators/linearly_separable.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Linearly Separable (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=15 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=linearly_separable output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_linearly_separable.log diff --git a/experiments/jobscripts/tuning/generators/mnist.sh b/experiments/jobscripts/tuning/generators/mnist.sh index 8e1ea4fc..3ce1b013 100644 --- a/experiments/jobscripts/tuning/generators/mnist.sh +++ b/experiments/jobscripts/tuning/generators/mnist.sh @@ -1,14 +1,16 @@ #!/bin/bash #SBATCH --job-name="Grid-search MNIST (ECCCo)" -#SBATCH --time=32:00:00 -#SBATCH --ntasks=1000 -#SBATCH --cpus-per-task=1 +#SBATCH --time=04:00:00 +#SBATCH --ntasks=10 +#SBATCH --cpus-per-task=10 #SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB +#SBATCH --mem-per-cpu=16GB #SBATCH --account=research-eemcs-insy #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi grid_search > experiments/grid_search_mnist.log \ No newline at end of file +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=mnist output_path=results mpi grid_search threaded n_individuals=10 n_each=16 > experiments/logs/grid_search_mnist.log \ No newline at end of file diff --git a/experiments/jobscripts/tuning/generators/moons.sh b/experiments/jobscripts/tuning/generators/moons.sh new file mode 100644 index 00000000..eabcb86a --- /dev/null +++ b/experiments/jobscripts/tuning/generators/moons.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Moons (ECCCo)" +#SBATCH --time=01:30:00 +#SBATCH --ntasks=15 +#SBATCH --cpus-per-task=10 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=2GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +source experiments/slurm_header.sh + +srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=moons output_path=results mpi grid_search threaded n_individuals=100 > experiments/logs/grid_search_moons.log diff --git a/experiments/jobscripts/tuning/generators/synthetic.sh b/experiments/jobscripts/tuning/generators/synthetic.sh deleted file mode 100644 index 66bcf412..00000000 --- a/experiments/jobscripts/tuning/generators/synthetic.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Grid-search Synthetic (ECCCo)" -#SBATCH --time=05:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi grid_search > experiments/grid_search_synthetic.log diff --git a/experiments/jobscripts/tuning/generators/tabular.sh b/experiments/jobscripts/tuning/generators/tabular.sh deleted file mode 100644 index 3c043ea7..00000000 --- a/experiments/jobscripts/tuning/generators/tabular.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -#SBATCH --job-name="Grid-search Tabular (ECCCo)" -#SBATCH --time=04:00:00 -#SBATCH --ntasks=100 -#SBATCH --cpus-per-task=1 -#SBATCH --partition=compute -#SBATCH --mem-per-cpu=8GB -#SBATCH --account=research-eemcs-insy -#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. - -module load 2023r1 openmpi - -srun julia --project=experiments experiments/run_experiments.jl -- data=gmsc,german_credit output_path=results mpi grid_search n_individuals=25 store_ce > experiments/grid_search_tabular.log diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index db09b5bb..43554f36 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -2,10 +2,13 @@ dataname = "Linearly Separable" n_obs = Int(1000 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split( - load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); - test_size=TEST_SIZE + load_blobs(n_obs; cluster_std = 0.1, center_box = (-1.0 => 1.0)); + test_size = TEST_SIZE, ) +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_SMALL @@ -15,25 +18,27 @@ tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: # These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( - use_tuned=false, - n_hidden=16, - n_layers=3, - activation=Flux.swish, - epochs=100, - opt=Flux.Optimise.Descent(0.01), - Λ=[0.1, 0.1, 0.05], - reg_strength=0.0, + use_tuned = false, + n_hidden = 16, + n_layers = 3, + activation = Flux.swish, + epochs = 100, + opt = Flux.Optimise.Descent(0.01), + Λ = [0.1, 0.1, 0.05], + reg_strength = 0.0, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) +@info "Using the following parameters: $(params)" -if GRID_SEARCH +if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -45,9 +50,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 0bfc9181..a4e9d50d 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -4,7 +4,10 @@ n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) # Adjust domain constraints to account for noise added during pre-processing: -counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1)) +counterfactual_data.domain = fill( + (minimum(counterfactual_data.X), maximum(counterfactual_data.X)), + size(counterfactual_data.X, 1), +) # VAE (trained on full dataset): using CounterfactualExplanations.Models: load_mnist_vae @@ -16,52 +19,59 @@ test_data = load_mnist_test() # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit( + MultivariateStats.PCA, + counterfactual_data.X; + maxoutdim = maxout_dim, +); # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE # Tuning parameters: tuning_params = DEFAULT_GENERATOR_TUNING -tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]]) +tuning_params = (; tuning_params..., Λ = [tuning_params.Λ[2:end]..., [0.01, 0.1, 3.0]]) # Additional models: -add_models = Dict( - "LeNet-5" => lenet5, -) +add_models = Dict("LeNet-5" => lenet5) # CE measures (add cosine distance): -ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] +ce_measures = + [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 100, - builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), - 𝒟x=Uniform(-1.0, 1.0), - α=[1.0, 1.0, 1e-2], - sampling_batch_size=10, - sampling_steps=25, - use_ensembling=true, - use_variants=false, - additional_models=add_models, - epochs=100, - nsamples=10, - nmin=1, - niter_eccco=10, - Λ=[0.01, 0.25, 0.25], - Λ_Δ=[0.01, 0.1, 0.3], - opt=Flux.Optimise.Descent(0.1), + n_individuals = N_IND_SPECIFIED ? N_IND : 100, + builder = default_builder(n_hidden = 128, n_layers = 1, activation = Flux.swish), + 𝒟x = Uniform(-1.0, 1.0), + α = [1.0, 1.0, 1e-2], + sampling_batch_size = 10, + sampling_steps = 25, + use_ensembling = true, + use_variants = false, + additional_models = add_models, + epochs = 100, + nsamples = 10, + nmin = 1, + niter_eccco = 10, + Λ = [0.01, 0.25, 0.25], + Λ_Δ = [0.01, 0.1, 0.3], + opt = Flux.Optimise.Descent(0.1), reg_strength = 0.0, - ce_measures=ce_measures, - dim_reduction=true, + ce_measures = ce_measures, + dim_reduction = true, ) +# Best grid search params: +params = append_best_params(params, dataname) + if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -73,9 +83,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/model_tuning.jl b/experiments/model_tuning.jl index dfabfbaa..c7a2abfe 100644 --- a/experiments/model_tuning.jl +++ b/experiments/model_tuning.jl @@ -17,23 +17,36 @@ function tune_mlp(exper::Experiment; kwargs...) model_tuning_path = mkpath(tuned_model_path(exper)) # Simple MLP: model = NeuralNetworkClassifier( - builder=default_builder(), - epochs=exper.epochs, - batch_size=batch_size(exper), - finaliser=exper.finaliser, - loss=exper.loss, - acceleration=CUDALibs(), + builder = default_builder(), + epochs = exper.epochs, + batch_size = batch_size(exper), + finaliser = exper.finaliser, + loss = exper.loss, + acceleration = CUDALibs(), ) # Unpack data: X, y, _ = prepare_data(exper::Experiment) # Tune model: measure = collect(values(exper.model_measures)) - mach = tune_mlp(model, X, y; tuning_params=exper.model_tuning_params, measure=measure, kwargs...) + mach = tune_mlp( + model, + X, + y; + tuning_params = exper.model_tuning_params, + measure = measure, + kwargs..., + ) # Machine is still on GPU, save CPU version of model: best_results = fitted_params(mach) - Serialization.serialize(joinpath(model_tuning_path, "$(exper.save_name)_best_mlp.jls"), best_results) + Serialization.serialize( + joinpath(model_tuning_path, "$(exper.save_name)_best_mlp.jls"), + best_results, + ) best_history = report(mach).best_history_entry - Serialization.serialize(joinpath(model_tuning_path, "$(exper.save_name)_best_mlp_history.jls"), best_history) + Serialization.serialize( + joinpath(model_tuning_path, "$(exper.save_name)_best_mlp_history.jls"), + best_history, + ) end return mach, best_results end @@ -44,39 +57,41 @@ end Tunes a model by performing a grid search over the parameters specified in `tuning_params`. """ function tune_mlp( - model::Supervised, X, y; + model::Supervised, + X, + y; tuning_params::NamedTuple, - measure::Vector=MODEL_MEASURE_VEC, - tuning=Grid(shuffle=false), - resampling=CV(nfolds=3, shuffle=true,), - kwargs... + measure::Vector = MODEL_MEASURE_VEC, + tuning = Grid(shuffle = false), + resampling = CV(nfolds = 3, shuffle = true), + kwargs..., ) ranges = [] for (k, v) in pairs(tuning_params) if k ∈ fieldnames(typeof(model)) - push!(ranges, range(model, k, values=v)) + push!(ranges, range(model, k, values = v)) elseif k ∈ fieldnames(typeof(model.builder)) - push!(ranges, range(model, :(builder.$(k)), values=v)) + push!(ranges, range(model, :(builder.$(k)), values = v)) elseif k ∈ fieldnames(typeof(model.optimiser)) - push!(ranges, range(model, :(optimiser.$(k)), values=v)) + push!(ranges, range(model, :(optimiser.$(k)), values = v)) else error("Parameter $k not found in model, builder or optimiser.") end end - + self_tuning_mod = TunedModel( - model=model, - range=ranges, - measure=measure, - tuning=tuning, - resampling=resampling, - kwargs... + model = model, + range = ranges, + measure = measure, + tuning = tuning, + resampling = resampling, + kwargs..., ) mach = machine(self_tuning_mod, X, y) - fit!(mach, verbosity=0) + fit!(mach, verbosity = 0) return mach @@ -87,5 +102,5 @@ end Checks if a tuned MLP exists. """ -tuned_mlp_exists(exper::Experiment) = isfile(joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls")) - +tuned_mlp_exists(exper::Experiment) = + isfile(joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls")) diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl index d6e8f814..410c3c86 100644 --- a/experiments/models/additional_models.jl +++ b/experiments/models/additional_models.jl @@ -4,9 +4,9 @@ MLJFlux builder for a LeNet-like convolutional neural network. """ mutable struct LeNetBuilder - filter_size::Int - channels1::Int - channels2::Int + filter_size::Int + channels1::Int + channels2::Int end """ @@ -18,27 +18,23 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) # Setup: _n_in = Int(sqrt(n_in)) - k, c1, c2 = b.filter_size, b.channels1, b.channels2 - mod(k, 2) == 1 || error("`filter_size` must be odd. ") + k, c1, c2 = b.filter_size, b.channels1, b.channels2 + mod(k, 2) == 1 || error("`filter_size` must be odd. ") p = div(k - 1, 2) # padding to preserve image size on convolution: # Model: - front = Flux.Chain( - Conv((k, k), 1 => c1, pad=(p, p), relu), + front = Flux.Chain( + Conv((k, k), 1 => c1, pad = (p, p), relu), MaxPool((2, 2)), - Conv((k, k), c1 => c2, pad=(p, p), relu), + Conv((k, k), c1 => c2, pad = (p, p), relu), MaxPool((2, 2)), - Flux.flatten - ) - d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first - back = Flux.Chain( - Dense(d, 120, relu), - Dense(120, 84, relu), - Dense(84, n_out), + Flux.flatten, ) + d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first + back = Flux.Chain(Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out)) chain = Flux.Chain(ECCCo.ToConv(_n_in), front, back) - return chain + return chain end """ @@ -46,7 +42,8 @@ end Builds a LeNet-like convolutional neural network. """ -lenet5(builder=LeNetBuilder(5, 6, 16); kwargs...) = NeuralNetworkClassifier(builder=builder; acceleration=CUDALibs(), kwargs...) +lenet5(builder = LeNetBuilder(5, 6, 16); kwargs...) = + NeuralNetworkClassifier(builder = builder; acceleration = CUDALibs(), kwargs...) """ ResNetBuilder @@ -62,13 +59,9 @@ Overloads the MLJFlux build function for a LeNet-like convolutional neural netwo """ function MLJFlux.build(b::ResNetBuilder, rng, n_in, n_out) _n_in = Int(sqrt(n_in)) - front = Metalhead.ResNet(18; inchannels=1) + front = Metalhead.ResNet(18; inchannels = 1) d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first - back = Flux.Chain( - Dense(d, 120, relu), - Dense(120, 84, relu), - Dense(84, n_out), - ) + back = Flux.Chain(Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out)) chain = Flux.Chain(ECCCo.ToConv(_n_in), front, back) return chain end @@ -78,4 +71,5 @@ end Builds a LeNet-like convolutional neural network. """ -resnet18(builder=ResNetBuilder(); kwargs...) = NeuralNetworkClassifier(builder=builder; acceleration=CUDALibs(), kwargs...) +resnet18(builder = ResNetBuilder(); kwargs...) = + NeuralNetworkClassifier(builder = builder; acceleration = CUDALibs(), kwargs...) diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl index c50ff092..597010f8 100644 --- a/experiments/models/default_models.jl +++ b/experiments/models/default_models.jl @@ -6,11 +6,12 @@ mutable struct TuningBuilder <: MLJFlux.Builder end "Outer constructor." -TuningBuilder(; n_hidden=32, n_layers=3, activation=Flux.swish) = TuningBuilder(n_hidden, n_layers, activation) +TuningBuilder(; n_hidden = 32, n_layers = 3, activation = Flux.swish) = + TuningBuilder(n_hidden, n_layers, activation) function MLJFlux.build(nn::TuningBuilder, rng, n_in, n_out) hidden = ntuple(i -> nn.n_hidden, nn.n_layers) - return MLJFlux.build(MLJFlux.MLP(hidden=hidden, σ=nn.activation), rng, n_in, n_out) + return MLJFlux.build(MLJFlux.MLP(hidden = hidden, σ = nn.activation), rng, n_in, n_out) end """ @@ -18,8 +19,13 @@ end Default builder for MLPs. """ -function default_builder(;n_hidden::Int=16, n_layers::Int=3, activation::Function=Flux.swish) - builder = TuningBuilder(n_hidden=n_hidden, n_layers=n_layers, activation=activation) +function default_builder(; + n_hidden::Int = 16, + n_layers::Int = 3, + activation::Function = Flux.swish, +) + builder = + TuningBuilder(n_hidden = n_hidden, n_layers = n_layers, activation = activation) return builder end @@ -42,55 +48,50 @@ Builds a dictionary of default models for training. """ function default_models(; sampler::AbstractSampler, - builder::MLJFlux.Builder=default_builder(), - epochs::Int=100, - batch_size::Int=128, - finaliser::Function=Flux.softmax, - loss::Function=Flux.Losses.crossentropy, - α::AbstractArray=[1.0, 1.0, 1e-1], - verbosity::Int=10, - sampling_steps::Int=30, - n_ens::Int=5, - use_ensembling::Bool=true, + builder::MLJFlux.Builder = default_builder(), + epochs::Int = 100, + batch_size::Int = 128, + finaliser::Function = Flux.softmax, + loss::Function = Flux.Losses.crossentropy, + α::AbstractArray = [1.0, 1.0, 1e-1], + verbosity::Int = 10, + sampling_steps::Int = 30, + n_ens::Int = 5, + use_ensembling::Bool = true, ) # Simple MLP: mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=finaliser, - loss=loss, - acceleration=CUDALibs(), + builder = builder, + epochs = epochs, + batch_size = batch_size, + finaliser = finaliser, + loss = loss, + acceleration = CUDALibs(), ) # Deep Ensemble: - mlp_ens = EnsembleModel(model=mlp, n=n_ens) + mlp_ens = EnsembleModel(model = mlp, n = n_ens) # Joint Energy Model: jem = JointEnergyClassifier( sampler; - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=finaliser, - loss=loss, - jem_training_params=( - α=α, verbosity=verbosity, - ), - sampling_steps=sampling_steps, + builder = builder, + epochs = epochs, + batch_size = batch_size, + finaliser = finaliser, + loss = loss, + jem_training_params = (α = α, verbosity = verbosity), + sampling_steps = sampling_steps, # acceleration=CUDALibs(), ) # Deep Ensemble of Joint Energy Models: - jem_ens = EnsembleModel(model=jem, n=n_ens) + jem_ens = EnsembleModel(model = jem, n = n_ens) # Dictionary of models: if !use_ensembling - models = Dict( - "MLP" => mlp, - "JEM" => jem, - ) + models = Dict("MLP" => mlp, "JEM" => jem) else models = Dict( "MLP" => mlp, @@ -101,4 +102,4 @@ function default_models(; end return models -end \ No newline at end of file +end diff --git a/experiments/models/models.jl b/experiments/models/models.jl index c446b524..519a3f7e 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -2,7 +2,7 @@ include("additional_models.jl") include("default_models.jl") include("train_models.jl") -function prepare_models(exper::Experiment; save_models::Bool=true) +function prepare_models(exper::Experiment; save_models::Bool = true) # Unpack data: X, labels, sampler = prepare_data(exper::Experiment) @@ -13,11 +13,17 @@ function prepare_models(exper::Experiment; save_models::Bool=true) if tuned_mlp_exists(exper) && exper.use_tuned @info "Loading tuned model architecture." # Load the best MLP: - best_mlp = Serialization.deserialize(joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls")) + best_mlp = Serialization.deserialize( + joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls"), + ) builder = best_mlp.best_model.builder else # Otherwise, use default MLP: - builder = default_builder(n_hidden=exper.n_hidden, n_layers=exper.n_layers, activation=exper.activation) + builder = default_builder( + n_hidden = exper.n_hidden, + n_layers = exper.n_layers, + activation = exper.activation, + ) end else builder = exper.builder @@ -26,16 +32,16 @@ function prepare_models(exper::Experiment; save_models::Bool=true) if isnothing(exper.models) @info "Using default models." models = default_models(; - sampler=sampler, - builder=builder, - batch_size=batch_size(exper), - sampling_steps=exper.sampling_steps, - α=exper.α, - n_ens=exper.n_ens, - use_ensembling=exper.use_ensembling, - finaliser=exper.finaliser, - loss=exper.loss, - epochs=exper.epochs, + sampler = sampler, + builder = builder, + batch_size = batch_size(exper), + sampling_steps = exper.sampling_steps, + α = exper.α, + n_ens = exper.n_ens, + use_ensembling = exper.use_ensembling, + finaliser = exper.finaliser, + loss = exper.loss, + epochs = exper.epochs, ) end # Additional models: @@ -45,10 +51,10 @@ function prepare_models(exper::Experiment; save_models::Bool=true) for (k, mod) in exper.additional_models if isa(mod, Function) add_models[k] = mod(; - batch_size=batch_size(exper), - finaliser=exper.finaliser, - loss=exper.loss, - epochs=exper.epochs, + batch_size = batch_size(exper), + finaliser = exper.finaliser, + loss = exper.loss, + epochs = exper.epochs, ) else add_models[k] = mod @@ -57,30 +63,44 @@ function prepare_models(exper::Experiment; save_models::Bool=true) models = merge(models, add_models) end @info "Training models." - model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage) + model_dict = train_models( + models, + X, + labels; + parallelizer = exper.parallelizer, + train_parallel = exper.train_parallel, + cov = exper.coverage, + ) else # Pre-trained models: if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) # Load models on root process: @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + model_dict = Serialization.deserialize( + joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), + ) else # Dummy model on other processes: model_dict = nothing end # Broadcast models: if is_multi_processed(exper) - model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0) + model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root = 0) end end # Save models: - local_models_exist = isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) - on_root_process = !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + local_models_exist = + isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) + on_root_process = + !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) if save_models && on_root_process && !local_models_exist @info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))." - Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict) + Serialization.serialize( + joinpath(exper.output_path, "$(exper.save_name)_models.jls"), + model_dict, + ) end return model_dict -end \ No newline at end of file +end diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 14d00675..14cb6b4c 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -5,7 +5,14 @@ using CounterfactualExplanations: AbstractParallelizer Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. """ -function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...) +function train_models( + models::Dict, + X, + y; + parallelizer::Union{Nothing,AbstractParallelizer} = nothing, + train_parallel::Bool = false, + kwargs..., +) verbose = is_multi_processed(parallelizer) ? false : true if is_multi_processed(parallelizer) && train_parallel # Split models into groups of approximately equal size: @@ -15,7 +22,8 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa # Train models: model_dict = Dict() for (mod_name, model) in x - model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) + model_dict[mod_name] = + _train(model, X, y; mod_name = mod_name, verbose = verbose, kwargs...) end MPI.Barrier(parallelizer.comm) output = MPI.gather(output, parallelizer.comm) @@ -26,10 +34,14 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa output = nothing end # Broadcast output to all processes: - model_dict = MPI.bcast(output, parallelizer.comm; root=0) + model_dict = MPI.bcast(output, parallelizer.comm; root = 0) MPI.Barrier(parallelizer.comm) else - model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) for (mod_name, model) in models) + model_dict = Dict( + mod_name => + _train(model, X, y; mod_name = mod_name, verbose = verbose, kwargs...) + for (mod_name, model) in models + ) end return model_dict end @@ -46,14 +58,22 @@ end Trains a model and returns a `ConformalModel` object. """ -function _train(model, X, y; cov, method=:simple_inductive, mod_name="model", verbose::Bool=true) - conf_model = conformal_model(model; method=method, coverage=cov) +function _train( + model, + X, + y; + cov, + method = :simple_inductive, + mod_name = "model", + verbose::Bool = true, +) + conf_model = conformal_model(model; method = method, coverage = cov) mach = machine(conf_model, X, y) @info "Begin training $mod_name." if verbose fit!(mach) else - fit!(mach, verbosity=0) + fit!(mach, verbosity = 0) end @info "Finished training $mod_name." M = ECCCo.ConformalModel(mach.model, mach.fitresult) @@ -67,4 +87,4 @@ Helper function to save models. """ function save_models(model_dict::Dict; save_name::String, output_path) Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) -end \ No newline at end of file +end diff --git a/experiments/moons.jl b/experiments/moons.jl index c2a83e10..809bdc40 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -1,7 +1,10 @@ # Data: dataname = "Moons" n_obs = Int(2500 / (1.0 - TEST_SIZE)) -counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE) +counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size = TEST_SIZE) + +# Domain constraints: +counterfactual_data.domain = extrema(counterfactual_data.X, dims=2) # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_SMALL @@ -12,27 +15,28 @@ tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: # These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( - use_tuned=false, - n_hidden=32, - n_layers=3, - activation=Flux.relu, - epochs=500, - sampling_batch_size=10, - sampling_steps=30, - opt=Flux.Optimise.Descent(0.01), - Λ=[0.1, 0.1, 0.5], - reg_strength=0.0, + use_tuned = false, + n_hidden = 32, + n_layers = 3, + activation = Flux.relu, + epochs = 500, + sampling_batch_size = 10, + sampling_steps = 30, + opt = Flux.Optimise.Descent(0.01), + Λ = [0.1, 0.1, 0.5], + reg_strength = 0.0, ) # Best grid search params: -append_best_params!(params, dataname) +params = append_best_params(params, dataname) if GRID_SEARCH grid_search( - counterfactual_data, test_data; - dataname=dataname, - tuning_params=tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + tuning_params = tuning_params, + params..., ) elseif FROM_GRID_SEARCH outcomes_file_path = joinpath( @@ -44,9 +48,10 @@ elseif FROM_GRID_SEARCH bmk2csv(dataname) else run_experiment( - counterfactual_data, test_data; - dataname=dataname, - model_tuning_params=model_tuning_params, - params... + counterfactual_data, + test_data; + dataname = dataname, + model_tuning_params = model_tuning_params, + params..., ) -end \ No newline at end of file +end diff --git a/experiments/notebooks/plots.qmd b/experiments/notebooks/plots.qmd new file mode 100644 index 00000000..dd207450 --- /dev/null +++ b/experiments/notebooks/plots.qmd @@ -0,0 +1,60 @@ +--- +format: pdf +--- + +# Plots + +```{julia} +using Pkg; Pkg.activate("experiments") +include("$(pwd())/experiments/setup_env.jl") +``` + +## Counterfactual Path - MNIST + +```{julia} +data_name = "mnist" +plt_order = ["MLP", "MLP Ensemble", "LeNet-5", "JEM", "JEM Ensemble"] +fixed_factual = 3 +fixed_target = 9 +models = Serialization.deserialize("models/$(data_name)_models.jls") +data = eval(Meta.parse("load_$(data_name)()")) +plt_order = plt_order[[x in collect(keys(models)) for x in plt_order]] +``` + +```{julia} +using Plots.PlotMeasures +n_samp = 100 +n_rand = 500 +σ = 0.1 +plts = [] +for (mod_name, model) in models + Δ = [] + L = [] + for i in 1:n_rand + factual = isnothing(fixed_factual) ? rand(data.y_levels) : fixed_factual + target = isnothing(fixed_target) ? rand(data.y_levels[data.y_levels .!= factual]) : fixed_target + t = get_target_index(data.y_levels, target) + E(x) = -logits(model, x)[t, :] + x_samp = data.X[:,rand(findall(data.output_encoder.labels.==target),n_samp)] + x_rand = data.X[:,rand(findall(data.output_encoder.labels.==factual),1)] + x_rand .+= Float32.(randn(size(x_samp,1),1)) .* σ + δ = mean(map(y -> norm(x_rand.-y),eachcol(x_samp))) + push!(Δ,δ) + l = E(x_rand)[1] + push!(L,l) + end + plt = scatter( + Δ, L; + label="", title=mod_name, smooth=:true, + lc=:red, lw=2 + ) + push!(plts, plt) +end +width = 1000 +plt = plot( + plts[sortperm(collect(keys(models)))[invperm(sortperm(plt_order))]]..., + layout=(1,length(models)), size=(width,round(1.0width)/length(models)), + ticks=false +) +savefig(plt, "www/dist_energy.png") +``` \ No newline at end of file diff --git a/experiments/post_processing/artifacts.jl b/experiments/post_processing/artifacts.jl index 2912f60e..c92459b3 100644 --- a/experiments/post_processing/artifacts.jl +++ b/experiments/post_processing/artifacts.jl @@ -17,12 +17,12 @@ using Serialization Uploads results to github releases. If `deploy=true`, then the results will be uploaded to a github release. If `deploy=false`, then the results will be saved locally. """ function generate_artifacts( - datafiles=DEFAULT_OUTPUT_PATH; - artifact_name=nothing, - root=".", - artifact_toml=LazyArtifacts.find_artifacts_toml("."), - deploy=true, - tag=nothing + datafiles = DEFAULT_OUTPUT_PATH; + artifact_name = nothing, + root = ".", + artifact_toml = LazyArtifacts.find_artifacts_toml("."), + deploy = true, + tag = nothing, ) # Artifact name: @@ -45,7 +45,7 @@ function generate_artifacts( # Try to detect where we should upload these weights to (or just override # as shown in the commented-out line) - origin_url = get_git_remote_url(root) + origin_url = replace(get_git_remote_url(root), ".git" => "") deploy_repo = "$(basename(dirname(origin_url)))/$(basename(origin_url))" end @@ -67,9 +67,9 @@ function generate_artifacts( artifact_toml, artifact_name, hash; - download_info=[(tarball_url, tarball_hash)], - lazy=true, - force=true + download_info = [(tarball_url, tarball_hash)], + lazy = true, + force = true, ) end @@ -89,7 +89,7 @@ function generate_artifacts( end end -function get_git_remote_url(repo_path::String=".") +function get_git_remote_url(repo_path::String = ".") repo = LibGit2.GitRepo(repo_path) origin = LibGit2.get(LibGit2.GitRemote, repo, "origin") return LibGit2.url(origin) diff --git a/experiments/post_processing/hypothesis_tests.jl b/experiments/post_processing/hypothesis_tests.jl new file mode 100644 index 00000000..e69de29b diff --git a/experiments/post_processing/meta_data.jl b/experiments/post_processing/meta_data.jl index 06585632..0acd94c3 100644 --- a/experiments/post_processing/meta_data.jl +++ b/experiments/post_processing/meta_data.jl @@ -1,13 +1,22 @@ """ - meta(exper::Experiment) + all_meta(exper::Experiment) Extract and save meta data about the experiment. """ -function meta(outcome::ExperimentOutcome; save_output::Bool=false, params_path::Union{Nothing,String}=nothing) - - model_params = meta_model(outcome; save_output=save_output, params_path=params_path) - model_performance = meta_model_performance(outcome; save_output=save_output, params_path=params_path) - generator_params = meta_generators(outcome; save_output=save_output, params_path=params_path) +function all_meta( + outcome::ExperimentOutcome; + save_output::Bool = false, + params_path::Union{Nothing,String} = nothing, +) + + model_params = meta_model(outcome; save_output = save_output, params_path = params_path) + model_performance = meta_model_performance( + outcome; + save_output = save_output, + params_path = params_path, + ) + generator_params = + meta_generators(outcome; save_output = save_output, params_path = params_path) return model_params, model_performance, generator_params @@ -18,7 +27,11 @@ end Extract and save meta data about the data and models in `outcome.model_dict`. """ -function meta_model(outcome::ExperimentOutcome; save_output::Bool=false, params_path::Union{Nothing,String}=nothing) +function meta_model( + outcome::ExperimentOutcome; + save_output::Bool = false, + params_path::Union{Nothing,String} = nothing, +) # Unpack: exper = outcome.exper @@ -38,7 +51,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false, params_ :n_ens => exper.n_ens, :lambda => string(exper.α[3]), :jem_sampling_steps => exper.sampling_steps, - ) + ), ) if save_output @@ -52,7 +65,11 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false, params_ end -function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false, params_path::Union{Nothing,String}=nothing) +function meta_generators( + outcome::ExperimentOutcome; + save_output::Bool = false, + params_path::Union{Nothing,String} = nothing, +) # Unpack: exper = outcome.exper @@ -66,7 +83,7 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false, pa generator_params = DataFrame( Dict( :opt => string(typeof(opt)), - :eta => opt.eta, + :eta => get_learning_rate(opt), :dataname => exper.dataname, :lambda_1 => string(Λ[1]), :lambda_2 => string(Λ[2]), @@ -76,7 +93,7 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false, pa :lambda_3_Δ => string(Λ_Δ[3]), :n_individuals => exper.n_individuals, :reg_strength => string(reg_strengh), - ) + ), ) if save_output @@ -94,7 +111,12 @@ end Compute and save the model performance for the models in `outcome.model_dict`. """ -function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Nothing,Dict}=nothing, save_output::Bool=false, params_path::Union{Nothing,String}=nothing) +function meta_model_performance( + outcome::ExperimentOutcome; + measures::Union{Nothing,Dict} = nothing, + save_output::Bool = false, + params_path::Union{Nothing,String} = nothing, +) # Unpack: exper = outcome.exper @@ -105,7 +127,11 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth model_performance = DataFrame() for (mod_name, model) in model_dict # Test performance: - _perf = CounterfactualExplanations.Models.model_evaluation(model, exper.test_data, measure=collect(values(measures))) + _perf = CounterfactualExplanations.Models.model_evaluation( + model, + exper.test_data, + measure = collect(values(measures)), + ) _perf = DataFrame([[p] for p in _perf], collect(keys(measures))) _perf.mod_name .= mod_name _perf.dataname .= exper.dataname @@ -127,4 +153,3 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth return model_performance end - diff --git a/experiments/post_processing/plotting.jl b/experiments/post_processing/plotting.jl index 46d36c76..45bfdd67 100644 --- a/experiments/post_processing/plotting.jl +++ b/experiments/post_processing/plotting.jl @@ -1,6 +1,11 @@ using Plots -function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5", img_height=125, seed=966) +function choose_random_mnist( + outcome::ExperimentOutcome; + model::String = "LeNet-5", + img_height = 125, + seed = 966, +) # Set seed: if !isnothing(seed) @@ -9,27 +14,33 @@ function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5" # Get output: bmk = outcome.bmk() - grouped_bmk = groupby(bmk[bmk.variable.=="distance" .&& bmk.model.==model,:], [:dataname, :target, :factual]) + grouped_bmk = groupby( + bmk[bmk.variable.=="distance".&&bmk.model.==model, :], + [:dataname, :target, :factual], + ) random_choice = rand(1:length(grouped_bmk)) generators = unique(bmk.generator) n_generators = length(generators) # Get data: - df = grouped_bmk[random_choice][1:n_generators, :] |> - x -> sort(x, :generator) |> - x -> subset(x, :generator => ByRow(x -> x != "ECCCo")) + df = + grouped_bmk[random_choice][1:n_generators, :] |> + x -> sort(x, :generator) |> x -> subset(x, :generator => ByRow(x -> x != "ECCCo")) generators = df.generator replace!(generators, "ECCCo-Δ" => "ECCCo") replace!(generators, "ECCCo-Δ (latent)" => "ECCCo+") n_generators = length(generators) # Factual: - img = CounterfactualExplanations.factual(grouped_bmk[random_choice][1:n_generators,:].ce[1]) |> ECCCo.convert2mnist + img = + CounterfactualExplanations.factual( + grouped_bmk[random_choice][1:n_generators, :].ce[1], + ) |> ECCCo.convert2mnist p1 = Plots.plot( img, - axis=([], false), - size=(img_height, img_height), - title="Factual", + axis = ([], false), + size = (img_height, img_height), + title = "Factual", ) plts = [p1] ces = [] @@ -40,9 +51,9 @@ function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5" img = CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist p = Plots.plot( img, - axis=([], false), - size=(img_height, img_height), - title="$generator", + axis = ([], false), + size = (img_height, img_height), + title = "$generator", ) push!(plts, p) push!(ces, ce) @@ -50,9 +61,9 @@ function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5" plt = Plots.plot( plts..., - layout=(1, n_generators + 1), - size=(img_height * (n_generators + 1), img_height), - dpi=300 + layout = (1, n_generators + 1), + size = (img_height * (n_generators + 1), img_height), + dpi = 300, ) display(plt) @@ -60,7 +71,13 @@ function choose_random_mnist(outcome::ExperimentOutcome; model::String="LeNet-5" end -function plot_random_eccco(outcome::ExperimentOutcome; ce=nothing, generator="ECCCo-Δ", img_height=200, seed=966) +function plot_random_eccco( + outcome::ExperimentOutcome; + ce = nothing, + generator = "ECCCo-Δ", + img_height = 200, + seed = 966, +) # Set seed: if !isnothing(seed) Random.seed!(seed) @@ -79,20 +96,28 @@ function plot_random_eccco(outcome::ExperimentOutcome; ce=nothing, generator="EC img = CounterfactualExplanations.factual(ce) |> ECCCo.convert2mnist p1 = Plots.plot( img, - axis=([], false), - size=(img_height, img_height), - title="Factual", + axis = ([], false), + size = (img_height, img_height), + title = "Factual", ) plts = [p1] for (model_name, M) in models - ce = generate_counterfactual(x, target, data, M, gen; initialization=:identity, converge_when=:generator_conditions) + ce = generate_counterfactual( + x, + target, + data, + M, + gen; + initialization = :identity, + converge_when = :generator_conditions, + ) img = CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist p = Plots.plot( img, - axis=([], false), - size=(img_height, img_height), - title="$model_name", + axis = ([], false), + size = (img_height, img_height), + title = "$model_name", ) push!(plts, p) end @@ -100,16 +125,23 @@ function plot_random_eccco(outcome::ExperimentOutcome; ce=nothing, generator="EC plt = Plots.plot( plts..., - layout=(1, n_models + 1), - size=(img_height * (n_models + 1), img_height), - dpi=300 + layout = (1, n_models + 1), + size = (img_height * (n_models + 1), img_height), + dpi = 300, ) display(plt) return plt, target, seed end -function plot_all_mnist(gen, model, data=load_mnist_test(); img_height=150, seed=123, maxoutdim=64) +function plot_all_mnist( + gen, + model, + data = load_mnist_test(); + img_height = 150, + seed = 123, + maxoutdim = 64, +) # Set seed: if !isnothing(seed) @@ -117,8 +149,8 @@ function plot_all_mnist(gen, model, data=load_mnist_test(); img_height=150, seed end # Dimensionality reduction: - data.dt = MultivariateStats.fit(MultivariateStats.PCA, data.X; maxoutdim=maxoutdim) - + data.dt = MultivariateStats.fit(MultivariateStats.PCA, data.X; maxoutdim = maxoutdim) + # VAE for REVISE: data.generative_model = CounterfactualExplanations.Models.load_mnist_vae() @@ -133,21 +165,26 @@ function plot_all_mnist(gen, model, data=load_mnist_test(); img_height=150, seed if factual != target @info "Generating counterfactual for $(factual) -> $(target)" ce = generate_counterfactual( - x, target, data, model, gen; - initialization=:identity, converge_when=:generator_conditions + x, + target, + data, + model, + gen; + initialization = :identity, + converge_when = :generator_conditions, ) plt = Plots.plot( CounterfactualExplanations.counterfactual(ce) |> ECCCo.convert2mnist, - axis=([], false), - size=(img_height, img_height), - title="$factual → $target", + axis = ([], false), + size = (img_height, img_height), + title = "$factual → $target", ) else plt = Plots.plot( x |> ECCCo.convert2mnist, - axis=([], false), - size=(img_height, img_height), - title="Factual", + axis = ([], false), + size = (img_height, img_height), + title = "Factual", ) end push!(plts, plt) @@ -156,9 +193,9 @@ function plot_all_mnist(gen, model, data=load_mnist_test(); img_height=150, seed plt = Plots.plot( plts..., - layout=(length(factuals), length(targets)), - size=(img_height * length(targets), img_height * length(factuals)), - dpi=300 + layout = (length(factuals), length(targets)), + size = (img_height * length(targets), img_height * length(factuals)), + dpi = 300, ) return plt @@ -167,7 +204,7 @@ end using MLDatasets using MosaicViews -function vae_reconstructions(seed=123) +function vae_reconstructions(seed = 123) # Set seed: if !isnothing(seed) @@ -175,30 +212,32 @@ function vae_reconstructions(seed=123) end counterfactual_data = load_mnist() - counterfactual_data.generative_model = CounterfactualExplanations.Models.load_mnist_vae() + counterfactual_data.generative_model = + CounterfactualExplanations.Models.load_mnist_vae() X = counterfactual_data.X - y = counterfactual_data.output_encoder.y + y = counterfactual_data.output_encoder.y images = [] rec_images = [] - for i in 0:9 + for i = 0:9 j = 0 while j < 10 - x = X[:,rand(findall(y .== i))] - x̂ = CounterfactualExplanations.GenerativeModels.reconstruct(vae, x)[1] |> - x̂ -> clamp.((x̂ .+ 1.0) ./ 2.0, 0.0, 1.0) |> - x̂ -> reshape(x̂, 28,28) |> - x̂ -> MLDatasets.convert2image(MNIST, x̂) - x = clamp.((x .+ 1.0) ./ 2.0, 0.0, 1.0) |> - x -> reshape(x, 28,28) |> - x -> MLDatasets.convert2image(MNIST, x) + x = X[:, rand(findall(y .== i))] + x̂ = + CounterfactualExplanations.GenerativeModels.reconstruct(vae, x)[1] |> + x̂ -> + clamp.((x̂ .+ 1.0) ./ 2.0, 0.0, 1.0) |> + x̂ -> reshape(x̂, 28, 28) |> x̂ -> MLDatasets.convert2image(MNIST, x̂) + x = + clamp.((x .+ 1.0) ./ 2.0, 0.0, 1.0) |> + x -> reshape(x, 28, 28) |> x -> MLDatasets.convert2image(MNIST, x) push!(images, x) push!(rec_images, x̂) j += 1 end end - p1 = plot(mosaic(images..., ncol=10), title="Images") - p2 = plot(mosaic(rec_images..., ncol=10), title="Reconstructions") - plt = plot(p1, p2, axis=false, size=(800,375)) + p1 = plot(mosaic(images..., ncol = 10), title = "Images") + p2 = plot(mosaic(rec_images..., ncol = 10), title = "Reconstructions") + plt = plot(p1, p2, axis = false, size = (800, 375)) return plt -end \ No newline at end of file +end diff --git a/experiments/post_processing/post_processing.jl b/experiments/post_processing/post_processing.jl index a79b23b0..85c4105f 100644 --- a/experiments/post_processing/post_processing.jl +++ b/experiments/post_processing/post_processing.jl @@ -1,4 +1,4 @@ include("meta_data.jl") include("artifacts.jl") include("results.jl") -include("plotting.jl") \ No newline at end of file +include("plotting.jl") diff --git a/experiments/post_processing/results.jl b/experiments/post_processing/results.jl index 98b81fab..81d291cb 100644 --- a/experiments/post_processing/results.jl +++ b/experiments/post_processing/results.jl @@ -3,14 +3,31 @@ Helper function to quickly filter a benchmark table for the distance from targets: the smaller this distance, the higher the plausibility. """ -function summarise_outcome(outcome::ExperimentOutcome; measure::Union{Nothing,AbstractArray}=nothing, model::Union{Nothing,AbstractArray}=nothing) +function summarise_outcome( + outcome::ExperimentOutcome; + measure::Union{Nothing,AbstractArray} = nothing, + model::Union{Nothing,AbstractArray} = nothing, +) bmk = outcome.bmk measure = isnothing(measure) ? unique(bmk().variable) : measure - - df = groupby(bmk(), [:dataname, :generator, :model, :variable]) |> - x -> combine(x, :value => mean => :mean, :value => std => :std) |> - x -> subset(x, :variable => ByRow(x -> x ∈ measure)) + df = bmk() + # If the :run column is missing (single runs), add it: + if !("run" ∈ names(df)) + df.run .= 1 + end + # Aggregate per run: + df = + groupby(df, [:dataname, :generator, :model, :run, :variable]) |> + x -> + combine(x, :value => mean => :mean_group, :value => std => :std_group) |> + x -> subset(x, :variable => ByRow(x -> x ∈ measure)) + # Compute mean and std across runs: + df = + groupby(df, [:dataname, :generator, :model, :variable]) |> + x -> + combine(x, :mean_group => mean => :mean, :mean_group => std => :std) + # Subset: if !isnothing(model) df = subset(df, :model => ByRow(x -> x ∈ model)) end @@ -23,56 +40,64 @@ end Helper function to quickly filter a benchmark table for the distance from targets: the smaller this distance, the higher the plausibility. """ -plausibility(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["distance_from_targets_l2"], kwrgs...) +plausibility(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["distance_from_targets_l2"], kwrgs...) """ plausibility_image(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the distance from targets: the smaller this distance, the higher the plausibility. """ -plausibility_image(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["distance_from_targets_ssim"], kwrgs...) +plausibility_image(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["distance_from_targets_ssim"], kwrgs...) """ faithfulness(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the distance from energy: the smaller this distance, the higher the faithfulness. """ -faithfulness(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["distance_from_energy_l2"], kwrgs...) +faithfulness(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["distance_from_energy_l2"], kwrgs...) """ faithfulness_image(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the distance from energy: the smaller this distance, the higher the faithfulness. """ -faithfulness_image(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["distance_from_energy_ssim"], kwrgs...) +faithfulness_image(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["distance_from_energy_ssim"], kwrgs...) """ closeness(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the distance from the factual: the smaller this distance, the higher the closeness desideratum. """ -closeness(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["distance"], kwrgs...) +closeness(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["distance"], kwrgs...) """ validity(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the validity: the higher this value, the higher the validity. """ -validity(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["validity"], kwrgs...) +validity(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["validity"], kwrgs...) """ redundancy(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the redundancy: the higher this value, the higher the redundancy. """ -redundancy(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["redundancy"], kwrgs...) +redundancy(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["redundancy"], kwrgs...) """ uncertainty(outcome::ExperimentOutcome) Helper function to quickly filter a benchmark table for the uncertainty: the higher this value, the higher the uncertainty. """ -uncertainty(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, measure=["set_size_penalty"], kwrgs...) +uncertainty(outcome::ExperimentOutcome; kwrgs...) = + summarise_outcome(outcome, measure = ["set_size_penalty"], kwrgs...) """ generator_rank(outcome::ExperimentOutcome; generator::Union{AbstractArray,Nothing}=nothing, measure::Union{AbstractArray,Nothing}=nothing, model::Union{Nothing,String}=nothing) @@ -80,10 +105,10 @@ uncertainty(outcome::ExperimentOutcome; kwrgs...) = summarise_outcome(outcome, m Computes the average rank of a generator across all datasets and variables. """ function generator_rank( - outcome::ExperimentOutcome; - generator::Union{AbstractArray,Nothing}=nothing, - measure::Union{AbstractArray,Nothing}=nothing, - model::Union{Nothing,AbstractArray}=nothing + outcome::ExperimentOutcome; + generator::Union{AbstractArray,Nothing} = nothing, + measure::Union{AbstractArray,Nothing} = nothing, + model::Union{Nothing,AbstractArray} = nothing, ) # Setup: @@ -92,24 +117,31 @@ function generator_rank( measure = isnothing(measure) ? unique(bmk().variable) : measure # Compute: - results = summarise_outcome(outcome, measure=measure, model=model) + results = summarise_outcome(outcome, measure = measure, model = model) # Adjust variables for which higher is better: higher_is_better = [var ∈ ["validity", "redundancy"] for var in results.variable] - results.mean[higher_is_better] .= - results.mean[higher_is_better] + results.mean[higher_is_better] .= -results.mean[higher_is_better] # Compute ranks: - ranked_results = groupby(results, [:dataname, :model, :variable]) |> - x -> combine(x, :mean => sortperm => :rank, :generator) |> - x -> subset(x, :generator => ByRow(x -> x ∈ generator)) |> - x -> groupby(x, [:dataname, :generator, :variable]) |> - x -> combine(x, :rank => mean => :avg_rank) |> - x -> subset(x, :variable => ByRow(x -> x ∈ measure)) + ranked_results = + groupby(results, [:dataname, :model, :variable]) |> + x -> + combine(x, :mean => sortperm => :rank, :generator) |> + x -> + subset(x, :generator => ByRow(x -> x ∈ generator)) |> + x -> + groupby(x, [:dataname, :generator, :variable]) |> + x -> + combine(x, :rank => mean => :avg_rank) |> + x -> subset(x, :variable => ByRow(x -> x ∈ measure)) sort!(ranked_results, [:variable, :avg_rank]) return ranked_results end -generator_rank_plausibility(outcome::ExperimentOutcome; kwrgs...) = generator_rank(outcome, measure=["distance_from_targets_l2"], kwrgs...) - -generator_rank_faithfulness(outcome::ExperimentOutcome; kwrgs...) = generator_rank(outcome, measure=["distance_from_energy_l2"], kwrgs...) +generator_rank_plausibility(outcome::ExperimentOutcome; kwrgs...) = + generator_rank(outcome, measure = ["distance_from_targets_l2"], kwrgs...) -generator_rank_closeness(outcome::ExperimentOutcome; kwrgs...) = generator_rank(outcome, measure=["distance"], kwrgs...) +generator_rank_faithfulness(outcome::ExperimentOutcome; kwrgs...) = + generator_rank(outcome, measure = ["distance_from_energy_l2"], kwrgs...) +generator_rank_closeness(outcome::ExperimentOutcome; kwrgs...) = + generator_rank(outcome, measure = ["distance"], kwrgs...) diff --git a/experiments/run_experiments.jl b/experiments/run_experiments.jl index ecf1d226..b0183cb0 100644 --- a/experiments/run_experiments.jl +++ b/experiments/run_experiments.jl @@ -1,11 +1,21 @@ include("setup_env.jl"); # User inputs: -all_data_sets = ["linearly_separable", "moons", "circles", "mnist", "fmnist", "gmsc", "german_credit", "california_housing"] +all_data_sets = [ + "linearly_separable", + "moons", + "circles", + "mnist", + "fmnist", + "gmsc", + "german_credit", + "california_housing", +] if "run-all" in ARGS datanames = all_data_sets elseif any(contains.(ARGS, "data=")) - datanames = [ARGS[findall(contains.(ARGS, "data="))][1] |> x -> replace(x, "data=" => "")] + datanames = + [ARGS[findall(contains.(ARGS, "data="))][1] |> x -> replace(x, "data=" => "")] datanames = replace.(split(datanames[1], ","), " " => "") else @warn "No dataset specified, defaulting to all." @@ -42,12 +52,6 @@ if "german_credit" in datanames include("german_credit.jl") end -# Credit Default -if "credit_default" in datanames - @info "Running Credit Default experiment." - include("credit_default.jl") -end - # California Housing if "california_housing" in datanames @info "Running California Housing experiment." diff --git a/experiments/save_best.jl b/experiments/save_best.jl index eb09cc1a..977e8720 100644 --- a/experiments/save_best.jl +++ b/experiments/save_best.jl @@ -11,26 +11,31 @@ function save_best(outcomes_file_path::String) # Save data: output_path = replace(exper.output_path, "grid_search" => "") params_path = joinpath(output_path, "params") - Serialization.serialize(joinpath(output_path, "$(exper.save_name)_outcome.jls"), outcome) - Serialization.serialize(joinpath(output_path, "$(exper.save_name)_bmk.jls"), outcome.bmk) - Serialization.serialize(joinpath(output_path, "$(exper.save_name)_models.jls"), outcome.model_dict) - meta(outcome; save_output=true, params_path=params_path) + Serialization.serialize( + joinpath(output_path, "$(exper.save_name)_outcome.jls"), + outcome, + ) + Serialization.serialize( + joinpath(output_path, "$(exper.save_name)_bmk.jls"), + outcome.bmk, + ) + Serialization.serialize( + joinpath(output_path, "$(exper.save_name)_models.jls"), + outcome.model_dict, + ) + all_meta(outcome; save_output = true, params_path = params_path) end function bmk2csv(dataname::String) - bmk_path = joinpath( - DEFAULT_OUTPUT_PATH, - "$(replace(lowercase(dataname), " " => "_"))_bmk.jls", - ) + bmk_path = + joinpath(DEFAULT_OUTPUT_PATH, "$(replace(lowercase(dataname), " " => "_"))_bmk.jls") bmk = Serialization.deserialize(bmk_path) - csv_path = joinpath( - DEFAULT_OUTPUT_PATH, - "$(replace(lowercase(dataname), " " => "_"))_bmk.csv", - ) + csv_path = + joinpath(DEFAULT_OUTPUT_PATH, "$(replace(lowercase(dataname), " " => "_"))_bmk.csv") bmk = bmk() if "ce" ∈ names(bmk) - CSV.write(csv_path, bmk[:,Not(:ce)]) + CSV.write(csv_path, bmk[:, Not(:ce)]) else CSV.write(csv_path, bmk) end -end \ No newline at end of file +end diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index d61ad56d..e4f42c68 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -6,7 +6,8 @@ using CounterfactualExplanations.Data using CounterfactualExplanations.DataPreprocessing: train_test_split using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark using CounterfactualExplanations.Generators: JSMADescent -using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs +using CounterfactualExplanations.Models: + load_mnist_mlp, load_fashion_mnist_mlp, train, probs using CounterfactualExplanations.Objectives using CounterfactualExplanations.Parallelization using CSV @@ -15,12 +16,14 @@ using DataFrames using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux +using Flux.Optimise: Optimiser, Descent, Adam, ClipValue using JointEnergyModels using LazyArtifacts using Logging using Metalhead using MLJ: TunedModel, Grid, CV, fitted_params, report -using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!, Supervised +using MLJBase: + multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!, Supervised using MLJEnsembles using MLJFlux using Random @@ -45,6 +48,48 @@ include("post_processing/post_processing.jl") include("utils.jl") include("save_best.jl") +# Number of counterfactuals: +n_ind_specified = false +if any(contains.(ARGS, "n_individuals=")) + n_ind_specified = true + n_individuals = + ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> + x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x) +else + n_individuals = 100 +end + +"Number of individuals to use in benchmarking." +const N_IND = n_individuals + +"Boolean flag to check if number of individuals was specified." +const N_IND_SPECIFIED = n_ind_specified + +# Number of tasks per process: +if any(contains.(ARGS, "n_each=")) + n_each = + ARGS[findall(contains.(ARGS, "n_each="))][1] |> + x -> replace(x, "n_each=" => "") |> + x -> x == "nothing" ? nothing : parse(Int, x) +else + n_each = 32 +end + +"Number of objects to pass to each process." +const N_EACH = n_each + +# Number of benchmark runs: +if any(contains.(ARGS, "n_runs=")) + n_runs = + ARGS[findall(contains.(ARGS, "n_runs="))][1] |> + x -> replace(x, "n_runs=" => "") |> x -> parse(Int, x) +else + n_runs = 1 +end + +"Number of benchmark runs." +const N_RUNS = n_runs + # Parallelization: plz = nothing @@ -58,13 +103,16 @@ end if "mpi" ∈ ARGS MPI.Init() const USE_MPI = true - plz = MPIParallelizer(MPI.COMM_WORLD; threaded=USE_THREADS) + plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS, n_each = N_EACH) if MPI.Comm_rank(MPI.COMM_WORLD) != 0 global_logger(NullLogger()) else @info "Multi-processing using MPI. Disabling logging on non-root processes." if USE_THREADS @info "Multi-threading using $(Threads.nthreads()) threads." + if Threads.threadid() != 1 + global_logger(NullLogger()) + end end end else @@ -83,7 +131,9 @@ const LATEST_ARTIFACT_PATH = joinpath(artifact_path(ARTIFACT_HASH), ARTIFACT_NAM time_stamped = false if any(contains.(ARGS, "output_path")) @assert sum(contains.(ARGS, "output_path")) == 1 "Only one output path can be specified." - _path = ARGS[findall(contains.(ARGS, "output_path"))][1] |> x -> replace(x, "output_path=" => "") + _path = + ARGS[findall(contains.(ARGS, "output_path"))][1] |> + x -> replace(x, "output_path=" => "") elseif isinteractive() @info "You are running experiments interactively. By default, results will be saved in a temporary directory." _path = tempdir() @@ -98,16 +148,16 @@ const DEFAULT_OUTPUT_PATH = _path const TIME_STAMPED = time_stamped "Boolean flag to only train models." -const ONLY_MODELS = "only_models" ∈ ARGS +const ONLY_MODELS = "only_models" ∈ ARGS "Boolean flag to retrain models." -const RETRAIN = "retrain" ∈ ARGS || ONLY_MODELS +const RETRAIN = "retrain" ∈ ARGS || ONLY_MODELS "Default model performance measures." const MODEL_MEASURES = Dict( :f1score => multiclass_f1score, :acc => accuracy, - :precision => multiclass_precision + :precision => multiclass_precision, ) "Default coverage rate." @@ -122,7 +172,7 @@ const CE_MEASURES = [ ECCCo.distance_from_targets_l2, CounterfactualExplanations.Evaluation.validity, CounterfactualExplanations.Evaluation.redundancy, - ECCCo.set_size_penalty + ECCCo.set_size_penalty, ] "Test set proportion." @@ -131,70 +181,57 @@ const TEST_SIZE = 0.2 "Boolean flag to check if upload was specified." const UPLOAD = "upload" ∈ ARGS -n_ind_specified = false -if any(contains.(ARGS, "n_individuals=")) - n_ind_specified = true - n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x) -else - n_individuals = 100 -end - -"Number of individuals to use in benchmarking." -const N_IND = n_individuals - -"Boolean flag to check if number of individuals was specified." -const N_IND_SPECIFIED = n_ind_specified - "Boolean flag to check if grid search was specified." const GRID_SEARCH = "grid_search" ∈ ARGS "Generator tuning parameters." DEFAULT_GENERATOR_TUNING = ( - Λ=[ - [0.1, 0.1, 0.05], - [0.1, 0.1, 0.1], - [0.1, 0.1, 0.5], - [0.1, 0.1, 1.0], - ], - reg_strength=[0.0, 0.1, 0.25, 0.5, 1.0], - opt=[ - Flux.Optimise.Descent(0.1), - Flux.Optimise.Descent(0.05), - Flux.Optimise.Descent(0.01), + Λ=[[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.1, 0.1, 0.5],], + reg_strength = [0.0, 0.1, 0.5], + opt = [ + Descent(0.01), + Descent(0.05), ], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Generator tuning parameters for large datasets." DEFAULT_GENERATOR_TUNING_LARGE = ( - Λ=[ - [0.1, 0.1, 0.1], - [0.1, 0.1, 0.2], - [0.2, 0.2, 0.2], - ], + Λ=[[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.1, 0.1, 0.5],], reg_strength=[0.0, 0.1, 0.5], - opt=[ - Flux.Optimise.Descent(0.01), - Flux.Optimise.Descent(0.05), + opt = [ + Descent(0.01), + Descent(0.05), ], + decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Boolean flag to check if model tuning was specified." const TUNE_MODEL = "tune_model" ∈ ARGS "Model tuning parameters for small datasets." -DEFAULT_MODEL_TUNING_SMALL = ( - n_hidden=[16, 32, 64], - n_layers=[1, 2, 3], -) +DEFAULT_MODEL_TUNING_SMALL = (n_hidden = [16, 32, 64], n_layers = [1, 2, 3]) "Model tuning parameters for large datasets." -DEFAULT_MODEL_TUNING_LARGE = ( - n_hidden=[32, 64, 128, 512], - n_layers=[2, 3, 5], -) +DEFAULT_MODEL_TUNING_LARGE = (n_hidden = [32, 64, 128, 512], n_layers = [2, 3, 5]) "Boolean flag to check if store counterfactual explanations was specified." STORE_CE = "store_ce" ∈ ARGS "Boolean flag to chech if best outcome from grid search should be used." FROM_GRID_SEARCH = "from_grid" ∈ ARGS + +# Vertical splits for benchmarking: +if any(contains.(ARGS, "vertical_splits")) + @assert sum(contains.(ARGS, "vertical_splits")) == 1 "`vertical_splits` is specified more than once." + n_splits = + ARGS[findall(contains.(ARGS, "vertical_splits"))][1] |> + x -> + replace(x, "vertical_splits=" => "") |> + x -> parse(Int, x) +else + n_splits = nothing +end + +"Number of vertical splits." +VERTICAL_SPLITS = n_splits diff --git a/experiments/slurm_header.sh b/experiments/slurm_header.sh new file mode 100644 index 00000000..48da377b --- /dev/null +++ b/experiments/slurm_header.sh @@ -0,0 +1,3 @@ +set -x # keep log of executed commands +export SRUN_CPUS_PER_TASK="$SLURM_CPUS_PER_TASK" # assign extra environment variable to be safe +export OPENBLAS_NUM_THREADS=1 # avoid that OpenBLAS calls too many threads \ No newline at end of file diff --git a/experiments/upload_artifacts.jl b/experiments/upload_artifacts.jl index 83db07f1..72d16295 100644 --- a/experiments/upload_artifacts.jl +++ b/experiments/upload_artifacts.jl @@ -1,2 +1,2 @@ include("setup_env.jl"); -generate_artifacts() \ No newline at end of file +generate_artifacts() diff --git a/experiments/utils.jl b/experiments/utils.jl index 9b6b8834..5b77538c 100644 --- a/experiments/utils.jl +++ b/experiments/utils.jl @@ -1,11 +1,17 @@ using CounterfactualExplanations.Parallelization: ThreadsParallelizer +using Distributions: Uniform +using Flux using LinearAlgebra: norm +using Statistics: mean, std function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer}) if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer) return false else - return isa(parallelizer, Base.get_extension(CounterfactualExplanations, :MPIExt).MPIParallelizer) + return isa( + parallelizer, + Base.get_extension(CounterfactualExplanations, :MPIExt).MPIParallelizer, + ) end end @@ -18,7 +24,19 @@ function min_max_scale(x::AbstractArray) end function standardize(x::AbstractArray) - x_norm = (x .- sum(x)/length(x)) ./ std(x) + x_norm = (x .- sum(x) / length(x)) ./ std(x) x_norm = replace(x_norm, NaN => 0.0) return x_norm +end + +function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser) + if hasfield(typeof(opt), :eta) + return opt.eta + elseif hasfield(typeof(opt), :os) + _os = opt.os + opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1] + return opt.eta + else + throw(ArgumentError("Cannot find learning rate.")) + end end \ No newline at end of file diff --git a/models/california_housing_models.jls b/models/california_housing_models.jls index 24c49c41..ae975c60 100644 Binary files a/models/california_housing_models.jls and b/models/california_housing_models.jls differ diff --git a/models/circles_models.jls b/models/circles_models.jls index 9ec8dce2..bc09cc6b 100644 Binary files a/models/circles_models.jls and b/models/circles_models.jls differ diff --git a/models/german_credit_models.jls b/models/german_credit_models.jls index f7c13d75..ccc586c8 100644 Binary files a/models/german_credit_models.jls and b/models/german_credit_models.jls differ diff --git a/models/gmsc_models.jls b/models/gmsc_models.jls index 6742fe45..556d82c3 100644 Binary files a/models/gmsc_models.jls and b/models/gmsc_models.jls differ diff --git a/models/linearly_separable_models.jls b/models/linearly_separable_models.jls index c296f660..3e4cbb29 100644 Binary files a/models/linearly_separable_models.jls and b/models/linearly_separable_models.jls differ diff --git a/models/moons_models.jls b/models/moons_models.jls index 2884389c..fac66291 100644 Binary files a/models/moons_models.jls and b/models/moons_models.jls differ diff --git a/paper/aaai/rebuttal.pdf b/paper/aaai/rebuttal.pdf new file mode 100644 index 00000000..67f166ef Binary files /dev/null and b/paper/aaai/rebuttal.pdf differ diff --git a/paper/aaai/rebuttal.tex b/paper/aaai/rebuttal.tex new file mode 100644 index 00000000..6c22c14b --- /dev/null +++ b/paper/aaai/rebuttal.tex @@ -0,0 +1,241 @@ +%File: anonymous-submission-latex-2024.tex +\documentclass[letterpaper]{article} % DO NOT CHANGE THIS +\usepackage[submission]{aaai24} % DO NOT CHANGE THIS +\usepackage{times} % DO NOT CHANGE THIS +\usepackage{helvet} % DO NOT CHANGE THIS +\usepackage{courier} % DO NOT CHANGE THIS +\usepackage[hyphens]{url} % DO NOT CHANGE THIS +\usepackage{graphicx} % DO NOT CHANGE THIS +\urlstyle{rm} % DO NOT CHANGE THIS +\def\UrlFont{\rm} % DO NOT CHANGE THIS +\usepackage{natbib} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT +\usepackage{caption} % DO NOT CHANGE THIS AND DO NOT ADD ANY OPTIONS TO IT +\frenchspacing % DO NOT CHANGE THIS +\setlength{\pdfpagewidth}{8.5in} % DO NOT CHANGE THIS +\setlength{\pdfpageheight}{11in} % DO NOT CHANGE THIS +% +% These are recommended to typeset algorithms but not required. See the subsubsection on algorithms. Remove them if you don't have algorithms in your paper. +\usepackage{algorithm} +% \usepackage{algorithmic} + +% +% These are are recommended to typeset listings but not required. See the subsubsection on listing. Remove this block if you don't have listings in your paper. +% \usepackage{newfloat} +% \usepackage{listings} +% \DeclareCaptionStyle{ruled}{labelfont=normalfont,labelsep=colon,strut=off} % DO NOT CHANGE THIS +% \lstset{% +% basicstyle={\footnotesize\ttfamily},% footnotesize acceptable for monospace +% numbers=left,numberstyle=\footnotesize,xleftmargin=2em,% show line numbers, remove this entire line if you don't want the numbers. +% aboveskip=0pt,belowskip=0pt,% +% showstringspaces=false,tabsize=2,breaklines=true} +% \floatstyle{ruled} +% \newfloat{listing}{tb}{lst}{} +% \floatname{listing}{Listing} +% +% Keep the \pdfinfo as shown here. There's no need +% for you to add the /Title and /Author tags. +\pdfinfo{ +/TemplateVersion (2024.1) +} + +\usepackage{amsfonts} % blackboard math symbols +\usepackage{amsmath} +\usepackage{amsthm} +\usepackage{caption} +\usepackage{graphicx} +\usepackage{algpseudocode} +\usepackage{import} +\usepackage{booktabs} +\usepackage{longtable} +\usepackage{array} +\usepackage{multirow} +\usepackage{placeins} + + +% Numbered Environments: +\newtheorem{definition}{Definition}[section] +\newtheorem{question}{Research Question}[section] + +% Bibliography +% \bibliographystyle{unsrtnat} +% \setcitestyle{numbers,square,comma} + +% Algorithm +\renewcommand{\algorithmicrequire}{\textbf{Input:}} +\renewcommand{\algorithmicensure}{\textbf{Output:}} + +% DISALLOWED PACKAGES +% \usepackage{authblk} -- This package is specifically forbidden +% \usepackage{balance} -- This package is specifically forbidden +% \usepackage{color (if used in text) +% \usepackage{CJK} -- This package is specifically forbidden +% \usepackage{float} -- This package is specifically forbidden +% \usepackage{flushend} -- This package is specifically forbidden +% \usepackage{fontenc} -- This package is specifically forbidden +% \usepackage{fullpage} -- This package is specifically forbidden +% \usepackage{geometry} -- This package is specifically forbidden +% \usepackage{grffile} -- This package is specifically forbidden +% \usepackage{hyperref} -- This package is specifically forbidden +% \usepackage{navigator} -- This package is specifically forbidden +% (or any other package that embeds links such as navigator or hyperref) +% \indentfirst} -- This package is specifically forbidden +% \layout} -- This package is specifically forbidden +% \multicol} -- This package is specifically forbidden +% \nameref} -- This package is specifically forbidden +% \usepackage{savetrees} -- This package is specifically forbidden +% \usepackage{setspace} -- This package is specifically forbidden +% \usepackage{stfloats} -- This package is specifically forbidden +% \usepackage{tabu} -- This package is specifically forbidden +% \usepackage{titlesec} -- This package is specifically forbidden +% \usepackage{tocbibind} -- This package is specifically forbidden +% \usepackage{ulem} -- This package is specifically forbidden +% \usepackage{wrapfig} -- This package is specifically forbidden +% DISALLOWED COMMANDS +% \nocopyright -- Your paper will not be published if you use this command +% \addtolength -- This command may not be used +% \balance -- This command may not be used +% \baselinestretch -- Your paper will not be published if you use this command +% \clearpage -- No page breaks of any kind may be used for the final version of your paper +% \columnsep -- This command may not be used +% \newpage -- No page breaks of any kind may be used for the final version of your paper +% \pagebreak -- No page breaks of any kind may be used for the final version of your paperr +% \pagestyle -- This command may not be used +% \tiny -- This is not an acceptable font size. +% \vspace{- -- No negative value may be used in proximity of a caption, figure, table, section, subsection, subsubsection, or reference +% \vskip{- -- No negative value may be used to alter spacing above or below a caption, figure, table, section, subsection, subsubsection, or reference + +\setcounter{secnumdepth}{2} %May be changed to 1 or 2 if section numbers are desired. + +% The file aaai24.sty is the style file for AAAI Press +% proceedings, working notes, and technical reports. +% + +% Title + +% Your title must be in mixed case, not sentence case. +% That means all verbs (including short verbs like be, is, using,and go), +% nouns, adverbs, adjectives should be capitalized, including both words in hyphenated terms, while +% articles, conjunctions, and prepositions are lower case unless they +% directly follow a colon or long dash +\title{Faithful Model Explanations through\\ +Energy-Constrained Conformal Counterfactuals} +\author{ + %Authors + % All authors must be in the same font size and format. + Written by AAAI Press Staff\textsuperscript{\rm 1}\thanks{With help from the AAAI Publications Committee.}\\ + AAAI Style Contributions by Pater Patel Schneider, + Sunil Issar,\\ + J. Scott Penberthy, + George Ferguson, + Hans Guesgen, + Francisco Cruz\equalcontrib, + Marc Pujol-Gonzalez\equalcontrib +} +\affiliations{ + %Afiliations + \textsuperscript{\rm 1}Association for the Advancement of Artificial Intelligence\\ + % If you have multiple authors and multiple affiliations + % use superscripts in text and roman font to identify them. + % For example, + + % Sunil Issar\textsuperscript{\rm 2}, + % J. Scott Penberthy\textsuperscript{\rm 3}, + % George Ferguson\textsuperscript{\rm 4}, + % Hans Guesgen\textsuperscript{\rm 5} + % Note that the comma should be placed after the superscript + + 1900 Embarcadero Road, Suite 101\\ + Palo Alto, California 94303-3310 USA\\ + % email address must be in roman text type, not monospace or sans serif + proceedings-questions@aaai.org +% +% See more examples next +} + +%Example, Single Author, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it +% \iffalse +% \title{My Publication Title --- Single Author} +% \author { +% Author Name +% } +% \affiliations{ +% Affiliation\\ +% Affiliation Line 2\\ +% name@example.com +% } +% \fi + +% \iffalse +% %Example, Multiple Authors, ->> remove \iffalse,\fi and place them surrounding AAAI title to use it +% \title{My Publication Title --- Multiple Authors} +% \author { +% % Authors +% First Author Name\textsuperscript{\rm 1}, +% Second Author Name\textsuperscript{\rm 2}, +% Third Author Name\textsuperscript{\rm 1} +% } +% \affiliations { +% % Affiliations +% \textsuperscript{\rm 1}Affiliation 1\\ +% \textsuperscript{\rm 2}Affiliation 2\\ +% firstAuthor@affiliation1.com, secondAuthor@affilation2.com, thirdAuthor@affiliation1.com +% } +% \fi + +\begin{document} + +We thank the reviewers for their thoughtful comments and are glad with the overall positive response. + +\subsection*{Reviewer \#1} + +\subsubsection{1. Experiment results: linguistic explanation.} + +We will add a linguistic explanation in Section 6 where we highlight that \textit{ECCCo} produces plausible counterfactuals iff the classifier itself has learned plausible explanations for the data. It thus avoids the risk of generating plausible but potentially misleading explanations for models that are highly susceptible to implausible explanations. + +\subsubsection{2. Core innovation: more visualizations.} + +Figure~\ref{fig:poc} shows the relationship between implausibility and the energy constraint for MNIST data. As expected, this relationship is positive and the size of the relationship depends positively on the model's generative property (the observed relationships are stronger for joint energy models). We will add such images for all datasets to the appendix. We note that our final benchmark results involve around 1.5 million counterfactuals per dataset (not including grid searches). + +\begin{figure}[h] + \centering + \includegraphics[width=\linewidth]{../../www/dist_energy.png} + \caption{The L2 distance of randomly drawn MNIST images with Gaussian perturbations from unperturbed images in the target class (horizontal axis) plotted against their energy-constrained score, i.e. target logit (vertical axis).}\label{fig:poc} +\end{figure} + +\subsubsection{3. Structural clarity.} + +To facilitate comprehension, we will follow the reviewer's advice and add a systematic flowchart either in the appendix or in place of Figure 2. + +\subsection*{Reviewer \#2} + +\subsubsection{4. Why use an embedding?} + +There are two main reasons for using a low-dimensional latent embedding: firstly, to help with plausibility and, secondly, to reduce computational costs. The latter is not currently made explicit in the paper and we will add this in Section 5. The former is discussed in the context of the results for \textit{ECCCo+} in Section 6.3, but we will highlight the following rationale: + +There is indeed a tradeoff between plausibility and faithfulness through the introduction of bias: plausibility is improved because counterfactuals are insensitive to variation captured by higher-order principal components. Intuitively, the generated counterfactuals are therefore less noisy. We think that the bias introduced by PCA may be acceptable, precisely because it `will not add any information on the input distribution' as the reviewer correctly points out. To maintain faithfulness, we want to avoid adding any information through surrogate models as much as possible. + +\subsubsection{5. What is `epsilon' and `s'?} + +From the paper: `[...] the step-size $\epsilon_j$ is typically polynomially decayed.' Intuitively, $\epsilon_j$ determines the size of gradient updates and random noise in each iteration of SGLD. + +Regarding $s(\cdot)$, this was an oversight. In the appendix we explain that `[the calibration dataset] is then used to compute so-called nonconformity scores: $\mathcal{S}=\{s(\mathbf{x}_i,\mathbf{y}_i)\}_{i \in \mathcal{D}_{\text{cal}}}$ where $s: (\mathcal{X},\mathcal{Y}) \mapsto \mathbb{R}$ is referred to as \textit{score function}.' We will add this in Section 4.2 of the paper. + +\subsubsection{6. Euclidean distance.} + +As we mentioned in the additional author response, we investigated different distance metrics and found that the overall qualitative results were largely independent of the choice of metric. For image data, we still decided to report the results for a dissimilarity metric that is more appropriate in this context. All of our distance-based metrics are computed in the feature space. This is because we would indeed expect certain discrepancies between distances evaluated in the feature space and distances evaluated in the latent space of a VAE, for example. In cases where high dimensionality leads to prohibitive computational costs, we suggest working in a lower-dimensional subspace that is as uninformative as possible (such as PCA). + +\subsubsection{7. Model fails to learn plausible explanations.} + +In these cases, \textit{ECCCo} generally achieves lower plausibility while maintaining faithfulness (see also points 1 and 9). + +\subsubsection{8. Faithfulness metric: is it fair?} + +We have taken measures to not unfairly bias our generator for the unfaithfulness metric: instead of penalizing the unfaithfulness metric directly, we penalize model energy in our preferred implementation. In contrast, \textit{Wachter} penalizes the closeness criterion directly and hence does particularly well in this regard. In the absence of other established faithfulness metrics, we can only point out that \textit{ECCCo} achieves strong performance for other commonly used metrics as well. For \textit{validity}, which corresponds to \textit{fidelity}, \textit{ECCCo} performs strongly. + +Joint energy models (JEM) are indeed explicitly trained to model $\mathcal{X}|y$, but the faithfulness metric is not computed for samples generated by JEMs. It is computed for counterfactuals generated by constraining model energy and hence there is no obvious source of bias. Our empirical findings support this argument: firstly, \textit{ECCCo} achieves high faithfulness also for classifiers that have not been trained to model $\mathcal{X}|y$; secondly, our additional results in the appendix for \textit{ECCCo-L1} show that if we do indeed explicitly penalize the unfaithfulness metric, we achieve even better results in this regard (also for models not trained to model $\mathcal{X}|y$). + +\subsubsection{9. Add unreliable models.} + +We would argue that the simple multi-layer perceptrons (MLP) are unreliable, especially compared to ensembles, joint energy models and convolutional neural networks. Simple MLPs are generally more vulnerable to adversarial attacks, which makes them susceptible to implausible counterfactual explanations as we point out in Section 3. Our results support this notion, in that the quality of counterfactuals produced by \textit{ECCCo} is higher for more reliable models. Consistent with the reviewer's idea, we originally considered introducing `poisoned' VAEs to illustrate what we identify as the key vulnerability of \textit{REVISE}: if the underlying VAE is misspecified, this will adversely affect counterfactual outcomes as well. We discarded this idea due to limited scope and because we decided that Section 3 sufficiently illustrates our line of thinking. + +\end{document} diff --git a/src/ECCCo.jl b/src/ECCCo.jl index 0cb0f598..5d72878a 100644 --- a/src/ECCCo.jl +++ b/src/ECCCo.jl @@ -15,4 +15,4 @@ export conformal_training_loss export get_lowest_energy_sample export set_size_penalty, distance_from_energy -end \ No newline at end of file +end diff --git a/src/generator.jl b/src/generator.jl index 5974f513..a1595f61 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -2,19 +2,20 @@ using CounterfactualExplanations.Objectives using CounterfactualExplanations.Generators: GradientBasedGenerator "Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." -function ECCCoGenerator(; - λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.2,0.4,0.4], - κ::Real=1.0, - temp::Real=0.1, - opt::Union{Nothing,Flux.Optimise.AbstractOptimiser}=nothing, - use_class_loss::Bool=false, - use_energy_delta::Bool=false, - nsamples::Union{Nothing,Int}=nothing, - nmin::Union{Nothing,Int}=nothing, - niter::Union{Nothing,Int}=nothing, - reg_strength::Real=0.1, - dim_reduction::Bool=false, - kwargs... +function ECCCoGenerator(; + λ::Union{AbstractFloat,Vector{<:AbstractFloat}} = [0.2, 0.4, 0.4], + κ::Real = 1.0, + temp::Real = 0.1, + opt::Union{Nothing,Flux.Optimise.AbstractOptimiser} = nothing, + use_class_loss::Bool = false, + use_energy_delta::Bool = false, + nsamples::Union{Nothing,Int} = nothing, + nmin::Union{Nothing,Int} = nothing, + niter::Union{Nothing,Int} = nothing, + reg_strength::Real = 0.1, + decay::Tuple = (0.1, 1), + dim_reduction::Bool = false, + kwargs..., ) # Default ECCCo parameters @@ -29,21 +30,39 @@ function ECCCoGenerator(; # Loss function if use_class_loss - loss_fun(ce::AbstractCounterfactualExplanation) = conformal_training_loss(ce; temp=temp) + loss_fun(ce::AbstractCounterfactualExplanation) = + conformal_training_loss(ce; temp = temp) else loss_fun = nothing end _energy_penalty = - use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, niter=niter, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin, niter=niter)) + use_energy_delta ? + ( + ECCCo.energy_delta, + ( + n = nsamples, + nmin = nmin, + niter = niter, + reg_strength = reg_strength, + decay = decay, + ), + ) : (ECCCo.distance_from_energy, (n = nsamples, nmin = nmin, niter = niter)) _penalties = [ - (Objectives.distance_l1, []), - (ECCCo.set_size_penalty, (κ=κ, temp=temp)), + (Objectives.distance_l1, []), + (ECCCo.set_size_penalty, (κ = κ, temp = temp)), _energy_penalty, ] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ # Generator - return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, dim_reduction=dim_reduction, kwargs...) -end \ No newline at end of file + return GradientBasedGenerator(; + loss = loss_fun, + penalty = _penalties, + λ = λ, + opt = opt, + dim_reduction = dim_reduction, + kwargs..., + ) +end diff --git a/src/losses.jl b/src/losses.jl index 2eada63d..d235f883 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -6,7 +6,12 @@ using Statistics: mean A configurable classification loss function for Conformal Predictors. """ -function conformal_training_loss(ce::AbstractCounterfactualExplanation; temp::Real=0.1, agg=mean, kwargs...) +function conformal_training_loss( + ce::AbstractCounterfactualExplanation; + temp::Real = 0.1, + agg = mean, + kwargs..., +) conf_model = ce.M.model fitresult = ce.M.fitresult X = CounterfactualExplanations.decode_state(ce) @@ -18,13 +23,17 @@ function conformal_training_loss(ce::AbstractCounterfactualExplanation; temp::Re n_classes = length(ce.data.y_levels) loss_mat = ones(n_classes, n_classes) - loss = map(eachslice(X, dims=ndims(X))) do x - x = ndims(x) == 1 ? x[:,:]' : x + loss = map(eachslice(X, dims = ndims(X))) do x + x = ndims(x) == 1 ? x[:, :]' : x ConformalPrediction.ConformalTraining.classification_loss( - conf_model, fitresult, x, y; - temp=temp, loss_matrix = loss_mat, + conf_model, + fitresult, + x, + y; + temp = temp, + loss_matrix = loss_mat, ) end loss = agg(loss)[1] return loss -end \ No newline at end of file +end diff --git a/src/model.jl b/src/model.jl index 445bf4c7..d942e440 100644 --- a/src/model.jl +++ b/src/model.jl @@ -8,7 +8,10 @@ using MLJFlux using MLUtils using Statistics -const CompatibleAtomicModel = Union{<:MLJFlux.MLJFluxProbabilistic,MLJEnsembles.ProbabilisticEnsembleModel{<:MLJFlux.MLJFluxProbabilistic}} +const CompatibleAtomicModel = Union{ + <:MLJFlux.MLJFluxProbabilistic, + MLJEnsembles.ProbabilisticEnsembleModel{<:MLJFlux.MLJFluxProbabilistic}, +} """ ConformalModel <: Models.AbstractDifferentiableModel @@ -20,7 +23,8 @@ struct ConformalModel <: Models.AbstractDifferentiableModel fitresult::Any likelihood::Union{Nothing,Symbol} function ConformalModel(model, fitresult, likelihood) - if likelihood ∈ [:classification_binary, :classification_multi] || isnothing(likelihood) + if likelihood ∈ [:classification_binary, :classification_multi] || + isnothing(likelihood) new(model, fitresult, likelihood) else throw( @@ -38,10 +42,10 @@ end Private function that extracts the chains from a fitted model. """ function _get_chains(fitresult) - + chains = [] - ignore_derivatives() do + ignore_derivatives() do if fitresult isa MLJEnsembles.WrappedEnsemble _chains = map(res -> res[1], fitresult.ensemble) else @@ -49,7 +53,7 @@ function _get_chains(fitresult) end push!(chains, _chains...) end - + return chains end @@ -103,7 +107,11 @@ end Outer constructor for `ConformalModel`. If `fitresult` is not specified, the model is not fitted and `likelihood` is inferred from the model. If `fitresult` is specified, `likelihood` is inferred from the output dimension of the model. If `likelihood` is not specified, it defaults to `:classification_binary`. """ -function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing) +function ConformalModel( + model, + fitresult = nothing; + likelihood::Union{Nothing,Symbol} = nothing, +) # Check if model is fitted and infer likelihood: if isnothing(fitresult) @@ -152,13 +160,13 @@ In the binary case logits are fed through the sigmoid function instead of softma which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier """ function Models.logits(M::ConformalModel, X::AbstractArray) - + fitresult = M.fitresult function predict_logits(fitresult, x) - ŷ = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> - y -> mean(y, dims=ndims(y)) |> - y -> MLUtils.unstack(y, dims=ndims(y))[1] + ŷ = + MLUtils.stack(map(chain -> get_logits(chain, x), _get_chains(fitresult))) |> + y -> mean(y, dims = ndims(y)) |> y -> MLUtils.unstack(y, dims = ndims(y))[1] if ndims(ŷ) == 2 ŷ = [ŷ] end @@ -203,4 +211,4 @@ function Models.train(M::ConformalModel, data::CounterfactualData; kwrgs...) fit!(mach; kwrgs...) likelihood, _ = CounterfactualExplanations.guess_likelihood(data.output_encoder.y) return ConformalModel(mach.model, mach.fitresult, likelihood) -end \ No newline at end of file +end diff --git a/src/penalties.jl b/src/penalties.jl index 2e611097..b295d174 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -12,8 +12,10 @@ using Statistics: mean Penalty for smooth conformal set size. """ function set_size_penalty( - ce::AbstractCounterfactualExplanation; - κ::Real=1.0, temp::Real=0.1, agg=mean + ce::AbstractCounterfactualExplanation; + κ::Real = 1.0, + temp::Real = 0.1, + agg = mean, ) _loss = 0.0 @@ -21,15 +23,17 @@ function set_size_penalty( conf_model = ce.M.model fitresult = ce.M.fitresult X = CounterfactualExplanations.decode_state(ce) - _loss = map(eachslice(X, dims=ndims(X))) do x - x = ndims(x) == 1 ? x[:,:] : x + _loss = map(eachslice(X, dims = ndims(X))) do x + x = ndims(x) == 1 ? x[:, :] : x if target_probs(ce, x)[1] >= 0.5 l = ConformalPrediction.ConformalTraining.smooth_size_loss( - conf_model, fitresult, x'; - κ=κ, - temp=temp + conf_model, + fitresult, + x'; + κ = κ, + temp = temp, )[1] - else + else l = 0.0 end return l @@ -42,19 +46,22 @@ end function energy_delta( ce::AbstractCounterfactualExplanation; - n::Int=50, niter=500, from_buffer=true, agg=mean, - choose_lowest_energy=true, - choose_random=false, - nmin::Int=25, - return_conditionals=false, - reg_strength=0.1, - decay::Real=0.1, - kwargs... + n::Int = 50, + niter = 500, + from_buffer = true, + agg = mean, + choose_lowest_energy = true, + choose_random = false, + nmin::Int = 25, + return_conditionals = false, + reg_strength = 0.1, + decay::Tuple = (0.1, 1), + kwargs..., ) xproposed = CounterfactualExplanations.decode_state(ce) # current state t = get_target_index(ce.data.y_levels, ce.target) - E(x) = -logits(ce.M, x)[t,:] # negative logits for taraget class + E(x) = -logits(ce.M, x)[t, :] # negative logits for taraget class # Generative loss: gen_loss = E(xproposed) @@ -64,7 +71,17 @@ function energy_delta( reg_loss = norm(E(xproposed))^2 reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss) # aggregate over samples - return gen_loss + reg_strength * reg_loss + # Decay: + iter = total_steps(ce) + ϕ = 1.0 + if iter % decay[2] == 0 + ϕ = exp(-decay[1] * total_steps(ce)) + end + + # Total loss: + ℒ = ϕ * (gen_loss + reg_strength * reg_loss) + + return ℒ end @@ -75,13 +92,16 @@ Computes the distance from the counterfactual to generated conditional samples. """ function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=50, niter=500, from_buffer=true, agg=mean, - choose_lowest_energy=true, - choose_random=false, - nmin::Int=25, - return_conditionals=false, - p::Int=1, - kwargs... + n::Int = 50, + niter = 500, + from_buffer = true, + agg = mean, + choose_lowest_energy = true, + choose_random = false, + nmin::Int = 25, + return_conditionals = false, + p::Int = 1, + kwargs..., ) _loss = 0.0 @@ -93,24 +113,25 @@ function distance_from_energy( ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + _dict[:energy_sampler] = + ECCCo.EnergySampler(ce; niter = niter, nsamples = n, kwargs...) end eng_sampler = _dict[:energy_sampler] if choose_lowest_energy nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) - xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n = nmin) push!(conditional_samples, xmin) elseif choose_random - push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + push!(conditional_samples, rand(eng_sampler, n; from_buffer = from_buffer)) else push!(conditional_samples, eng_sampler.buffer) end end _loss = map(eachcol(conditional_samples[1])) do xsample - distance(ce; from=xsample, agg=agg, p=p) + distance(ce; from = xsample, agg = agg, p = p) end - _loss = reduce((x,y) -> x + y, _loss) / n # aggregate over samples + _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples if return_conditionals return conditional_samples[1] @@ -119,7 +140,8 @@ function distance_from_energy( end -distance_from_energy_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = distance_from_energy(ce; p=2, kwrgs...) +distance_from_energy_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = + distance_from_energy(ce; p = 2, kwrgs...) """ distance_from_energy_cosine(ce::AbstractCounterfactualExplanation) @@ -128,12 +150,15 @@ Computes the cosine distance from the counterfactual to generated conditional sa """ function distance_from_energy_cosine( ce::AbstractCounterfactualExplanation; - n::Int=50, niter=500, from_buffer=true, agg=mean, - choose_lowest_energy=true, - choose_random=false, - nmin::Int=25, - return_conditionals=false, - kwargs... + n::Int = 50, + niter = 500, + from_buffer = true, + agg = mean, + choose_lowest_energy = true, + choose_random = false, + nmin::Int = 25, + return_conditionals = false, + kwargs..., ) _loss = 0.0 @@ -145,15 +170,16 @@ function distance_from_energy_cosine( ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + _dict[:energy_sampler] = + ECCCo.EnergySampler(ce; niter = niter, nsamples = n, kwargs...) end eng_sampler = _dict[:energy_sampler] if choose_lowest_energy nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) - xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n = nmin) push!(conditional_samples, xmin) elseif choose_random - push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + push!(conditional_samples, rand(eng_sampler, n; from_buffer = from_buffer)) else push!(conditional_samples, eng_sampler.buffer) end @@ -162,7 +188,7 @@ function distance_from_energy_cosine( _loss = map(eachcol(conditional_samples[1])) do xsample cos_dist(CounterfactualExplanations.counterfactual(ce), xsample) end - _loss = reduce((x,y) -> x + y, _loss) / n # aggregate over samples + _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples if return_conditionals return conditional_samples[1] @@ -178,12 +204,15 @@ Computes 1-SSIM from the counterfactual to generated conditional samples where S """ function distance_from_energy_ssim( ce::AbstractCounterfactualExplanation; - n::Int=50, niter=500, from_buffer=true, agg=mean, - choose_lowest_energy=true, - choose_random=false, - nmin::Int=25, - return_conditionals=false, - kwargs... + n::Int = 50, + niter = 500, + from_buffer = true, + agg = mean, + choose_lowest_energy = true, + choose_random = false, + nmin::Int = 25, + return_conditionals = false, + kwargs..., ) _loss = 0.0 @@ -195,15 +224,16 @@ function distance_from_energy_ssim( ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + _dict[:energy_sampler] = + ECCCo.EnergySampler(ce; niter = niter, nsamples = n, kwargs...) end eng_sampler = _dict[:energy_sampler] if choose_lowest_energy nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) - xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n = nmin) push!(conditional_samples, xmin) elseif choose_random - push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + push!(conditional_samples, rand(eng_sampler, n; from_buffer = from_buffer)) else push!(conditional_samples, eng_sampler.buffer) end @@ -228,14 +258,14 @@ Computes the distance from the counterfactual to the N-nearest neighbors of the """ function distance_from_targets( ce::AbstractCounterfactualExplanation; - agg=mean, - n_nearest_neighbors::Union{Int,Nothing}=100, - p::Int=1, + agg = mean, + n_nearest_neighbors::Union{Int,Nothing} = 100, + p::Int = 1, ) target_idx = ce.data.output_encoder.labels .== ce.target - target_samples = ce.data.X[:,target_idx] + target_samples = ce.data.X[:, target_idx] x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=ndims(x′))) do x + loss = map(eachslice(x′, dims = ndims(x′))) do x Δ = map(eachcol(target_samples)) do xsample norm(x - xsample, p) end @@ -250,7 +280,8 @@ function distance_from_targets( end -distance_from_targets_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = distance_from_targets(ce; p=2, kwrgs...) +distance_from_targets_l2(ce::AbstractCounterfactualExplanation; kwrgs...) = + distance_from_targets(ce; p = 2, kwrgs...) @@ -261,14 +292,14 @@ Computes the cosine distance from the counterfactual to the N-nearest neighbors """ function distance_from_targets_cosine( ce::AbstractCounterfactualExplanation; - agg=mean, - n_nearest_neighbors::Union{Int,Nothing}=100, + agg = mean, + n_nearest_neighbors::Union{Int,Nothing} = 100, ) target_idx = ce.data.output_encoder.labels .== ce.target - target_samples = ce.data.X[:,target_idx] + target_samples = ce.data.X[:, target_idx] x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=ndims(x′))) do x + loss = map(eachslice(x′, dims = ndims(x′))) do x Δ = map(eachcol(target_samples)) do xsample cos_dist(x, xsample) end @@ -290,13 +321,13 @@ Computes the distance (1-SSIM) from the counterfactual to the N-nearest neighbor """ function distance_from_targets_ssim( ce::AbstractCounterfactualExplanation; - agg=mean, - n_nearest_neighbors::Union{Int,Nothing}=100, + agg = mean, + n_nearest_neighbors::Union{Int,Nothing} = 100, ) target_idx = ce.data.output_encoder.labels .== ce.target target_samples = ce.data.X[:, target_idx] x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=ndims(x′))) do x + loss = map(eachslice(x′, dims = ndims(x′))) do x Δ = map(eachcol(target_samples)) do xsample ssim_dist(x, xsample) end @@ -310,4 +341,3 @@ function distance_from_targets_ssim( return loss end - diff --git a/src/sampling.jl b/src/sampling.jl index f62f26f8..75776c2c 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -8,7 +8,8 @@ using JointEnergyModels When called on data `x`, softmax logits are returned. In the binary case, outputs are one-hot encoded. """ -(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x)) +(model::AbstractFittedModel)(x) = + log.(CounterfactualExplanations.predict_proba(model, nothing, x)) "Base type that stores information relevant to energy-based posterior sampling from `AbstractFittedModel`." mutable struct EnergySampler @@ -36,25 +37,27 @@ function EnergySampler( model::AbstractFittedModel, data::CounterfactualData, y::Any; - opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(), - niter::Int=100, - nsamples::Int=100 + opt::JointEnergyModels.AbstractSamplingRule = ImproperSGLD(), + niter::Int = 100, + nsamples::Int = 100, ) @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels) K = length(data.y_levels) input_size = size(selectdim(data.X, ndims(data.X), 1)) - 𝒟x = Uniform(extrema(data.X)...) + # Prior distribution: + 𝒟x = prior_sampling_space(data) 𝒟y = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size) + # Sampler: + sampler = ConditionalSampler(𝒟x, 𝒟y; input_size = input_size) yidx = get_target_index(data.y_levels, y) # Initiate: energy_sampler = EnergySampler(model, data, sampler, opt, nothing, yidx) # Generate conditional: - generate_samples!(energy_sampler, nsamples, yidx; niter=niter) + generate_samples!(energy_sampler, nsamples, yidx; niter = niter) return energy_sampler end @@ -67,10 +70,7 @@ end Constructor for `EnergySampler` that takes a `CounterfactualExplanation` as input. The underlying model, data and `target` are used for the `EnergySampler`, where `target` is the conditioning value of `y`. """ -function EnergySampler( - ce::CounterfactualExplanation; - kwrgs... -) +function EnergySampler(ce::CounterfactualExplanation; kwrgs...) # Setup: model = ce.M @@ -85,12 +85,12 @@ end Generates `n` samples from `EnergySampler` for conditioning value `y`. """ -function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) +function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int = 100) # Generate samples: f(x) = logits(e.model, x) rule = e.opt - xsamples = e.sampler(f, rule; niter=niter, n_samples=n, y=y) + xsamples = e.sampler(f, rule; niter = niter, n_samples = n, y = y) return xsamples end @@ -100,11 +100,12 @@ end Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`. """ -function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) +function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int = 100) if isnothing(e.buffer) - e.buffer = generate_samples(e, n, y; niter=niter) + e.buffer = generate_samples(e, n, y; niter = niter) else - e.buffer = cat(e.buffer, generate_samples(e, n, y; niter=niter), dims=ndims(e.buffer)) + e.buffer = + cat(e.buffer, generate_samples(e, n, y; niter = niter), dims = ndims(e.buffer)) end e.yidx = y end @@ -114,13 +115,18 @@ end Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`. """ -function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100) +function Base.rand( + sampler::EnergySampler, + n::Int = 100; + from_buffer = true, + niter::Int = 100, +) ntotal = size(sampler.buffer, 2) idx = rand(1:ntotal, n) if from_buffer X = sampler.buffer[:, idx] else - X = generate_samples(sampler, n, sampler.yidx; niter=niter) + X = generate_samples(sampler, n, sampler.yidx; niter = niter) end return X end @@ -130,10 +136,14 @@ end Chooses the samples with the lowest energy (i.e. highest probability) from `EnergySampler`. """ -function get_lowest_energy_sample(sampler::EnergySampler; n::Int=5) +function get_lowest_energy_sample(sampler::EnergySampler; n::Int = 5) X = sampler.buffer model = sampler.model y = sampler.yidx - x = selectdim(X, ndims(X), energy(sampler.sampler, model, X, y; agg=x -> partialsortperm(x, 1:n))) + x = selectdim( + X, + ndims(X), + energy(sampler.sampler, model, X, y; agg = x -> partialsortperm(x, 1:n)), + ) return x -end \ No newline at end of file +end diff --git a/src/utils.jl b/src/utils.jl index 53909ec8..a2fae145 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ using Images Helper function to add tiny noise to inputs. """ -function pre_process(x; noise::Float32=0.03f0) +function pre_process(x; noise::Float32 = 0.03f0) ϵ = Float32.(randn(size(x)) * noise) x += ϵ return x @@ -31,10 +31,7 @@ end Converts a vector to a 28x28 grey image. """ function convert2mnist(x) - x = (x .+ 1) ./ 2 |> - x -> reshape(x, 28, 28) |> - permutedims |> - x -> Gray.(x) + x = (x .+ 1) ./ 2 |> x -> reshape(x, 28, 28) |> permutedims |> x -> Gray.(x) return x end @@ -58,5 +55,19 @@ Computes 1-SSIM between two images. function ssim_dist(x, y) x = convert2mnist(x) y = convert2mnist(y) - return (1 - assess_ssim(x, y))/2 + return (1 - assess_ssim(x, y)) / 2 +end + +""" + prior_sampling_space(data::CounterfactualData; n_std=3) + +Define the prior sampling space for the data. +""" +function prior_sampling_space(data::CounterfactualData; n_std=3) + X = data.X + centers = mean(X, dims=2) + stds = std(X, dims=2) + lower_bound = minimum(centers .- n_std .* stds)[1] + upper_bound = maximum(centers .+ n_std .* stds)[1] + return Uniform(lower_bound, upper_bound) end \ No newline at end of file diff --git a/www/dist_energy.png b/www/dist_energy.png new file mode 100644 index 00000000..6608b944 Binary files /dev/null and b/www/dist_energy.png differ