diff --git a/Project.toml b/Project.toml
index aa289fc..78a8391 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "MatrixLMnet"
uuid = "436227dc-6146-11e9-240b-9df6ee45b799"
authors = ["Jane Liang", "Zifan Yu", "Gregory Farage", "Saunak Sen"]
-version = "1.0.2"
+version = "1.1.0"
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@@ -15,5 +15,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
DataFrames = "^1.3.0"
MLBase = "0.8, 0.9"
-MatrixLM = "^0.1.3"
+MatrixLM = "~0.2"
+Statistics = "1"
julia = "1.6.4, 1"
+
+[extras]
+Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa"
+Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+
+[targets]
+test = ["Distributions", "Helium", "Test"]
\ No newline at end of file
diff --git a/README.md b/README.md
index 1612ec5..6b2a069 100644
--- a/README.md
+++ b/README.md
@@ -82,7 +82,7 @@ alphas = reverse(collect(0:0.1:1))
L1 and L2-penalized estimates for matrix linear models can be obtained by running `mlmnet`. In addition to the `RawData` object, `lambdas` and `alphas`, `mlmnet` requires as a keyword argument the function name for an algorithm used to fit Elastic Net penalized estimates. Current methods are: `"cd"` (coordinate descent), `"cd_active"` (active coordinate descent), `"ista"` (ISTA with fixed step size), `"fista"` (FISTA with fixed step size), `"fista_bt"` (FISTA with backtracking), and `"admm"` (ADMM).
-An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `hasXIntercept=false` and/or `hasZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations.
+An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `addXIntercept=false` and/or `addZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations.
```
est = mlmnet(dat, lambdas, alphas, method = "fista_bt")
diff --git a/src/bic/mlmnet_bic.jl b/src/bic/mlmnet_bic.jl
index 8e514d4..b3713bd 100644
--- a/src/bic/mlmnet_bic.jl
+++ b/src/bic/mlmnet_bic.jl
@@ -31,7 +31,7 @@ end
lambdas::AbstractArray{Float64,1},
alphas::AbstractArray{Float64,1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -55,9 +55,9 @@ Performs BIC validation for `mlmnet`.
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- isNaive = boolean flag indicating whether to solve the Naive or non-Naive
Elastic-net problem
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -94,7 +94,7 @@ function mlmnet_bic(data::RawData,
lambdas::AbstractArray{Float64,1},
alphas::AbstractArray{Float64,1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -107,8 +107,8 @@ function mlmnet_bic(data::RawData,
# Run mlmnet on each RawData object, in parallel when possible
MLMNet = mlmnet(data, lambdas, alphas;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept,
- hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept,
+ addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
diff --git a/src/crossvalidation/mlmnet_cv.jl b/src/crossvalidation/mlmnet_cv.jl
index 1421e5a..ea1e10c 100644
--- a/src/crossvalidation/mlmnet_cv.jl
+++ b/src/crossvalidation/mlmnet_cv.jl
@@ -41,7 +41,7 @@ end
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -74,9 +74,9 @@ input.
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- isNaive = boolean flag indicating whether to solve the Naive or non-Naive
Elastic-net problem
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -116,7 +116,7 @@ function mlmnet_cv(data::RawData,
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -156,8 +156,8 @@ function mlmnet_cv(data::RawData,
MLMNets = Distributed.pmap(data -> mlmnet(data, lambdas, alphas;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept,
- hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept,
+ addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
@@ -177,7 +177,7 @@ end
alphas::AbstractArray{Float64,1},
rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64;
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -207,9 +207,9 @@ Calls the base `mlmnet_cv` function.
- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -248,7 +248,7 @@ function mlmnet_cv(data::RawData,
alphas::AbstractArray{Float64,1},
rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64;
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -263,7 +263,7 @@ function mlmnet_cv(data::RawData,
# base mlmnet_cv function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
@@ -278,7 +278,7 @@ end
alphas::AbstractArray{Float64,1},
nRowFolds::Int64, colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -308,9 +308,9 @@ input. Calls the base `mlmnet_cv` function.
- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -349,7 +349,7 @@ function mlmnet_cv(data::RawData,
alphas::AbstractArray{Float64,1},
nRowFolds::Int64, colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -364,7 +364,7 @@ function mlmnet_cv(data::RawData,
# base mlmnet_cv function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
@@ -379,7 +379,7 @@ end
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -409,9 +409,9 @@ folds randomly generated using calls to `make_folds`. Calls the base
- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -465,7 +465,7 @@ function mlmnet_cv(data::RawData,
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -481,7 +481,7 @@ function mlmnet_cv(data::RawData,
# function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
@@ -496,7 +496,7 @@ end
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -516,7 +516,7 @@ function mlmnet_cv(data::RawData,
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -532,7 +532,7 @@ function mlmnet_cv(data::RawData,
# function
mlmnet_cv(data, lambdas, alphas, nRowFolds, nColFolds;
method=method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
diff --git a/src/mlmnet/mlmnet.jl b/src/mlmnet/mlmnet.jl
index 6ecc8e9..434e5a0 100644
--- a/src/mlmnet/mlmnet.jl
+++ b/src/mlmnet/mlmnet.jl
@@ -156,7 +156,7 @@ end
lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1};
method::String = "ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -184,9 +184,9 @@ inputs.
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- isNaive = boolean flag indicating whether to solve the Naive or non-Naive
Elastic-net problem
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -235,7 +235,7 @@ function mlmnet(data::RawData,
lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1};
method::String = "ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -263,13 +263,13 @@ function mlmnet(data::RawData,
# Add X and Z intercepts if necessary
# Update toXReg and toZReg accordingly
- if hasXIntercept==true && data.predictors.hasXIntercept==false
+ if addXIntercept==true && data.predictors.hasXIntercept==false
data.predictors.X = add_intercept(data.predictors.X)
data.predictors.hasXIntercept = true
data.p = data.p + 1
toXReg = vcat(toXInterceptReg, toXReg)
end
- if hasZIntercept==true && data.predictors.hasZIntercept==false
+ if addZIntercept==true && data.predictors.hasZIntercept==false
data.predictors.Z = add_intercept(data.predictors.Z)
data.predictors.hasZIntercept = true
data.q = data.q + 1
@@ -278,14 +278,14 @@ function mlmnet(data::RawData,
# Remove X and Z intercepts in new predictors if necessary
# Update toXReg and toZReg accordingly
- if hasXIntercept==false && data.predictors.hasXIntercept==true
+ if addXIntercept==false && data.predictors.hasXIntercept==true
data.predictors.X = remove_intercept(data.predictors.X)
data.predictors.hasXIntercept = false
data.p = data.p - 1
toXReg = toXReg[2:end]
end
- if hasZIntercept==false && data.predictors.hasZIntercept==true
+ if addZIntercept==false && data.predictors.hasZIntercept==true
data.predictors.Z = remove_intercept(data.predictors.Z)
data.predictors.hasZIntercept = false
data.q = data.q - 1
@@ -293,10 +293,10 @@ function mlmnet(data::RawData,
end
# Update toXReg and toZReg accordingly when intercept is already included
- if hasXIntercept==true && data.predictors.hasXIntercept==true
+ if addXIntercept==true && data.predictors.hasXIntercept==true
toXReg[1] = toXInterceptReg
end
- if hasZIntercept==true && data.predictors.hasZIntercept==true
+ if addZIntercept==true && data.predictors.hasZIntercept==true
toZReg[1] = toZInterceptReg
end
@@ -314,8 +314,8 @@ function mlmnet(data::RawData,
Z = copy(get_Z(data))
# Centers and normalizes predictors
- meansX, normsX, = normalize!(X, hasXIntercept)
- meansZ, normsZ, = normalize!(Z, hasZIntercept)
+ meansX, normsX, = normalize!(X, addXIntercept)
+ meansZ, normsZ, = normalize!(Z, addZIntercept)
# If X and Z are standardized, set the norm to nothing
norms = nothing
else
@@ -375,11 +375,11 @@ function mlmnet(data::RawData,
# Back-transform coefficient estimates, if necessary.
# Case if including both X and Z intercepts:
- if toNormalize == true && (hasXIntercept==true) && (hasZIntercept==true)
+ if toNormalize == true && (addXIntercept==true) && (addZIntercept==true)
backtransform!(coeffs, meansX, meansZ, normsX, normsZ, get_Y(data),
data.predictors.X, data.predictors.Z)
elseif toNormalize == true # Otherwise
- backtransform!(coeffs, hasXIntercept, hasZIntercept, meansX, meansZ,
+ backtransform!(coeffs, addXIntercept, addZIntercept, meansX, meansZ,
normsX, normsZ)
end
@@ -399,7 +399,7 @@ function mlmnet(data::RawData,
lambdas::AbstractArray{Float64,1};
method::String = "ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -413,7 +413,7 @@ function mlmnet(data::RawData,
lambdas::AbstractArray{Float64,1};
method::String = "ista",
isNaive::Bool=false,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
@@ -425,7 +425,7 @@ function mlmnet(data::RawData,
alphas = [1.0] # default LASSO, 饾浖 = 1
rslts = mlmnet(data, lambdas, alphas; method,
- isNaive, hasXIntercept, hasZIntercept,
+ isNaive, addXIntercept, addZIntercept,
toXReg, toZReg,
toXInterceptReg, toZInterceptReg,
toNormalize, isVerbose,
diff --git a/src/mlmnet/mlmnet_perms.jl b/src/mlmnet/mlmnet_perms.jl
index 1f50caf..b72d7aa 100644
--- a/src/mlmnet/mlmnet_perms.jl
+++ b/src/mlmnet/mlmnet_perms.jl
@@ -3,7 +3,7 @@
lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1};
method::String = "ista", isNaive::Bool=false,
permFun::Function=shuffle_rows,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false,
@@ -30,9 +30,9 @@ function.
Elastic-net problem
- permFun = function used to permute `Y`. Defaults to `shuffle_rows`
(shuffles rows of `Y`).
-- hasXIntercept = boolean flag indicating whether or not to include an `X`
+- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
-- hasZIntercept = boolean flag indicating whether or not to include a `Z`
+- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
@@ -81,7 +81,7 @@ function mlmnet_perms(data::RawData,
lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1};
method::String = "ista", isNaive::Bool=false,
permFun::Function=shuffle_rows,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false,
@@ -95,7 +95,7 @@ function mlmnet_perms(data::RawData,
# Run penalty on the permuted data
return mlmnet(dataPerm, lambdas, alphas;
method = method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg,
toNormalize=toNormalize, isVerbose=isVerbose,
@@ -110,7 +110,7 @@ end
lambdas::AbstractArray{Float64,1};
method::String = "ista", isNaive::Bool=false,
permFun::Function=shuffle_rows,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false,
@@ -124,7 +124,7 @@ function mlmnet_perms(data::RawData,
lambdas::AbstractArray{Float64,1};
method::String = "ista", isNaive::Bool=false,
permFun::Function=shuffle_rows,
- hasXIntercept::Bool=true, hasZIntercept::Bool=true,
+ addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(data.p),
toZReg::BitArray{1}=trues(data.q),
toXInterceptReg::Bool=false,
@@ -140,7 +140,7 @@ function mlmnet_perms(data::RawData,
# Run L1-L2 penalties on the permuted data
rslts = mlmnet(dataPerm, lambdas, alphas;
method = method, isNaive=isNaive,
- hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
+ addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg,
toNormalize=toNormalize, isVerbose=isVerbose,
diff --git a/src/utilities/std_helpers.jl b/src/utilities/std_helpers.jl
index 0d392cb..4f33747 100644
--- a/src/utilities/std_helpers.jl
+++ b/src/utilities/std_helpers.jl
@@ -95,7 +95,7 @@ end
"""
backtransform!(B::AbstractArray{Float64,2},
- hasXIntercept::Bool, hasZIntercept::Bool,
+ addXIntercept::Bool, addZIntercept::Bool,
meansX::AbstractArray{Float64,2},
meansZ::AbstractArray{Float64,2},
normsX::AbstractArray{Float64,2},
@@ -108,9 +108,9 @@ or Z.
# Arguments
- B = 2d array of coefficient estimates B
-- hasXIntercept = boolean flag indicating whether or not to X has an
+- addXIntercept = boolean flag indicating whether or not to X has an
intercept column
-- hasZIntercept = boolean flag indicating whether or not to Z has an
+- addZIntercept = boolean flag indicating whether or not to Z has an
intercept column
- meansX = 2d array of column means of X, obtained prior to standardizing X
- meansZ = 2d array of column means of Z, obtained prior to standardizing Z
@@ -123,33 +123,33 @@ None; back-transforms B in place
"""
function backtransform!(B::AbstractArray{Float64,2},
- hasXIntercept::Bool, hasZIntercept::Bool,
+ addXIntercept::Bool, addZIntercept::Bool,
meansX::AbstractArray{Float64,2},
meansZ::AbstractArray{Float64,2},
normsX::AbstractArray{Float64,2},
normsZ::AbstractArray{Float64,2})
# Back transform the X intercepts (row main effects), if necessary
- if hasXIntercept == true
+ if addXIntercept == true
prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end, 2:end]
B[1,2:end] = (B[1,2:end]-vec(prodX))./vec(normsZ[:,2:end])/normsX[1,1]
end
# Back transform the Z intercepts (column main effects), if necessary
- if hasZIntercept == true
+ if addZIntercept == true
prodZ = B[2:end, 2:end]*transpose(meansZ[:,2:end]./normsZ[:,2:end])
B[2:end,1] = (B[2:end,1]-prodZ)./transpose(normsX[:,2:end])/
normsZ[1,1]
end
# Back transform the interactions, if necessary
- if (hasXIntercept == true) || (hasZIntercept == true)
+ if (addXIntercept == true) || (addZIntercept == true)
B[2:end, 2:end] = B[2:end, 2:end]./transpose(normsX[:,2:end])./
normsZ[:,2:end]
end
# Back transform the interactions if not including any main effects
- if (hasXIntercept == false) && (hasZIntercept == false)
+ if (addXIntercept == false) && (addZIntercept == false)
B = B./transpose(normsX)./normsZ
end
end
@@ -227,7 +227,7 @@ end
"""
backtransform!(B::AbstractArray{Float64,4},
- hasXIntercept::Bool, hasZIntercept::Bool,
+ addXIntercept::Bool, addZIntercept::Bool,
meansX::AbstractArray{Float64,2},
meansZ::AbstractArray{Float64,2},
normsX::AbstractArray{Float64,2},
@@ -240,9 +240,9 @@ or Z.
# Arguments
- B = 4d array of coefficient estimates
-- hasXIntercept = boolean flag indicating whether or not to X has an
+- addXIntercept = boolean flag indicating whether or not to X has an
intercept column
-- hasZIntercept = boolean flag indicating whether or not to Z has an
+- addZIntercept = boolean flag indicating whether or not to Z has an
intercept column
- meansX = 2d array of column means of X, obtained prior to standardizing X
- meansZ = 2d array of column means of Z, obtained prior to standardizing Z
@@ -260,7 +260,7 @@ dimension.
"""
function backtransform!(B::AbstractArray{Float64,4},
- hasXIntercept::Bool, hasZIntercept::Bool,
+ addXIntercept::Bool, addZIntercept::Bool,
meansX::AbstractArray{Float64,2},
meansZ::AbstractArray{Float64,2},
normsX::AbstractArray{Float64,2},
@@ -271,14 +271,14 @@ function backtransform!(B::AbstractArray{Float64,4},
for j in 1:size(B,4)
for i in 1:size(B,3)
# Back transform the X intercepts (row main effects), if necessary
- if hasXIntercept == true
+ if addXIntercept == true
prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end,2:end,i,j]
B[1,2:end,i,j] = (B[1,2:end,i,j]-vec(prodX))./vec(normsZ[:,2:end])/
normsX[1,1]
end
# Back transform the Z intercepts (column main effects), if necessary
- if hasZIntercept == true
+ if addZIntercept == true
prodZ = B[2:end,2:end,i,j]*transpose(meansZ[:,2:end]./
normsZ[:,2:end])
B[2:end,1,i,j] = (B[2:end,1,i,j]-prodZ)./transpose(normsX[:,2:end])/
@@ -286,13 +286,13 @@ function backtransform!(B::AbstractArray{Float64,4},
end
# Back transform the interactions, if necessary
- if (hasXIntercept == true) || (hasZIntercept == true)
+ if (addXIntercept == true) || (addZIntercept == true)
B[2:end,2:end,i,j] = B[2:end,2:end,i,j]./transpose(normsX[:,2:end])./
normsZ[:,2:end]
end
# Back transform the interactions if not including any main effects
- if (hasXIntercept == false) && (hasZIntercept == false)
+ if (addXIntercept == false) && (addZIntercept == false)
B[:,:,i,j] = B[:,:,i,j]./transpose(normsX)./normsZ
end
end
diff --git a/test/Project.toml b/test/Project.toml
deleted file mode 100644
index af82aa4..0000000
--- a/test/Project.toml
+++ /dev/null
@@ -1,9 +0,0 @@
-[deps]
-Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
-Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa"
-LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
-
-[compat]
-Distributions = "0.25"
-Helium = "0.2"
diff --git a/test/generate_testing_dataset.jl b/test/generate_testing_dataset.jl
index 7f4dd0f..3527204 100644
--- a/test/generate_testing_dataset.jl
+++ b/test/generate_testing_dataset.jl
@@ -77,24 +77,24 @@ dat = MatrixLMnet.MatrixLM.RawData(Response(Y), Predictors(X, Z));
###############
# Lasso penalized regression - ista
-est = mlmnet(ista!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet(ista!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false);
est_B_Lasso_ista = est.B[:, :, 3];
# Lasso penalized regression - fista
-est = mlmnet(fista!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet(fista!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false);
est_B_Lasso_fista = est.B[:, :, 3];
# Lasso penalized regression - fista backtracking
-est = mlmnet(fista_bt!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet(fista_bt!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false);
est_B_Lasso_fista_bt = est.B[:, :, 3];
# Lasso penalized regression - admm
-est = mlmnet(admm!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet(admm!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false);
est_B_Lasso_admm = est.B[:, :, 3];
# Lasso penalized regression - cd
Random.seed!(rng)
-est = mlmnet(cd!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet(cd!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false);
est_B_Lasso_cd = est.B[:, :, 3];
#################################
@@ -104,31 +104,31 @@ est_B_Lasso_cd = est.B[:, :, 3];
# Lasso penalized regression - ista cv
Random.seed!(rng)
-est = mlmnet_cv(ista!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet_cv(ista!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Lasso = lambda_min_deprecated(est);
smmr_ista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero)
# Lasso penalized regression - fista cv
Random.seed!(rng)
-est = mlmnet_cv(fista!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet_cv(fista!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Lasso = lambda_min_deprecated(est);
smmr_fista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero)
# Lasso penalized regression - fista-bt cv
Random.seed!(rng)
-est = mlmnet_cv(fista_bt!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet_cv(fista_bt!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Lasso = lambda_min_deprecated(est);
smmr_fistabt = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero)
# Lasso penalized regression - admm cv
Random.seed!(rng)
-est = mlmnet_cv(admm!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet_cv(admm!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Lasso = lambda_min_deprecated(est);
smmr_admm = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero)
# Lasso penalized regression - cd cv
Random.seed!(rng)
-est = mlmnet_cv(cd!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est = mlmnet_cv(cd!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Lasso = lambda_min_deprecated(est);
smmr_cd = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero)
diff --git a/test/mlmnetBicTests.jl b/test/mlmnetBicTests.jl
index 95669a7..2782514 100644
--- a/test/mlmnetBicTests.jl
+++ b/test/mlmnetBicTests.jl
@@ -2,9 +2,9 @@
# Library #
###########
# using Random
-using MatrixLMnet, Distributions, LinearAlgebra
-using Helium
-using Test
+# using MatrixLMnet, Distributions, LinearAlgebra
+# using Helium
+# using Test
########################################
# TEST BIC Validation - Simulated Data #
@@ -76,20 +76,20 @@ numVersion = VERSION
if Int(numVersion.minor) < 7
tolVersion=2e-1
else
- tolVersion=1e-6
+ tolVersion=1e-5
end
####################################
# TEST BIC Validation - Estimation #
####################################
-est = mlmnet(dat, 位, 伪; method = "fista_bt", hasXIntercept = false, hasZIntercept=false, isVerbose = false);
+est = mlmnet(dat, 位, 伪; method = "fista_bt", addXIntercept = false, addZIntercept=false, isVerbose = false);
#############################
# TEST BIC Validation - BIC #
#############################
-est_BIC = mlmnet_bic(dat, 位, 伪; method = "fista_bt", hasXIntercept = false, hasZIntercept=false, isVerbose = false);
+est_BIC = mlmnet_bic(dat, 位, 伪; method = "fista_bt", addXIntercept = false, addZIntercept=false, isVerbose = false);
df_BIC = mlmnet_bic_summary(est_BIC);
@@ -106,7 +106,7 @@ for i in 1:length(est.lambdas), j in 1:length(est.alphas)
# BIC for (lambdas i, alphas j)
k = sum(est.B[:,:,i,j] .!= 0.0, dims = 1) .+ m;
- distResids = MvNormal(zeros(m), (sqrt.(sum(resids[:,:,i,j], dims = 1)./n))[:]);
+ distResids = MvNormal(zeros(m), LinearAlgebra.Diagonal(map(abs2, (sqrt.(sum(resids[:,:,i,j], dims = 1)./n))[:])));
L虃 = loglikelihood(distResids, permutedims(resids[:,:,i,j]))
BIC2[i,j] = sum(k)*log(n) - 2*(L虃)
end;
diff --git a/test/mlmnetCvTests.jl b/test/mlmnetCvTests.jl
index 351d25d..0c7e1cc 100644
--- a/test/mlmnetCvTests.jl
+++ b/test/mlmnetCvTests.jl
@@ -48,7 +48,7 @@ numVersion = VERSION
if Int(numVersion.minor) < 7
tolVersion=2e-1
else
- tolVersion=1e-6
+ tolVersion=1e-5
end
#############################################
@@ -57,12 +57,12 @@ end
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net1 = MatrixLMnet.lambda_min(est1);
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net2 = MatrixLMnet.lambda_min(est2);
# Lasso penalized regression - ista cv
@@ -78,12 +78,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 1 - ista: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net1 = MatrixLMnet.lambda_min(est1);
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net2 = MatrixLMnet.lambda_min(est2);
# Lasso penalized regression - fista cv
@@ -99,12 +99,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 2 - fista: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net1 = MatrixLMnet.lambda_min(est1);
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net2 = MatrixLMnet.lambda_min(est2);
# Lasso penalized regression - fista-bt cv
@@ -121,12 +121,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 3 - fista-bt: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net1 = MatrixLMnet.lambda_min(est1);
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net2 = MatrixLMnet.lambda_min(est2);
# Lasso penalized regression - fista-bt cv
@@ -142,12 +142,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 4 - admm: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net1 = MatrixLMnet.lambda_min(est1);
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false);
smmr_Net2 = MatrixLMnet.lambda_min(est2);
# Lasso penalized regression - cd cv
diff --git a/test/mlmnetTests.jl b/test/mlmnetTests.jl
index 5435b87..7619d31 100644
--- a/test/mlmnetTests.jl
+++ b/test/mlmnetTests.jl
@@ -1,12 +1,12 @@
-###########
-# Library #
-###########
-# using MatrixLM
-# using Distributions, Random, Statistics, LinearAlgebra, StatsBase
-# using Random
-using MatrixLMnet
-using Helium
-using Test
+# ###########
+# # Library #
+# ###########
+# # using MatrixLM
+# # using Distributions, Random, Statistics, LinearAlgebra, StatsBase
+# # using Random
+# using MatrixLMnet
+# using Helium
+# using Test
####################################################
@@ -54,11 +54,11 @@ rng = MatrixLMnet.Random.MersenneTwister(2021)
# Elastic net penalized regression
-est_ista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_ista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_ista_1 = est_ista_1.B[:, :, 3, 1];
# Elastic net penalized regression
-est_ista_2 = MatrixLMnet.mlmnet(dat, 位, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_ista_2 = MatrixLMnet.mlmnet(dat, 位, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_ista_2 = est_ista_2.B[:, :, 3, 1];
# Lasso penalized regression - ista
@@ -71,11 +71,11 @@ println("Lasso vs Elastic Net when 伪=1 test 1 - ista: ", @test (B_Net_ista_1
#############################################
# Elastic net penalized regression
-est_fista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_fista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_fista_1 = est_fista_1.B[:, :, 3, 1];
# Elastic net penalized regression
-est_fista_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_fista_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_fista_2 = est_fista_2.B[:, :, 3, 1];
# Lasso penalized regression - fista
@@ -88,11 +88,11 @@ println("Lasso vs Elastic Net when 伪=1 test 2 - fista: ", @test (B_Net_fista_1
##########################################################
# Elastic net penalized regression
-est_fistabt_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_fistabt_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_fistabt_1 = est_fistabt_1.B[:, :, 3, 1];
# Elastic net penalized regression
-est_fistabt_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_fistabt_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_fistabt_2 = est_fistabt_2.B[:, :, 3, 1];
# Lasso penalized regression - fista-bt
@@ -106,11 +106,11 @@ println("Lasso vs Elastic Net when 伪=1 test 3 - fista-bt: ", @test (B_Net_fista
############################################
# Elastic net penalized regression
-est_admm_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_admm_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_admm_1 = est_admm_1.B[:, :, 3, 1];
# Elastic net penalized regression
-est_admm_2 = MatrixLMnet.mlmnet(dat, 位, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_admm_2 = MatrixLMnet.mlmnet(dat, 位, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_admm_2 = est_admm_2.B[:, :, 3, 1];
# Lasso penalized regression - admm
@@ -125,12 +125,12 @@ println("Lasso vs Elastic Net when 伪=1 test 4 - admm: ", @test (B_Net_admm_1
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est_cd_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_cd_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_cd_1 = est_cd_1.B[:, :, 3, 1];
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est_cd_2 = MatrixLMnet.mlmnet(dat, 位, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est_cd_2 = MatrixLMnet.mlmnet(dat, 位, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false);
B_Net_cd_2 = est_cd_2.B[:, :, 3, 1];
# Lasso penalized regression - cd
diff --git a/test/runtests.jl b/test/runtests.jl
index 7ebd2fb..aea01a8 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,4 +1,6 @@
using MatrixLMnet
+using Distributions, LinearAlgebra
+using Helium
using Test
@testset "MatrixLMnet" begin
diff --git a/test/summaryCvTests.jl b/test/summaryCvTests.jl
index 98f2bee..3359ef4 100644
--- a/test/summaryCvTests.jl
+++ b/test/summaryCvTests.jl
@@ -1,10 +1,10 @@
-###########
-# Library #
-###########
-# using Random
-using MatrixLMnet
-using Helium
-using Test
+# ###########
+# # Library #
+# ###########
+# # using Random
+# using MatrixLMnet
+# using Helium
+# using Test
#####################################################################
# TEST Cross Validation Lasso vs Elastic Net (饾浖=1) - Simulated Data #
@@ -48,7 +48,7 @@ numVersion = VERSION
if Int(numVersion.minor) < 7
tolVersion=2e-1
else
- tolVersion=1e-6
+ tolVersion=1e-5
end
@@ -58,7 +58,7 @@ end
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1);
smmr_min_Net1 = MatrixLMnet.lambda_min(est1);
@@ -70,7 +70,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) &&
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2);
smmr_min_Net2 = MatrixLMnet.lambda_min(est2);
@@ -90,7 +90,7 @@ println("Summary cross-validation test 1 - ista: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1);
smmr_min_Net1 = MatrixLMnet.lambda_min(est1);
@@ -102,7 +102,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) &&
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2);
smmr_min_Net2 = MatrixLMnet.lambda_min(est2);
@@ -122,7 +122,7 @@ println("Summary cross-validation test 2 - fista: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1);
smmr_min_Net1 = MatrixLMnet.lambda_min(est1);
@@ -134,7 +134,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) &&
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2);
smmr_min_Net2 = MatrixLMnet.lambda_min(est2);
@@ -154,7 +154,7 @@ println("Summary cross-validation test 3 - fista_bt: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1);
smmr_min_Net1 = MatrixLMnet.lambda_min(est1);
@@ -166,7 +166,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) &&
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2);
smmr_min_Net2 = MatrixLMnet.lambda_min(est2);
@@ -186,7 +186,7 @@ println("Summary cross-validation test 4 - admm: ",
# Elastic net penalized regression
MatrixLMnet.Random.seed!(rng)
-est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false);
+est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false);
# Summaries
smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2);
smmr_min_Net2 = MatrixLMnet.lambda_min(est2);
@@ -196,5 +196,5 @@ test_Lasso = (smmr_Net2[idxSmmr, :AvgMSE] == smmr_min_Net2[1, :AvgMSE]) &&
(smmr_Net2[idxSmmr, :Lambda] == smmr_min_Net2[1, :Lambda]) &&
(smmr_Net2[idxSmmr, :Alpha] == smmr_min_Net2[1, :Alpha])
-println("Summary cross-validation test 5 - admm: ",
+println("Summary cross-validation test 5 - cd: ",
@test (test_Lasso))
\ No newline at end of file
diff --git a/test/utilitiesTests.jl b/test/utilitiesTests.jl
index 9ed236d..d1ba0de 100644
--- a/test/utilitiesTests.jl
+++ b/test/utilitiesTests.jl
@@ -36,7 +36,7 @@ rng = MatrixLMnet.Random.MersenneTwister(2021)
#Test the dimension of results by util functions #
###################################################
-est = mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = true, hasXIntercept = true, isVerbose = true)
+est = mlmnet(dat, 位, 伪, method = "cd", addZIntercept = true, addXIntercept = true, isVerbose = true)
predicted = predict(est, est.data.predictors)
@@ -59,7 +59,7 @@ end
newPredictors = Predictors(X, Z, false, false)
predicted = predict(est, newPredictors)
-est2 = mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = true)
+est2 = mlmnet(dat, 位, 伪, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = true)
newPredictors2 = Predictors(hcat(ones(size(X, 1)), X), hcat(ones(size(Z, 1)), Z), true, true)
predicted2 = predict(est2, newPredictors2)