diff --git a/Project.toml b/Project.toml index aa289fc..78a8391 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixLMnet" uuid = "436227dc-6146-11e9-240b-9df6ee45b799" authors = ["Jane Liang", "Zifan Yu", "Gregory Farage", "Saunak Sen"] -version = "1.0.2" +version = "1.1.0" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -15,5 +15,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] DataFrames = "^1.3.0" MLBase = "0.8, 0.9" -MatrixLM = "^0.1.3" +MatrixLM = "~0.2" +Statistics = "1" julia = "1.6.4, 1" + +[extras] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["Distributions", "Helium", "Test"] \ No newline at end of file diff --git a/README.md b/README.md index 1612ec5..6b2a069 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ alphas = reverse(collect(0:0.1:1)) L1 and L2-penalized estimates for matrix linear models can be obtained by running `mlmnet`. In addition to the `RawData` object, `lambdas` and `alphas`, `mlmnet` requires as a keyword argument the function name for an algorithm used to fit Elastic Net penalized estimates. Current methods are: `"cd"` (coordinate descent), `"cd_active"` (active coordinate descent), `"ista"` (ISTA with fixed step size), `"fista"` (FISTA with fixed step size), `"fista_bt"` (FISTA with backtracking), and `"admm"` (ADMM). -An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `hasXIntercept=false` and/or `hasZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations. +An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `addXIntercept=false` and/or `addZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations. ``` est = mlmnet(dat, lambdas, alphas, method = "fista_bt") diff --git a/src/bic/mlmnet_bic.jl b/src/bic/mlmnet_bic.jl index 8e514d4..b3713bd 100644 --- a/src/bic/mlmnet_bic.jl +++ b/src/bic/mlmnet_bic.jl @@ -31,7 +31,7 @@ end lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -55,9 +55,9 @@ Performs BIC validation for `mlmnet`. default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` - isNaive = boolean flag indicating whether to solve the Naive or non-Naive Elastic-net problem -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -94,7 +94,7 @@ function mlmnet_bic(data::RawData, lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -107,8 +107,8 @@ function mlmnet_bic(data::RawData, # Run mlmnet on each RawData object, in parallel when possible MLMNet = mlmnet(data, lambdas, alphas; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, - hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, + addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, diff --git a/src/crossvalidation/mlmnet_cv.jl b/src/crossvalidation/mlmnet_cv.jl index 1421e5a..ea1e10c 100644 --- a/src/crossvalidation/mlmnet_cv.jl +++ b/src/crossvalidation/mlmnet_cv.jl @@ -41,7 +41,7 @@ end rowFolds::Array{Array{Int64,1},1}, colFolds::Array{Array{Int64,1},1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -74,9 +74,9 @@ input. default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` - isNaive = boolean flag indicating whether to solve the Naive or non-Naive Elastic-net problem -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -116,7 +116,7 @@ function mlmnet_cv(data::RawData, rowFolds::Array{Array{Int64,1},1}, colFolds::Array{Array{Int64,1},1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -156,8 +156,8 @@ function mlmnet_cv(data::RawData, MLMNets = Distributed.pmap(data -> mlmnet(data, lambdas, alphas; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, - hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, + addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, @@ -177,7 +177,7 @@ end alphas::AbstractArray{Float64,1}, rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -207,9 +207,9 @@ Calls the base `mlmnet_cv` function. - methods = function name that applies the Elastic-net penalty estimate method; default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -248,7 +248,7 @@ function mlmnet_cv(data::RawData, alphas::AbstractArray{Float64,1}, rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -263,7 +263,7 @@ function mlmnet_cv(data::RawData, # base mlmnet_cv function mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, @@ -278,7 +278,7 @@ end alphas::AbstractArray{Float64,1}, nRowFolds::Int64, colFolds::Array{Array{Int64,1},1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -308,9 +308,9 @@ input. Calls the base `mlmnet_cv` function. - methods = function name that applies the Elastic-net penalty estimate method; default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -349,7 +349,7 @@ function mlmnet_cv(data::RawData, alphas::AbstractArray{Float64,1}, nRowFolds::Int64, colFolds::Array{Array{Int64,1},1}; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -364,7 +364,7 @@ function mlmnet_cv(data::RawData, # base mlmnet_cv function mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, @@ -379,7 +379,7 @@ end nRowFolds::Int64=10, nColFolds::Int64=10; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -409,9 +409,9 @@ folds randomly generated using calls to `make_folds`. Calls the base - methods = function name that applies the Elastic-net penalty estimate method; default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -465,7 +465,7 @@ function mlmnet_cv(data::RawData, nRowFolds::Int64=10, nColFolds::Int64=10; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -481,7 +481,7 @@ function mlmnet_cv(data::RawData, # function mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, @@ -496,7 +496,7 @@ end nRowFolds::Int64=10, nColFolds::Int64=10; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -516,7 +516,7 @@ function mlmnet_cv(data::RawData, nRowFolds::Int64=10, nColFolds::Int64=10; method::String="ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(size(get_X(data), 2)), toZReg::BitArray{1}=trues(size(get_Z(data), 2)), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -532,7 +532,7 @@ function mlmnet_cv(data::RawData, # function mlmnet_cv(data, lambdas, alphas, nRowFolds, nColFolds; method=method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, diff --git a/src/mlmnet/mlmnet.jl b/src/mlmnet/mlmnet.jl index 6ecc8e9..434e5a0 100644 --- a/src/mlmnet/mlmnet.jl +++ b/src/mlmnet/mlmnet.jl @@ -156,7 +156,7 @@ end lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -184,9 +184,9 @@ inputs. default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd` - isNaive = boolean flag indicating whether to solve the Naive or non-Naive Elastic-net problem -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -235,7 +235,7 @@ function mlmnet(data::RawData, lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -263,13 +263,13 @@ function mlmnet(data::RawData, # Add X and Z intercepts if necessary # Update toXReg and toZReg accordingly - if hasXIntercept==true && data.predictors.hasXIntercept==false + if addXIntercept==true && data.predictors.hasXIntercept==false data.predictors.X = add_intercept(data.predictors.X) data.predictors.hasXIntercept = true data.p = data.p + 1 toXReg = vcat(toXInterceptReg, toXReg) end - if hasZIntercept==true && data.predictors.hasZIntercept==false + if addZIntercept==true && data.predictors.hasZIntercept==false data.predictors.Z = add_intercept(data.predictors.Z) data.predictors.hasZIntercept = true data.q = data.q + 1 @@ -278,14 +278,14 @@ function mlmnet(data::RawData, # Remove X and Z intercepts in new predictors if necessary # Update toXReg and toZReg accordingly - if hasXIntercept==false && data.predictors.hasXIntercept==true + if addXIntercept==false && data.predictors.hasXIntercept==true data.predictors.X = remove_intercept(data.predictors.X) data.predictors.hasXIntercept = false data.p = data.p - 1 toXReg = toXReg[2:end] end - if hasZIntercept==false && data.predictors.hasZIntercept==true + if addZIntercept==false && data.predictors.hasZIntercept==true data.predictors.Z = remove_intercept(data.predictors.Z) data.predictors.hasZIntercept = false data.q = data.q - 1 @@ -293,10 +293,10 @@ function mlmnet(data::RawData, end # Update toXReg and toZReg accordingly when intercept is already included - if hasXIntercept==true && data.predictors.hasXIntercept==true + if addXIntercept==true && data.predictors.hasXIntercept==true toXReg[1] = toXInterceptReg end - if hasZIntercept==true && data.predictors.hasZIntercept==true + if addZIntercept==true && data.predictors.hasZIntercept==true toZReg[1] = toZInterceptReg end @@ -314,8 +314,8 @@ function mlmnet(data::RawData, Z = copy(get_Z(data)) # Centers and normalizes predictors - meansX, normsX, = normalize!(X, hasXIntercept) - meansZ, normsZ, = normalize!(Z, hasZIntercept) + meansX, normsX, = normalize!(X, addXIntercept) + meansZ, normsZ, = normalize!(Z, addZIntercept) # If X and Z are standardized, set the norm to nothing norms = nothing else @@ -375,11 +375,11 @@ function mlmnet(data::RawData, # Back-transform coefficient estimates, if necessary. # Case if including both X and Z intercepts: - if toNormalize == true && (hasXIntercept==true) && (hasZIntercept==true) + if toNormalize == true && (addXIntercept==true) && (addZIntercept==true) backtransform!(coeffs, meansX, meansZ, normsX, normsZ, get_Y(data), data.predictors.X, data.predictors.Z) elseif toNormalize == true # Otherwise - backtransform!(coeffs, hasXIntercept, hasZIntercept, meansX, meansZ, + backtransform!(coeffs, addXIntercept, addZIntercept, meansX, meansZ, normsX, normsZ) end @@ -399,7 +399,7 @@ function mlmnet(data::RawData, lambdas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -413,7 +413,7 @@ function mlmnet(data::RawData, lambdas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, @@ -425,7 +425,7 @@ function mlmnet(data::RawData, alphas = [1.0] # default LASSO, 饾浖 = 1 rslts = mlmnet(data, lambdas, alphas; method, - isNaive, hasXIntercept, hasZIntercept, + isNaive, addXIntercept, addZIntercept, toXReg, toZReg, toXInterceptReg, toZInterceptReg, toNormalize, isVerbose, diff --git a/src/mlmnet/mlmnet_perms.jl b/src/mlmnet/mlmnet_perms.jl index 1f50caf..b72d7aa 100644 --- a/src/mlmnet/mlmnet_perms.jl +++ b/src/mlmnet/mlmnet_perms.jl @@ -3,7 +3,7 @@ lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, permFun::Function=shuffle_rows, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, @@ -30,9 +30,9 @@ function. Elastic-net problem - permFun = function used to permute `Y`. Defaults to `shuffle_rows` (shuffles rows of `Y`). -- hasXIntercept = boolean flag indicating whether or not to include an `X` +- addXIntercept = boolean flag indicating whether or not to include an `X` intercept (row main effects). Defaults to `true`. -- hasZIntercept = boolean flag indicating whether or not to include a `Z` +- addZIntercept = boolean flag indicating whether or not to include a `Z` intercept (column main effects). Defaults to `true`. - toXReg = 1d array of bit flags indicating whether or not to regularize each of the `X` (row) effects. Defaults to 2d array of `true`s with length @@ -81,7 +81,7 @@ function mlmnet_perms(data::RawData, lambdas::AbstractArray{Float64,1}, alphas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, permFun::Function=shuffle_rows, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, @@ -95,7 +95,7 @@ function mlmnet_perms(data::RawData, # Run penalty on the permuted data return mlmnet(dataPerm, lambdas, alphas; method = method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, toNormalize=toNormalize, isVerbose=isVerbose, @@ -110,7 +110,7 @@ end lambdas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, permFun::Function=shuffle_rows, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, @@ -124,7 +124,7 @@ function mlmnet_perms(data::RawData, lambdas::AbstractArray{Float64,1}; method::String = "ista", isNaive::Bool=false, permFun::Function=shuffle_rows, - hasXIntercept::Bool=true, hasZIntercept::Bool=true, + addXIntercept::Bool=true, addZIntercept::Bool=true, toXReg::BitArray{1}=trues(data.p), toZReg::BitArray{1}=trues(data.q), toXInterceptReg::Bool=false, @@ -140,7 +140,7 @@ function mlmnet_perms(data::RawData, # Run L1-L2 penalties on the permuted data rslts = mlmnet(dataPerm, lambdas, alphas; method = method, isNaive=isNaive, - hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, toXReg=toXReg, toZReg=toZReg, toXInterceptReg=toXInterceptReg, toZInterceptReg=toZInterceptReg, toNormalize=toNormalize, isVerbose=isVerbose, diff --git a/src/utilities/std_helpers.jl b/src/utilities/std_helpers.jl index 0d392cb..4f33747 100644 --- a/src/utilities/std_helpers.jl +++ b/src/utilities/std_helpers.jl @@ -95,7 +95,7 @@ end """ backtransform!(B::AbstractArray{Float64,2}, - hasXIntercept::Bool, hasZIntercept::Bool, + addXIntercept::Bool, addZIntercept::Bool, meansX::AbstractArray{Float64,2}, meansZ::AbstractArray{Float64,2}, normsX::AbstractArray{Float64,2}, @@ -108,9 +108,9 @@ or Z. # Arguments - B = 2d array of coefficient estimates B -- hasXIntercept = boolean flag indicating whether or not to X has an +- addXIntercept = boolean flag indicating whether or not to X has an intercept column -- hasZIntercept = boolean flag indicating whether or not to Z has an +- addZIntercept = boolean flag indicating whether or not to Z has an intercept column - meansX = 2d array of column means of X, obtained prior to standardizing X - meansZ = 2d array of column means of Z, obtained prior to standardizing Z @@ -123,33 +123,33 @@ None; back-transforms B in place """ function backtransform!(B::AbstractArray{Float64,2}, - hasXIntercept::Bool, hasZIntercept::Bool, + addXIntercept::Bool, addZIntercept::Bool, meansX::AbstractArray{Float64,2}, meansZ::AbstractArray{Float64,2}, normsX::AbstractArray{Float64,2}, normsZ::AbstractArray{Float64,2}) # Back transform the X intercepts (row main effects), if necessary - if hasXIntercept == true + if addXIntercept == true prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end, 2:end] B[1,2:end] = (B[1,2:end]-vec(prodX))./vec(normsZ[:,2:end])/normsX[1,1] end # Back transform the Z intercepts (column main effects), if necessary - if hasZIntercept == true + if addZIntercept == true prodZ = B[2:end, 2:end]*transpose(meansZ[:,2:end]./normsZ[:,2:end]) B[2:end,1] = (B[2:end,1]-prodZ)./transpose(normsX[:,2:end])/ normsZ[1,1] end # Back transform the interactions, if necessary - if (hasXIntercept == true) || (hasZIntercept == true) + if (addXIntercept == true) || (addZIntercept == true) B[2:end, 2:end] = B[2:end, 2:end]./transpose(normsX[:,2:end])./ normsZ[:,2:end] end # Back transform the interactions if not including any main effects - if (hasXIntercept == false) && (hasZIntercept == false) + if (addXIntercept == false) && (addZIntercept == false) B = B./transpose(normsX)./normsZ end end @@ -227,7 +227,7 @@ end """ backtransform!(B::AbstractArray{Float64,4}, - hasXIntercept::Bool, hasZIntercept::Bool, + addXIntercept::Bool, addZIntercept::Bool, meansX::AbstractArray{Float64,2}, meansZ::AbstractArray{Float64,2}, normsX::AbstractArray{Float64,2}, @@ -240,9 +240,9 @@ or Z. # Arguments - B = 4d array of coefficient estimates -- hasXIntercept = boolean flag indicating whether or not to X has an +- addXIntercept = boolean flag indicating whether or not to X has an intercept column -- hasZIntercept = boolean flag indicating whether or not to Z has an +- addZIntercept = boolean flag indicating whether or not to Z has an intercept column - meansX = 2d array of column means of X, obtained prior to standardizing X - meansZ = 2d array of column means of Z, obtained prior to standardizing Z @@ -260,7 +260,7 @@ dimension. """ function backtransform!(B::AbstractArray{Float64,4}, - hasXIntercept::Bool, hasZIntercept::Bool, + addXIntercept::Bool, addZIntercept::Bool, meansX::AbstractArray{Float64,2}, meansZ::AbstractArray{Float64,2}, normsX::AbstractArray{Float64,2}, @@ -271,14 +271,14 @@ function backtransform!(B::AbstractArray{Float64,4}, for j in 1:size(B,4) for i in 1:size(B,3) # Back transform the X intercepts (row main effects), if necessary - if hasXIntercept == true + if addXIntercept == true prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end,2:end,i,j] B[1,2:end,i,j] = (B[1,2:end,i,j]-vec(prodX))./vec(normsZ[:,2:end])/ normsX[1,1] end # Back transform the Z intercepts (column main effects), if necessary - if hasZIntercept == true + if addZIntercept == true prodZ = B[2:end,2:end,i,j]*transpose(meansZ[:,2:end]./ normsZ[:,2:end]) B[2:end,1,i,j] = (B[2:end,1,i,j]-prodZ)./transpose(normsX[:,2:end])/ @@ -286,13 +286,13 @@ function backtransform!(B::AbstractArray{Float64,4}, end # Back transform the interactions, if necessary - if (hasXIntercept == true) || (hasZIntercept == true) + if (addXIntercept == true) || (addZIntercept == true) B[2:end,2:end,i,j] = B[2:end,2:end,i,j]./transpose(normsX[:,2:end])./ normsZ[:,2:end] end # Back transform the interactions if not including any main effects - if (hasXIntercept == false) && (hasZIntercept == false) + if (addXIntercept == false) && (addZIntercept == false) B[:,:,i,j] = B[:,:,i,j]./transpose(normsX)./normsZ end end diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index af82aa4..0000000 --- a/test/Project.toml +++ /dev/null @@ -1,9 +0,0 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -Distributions = "0.25" -Helium = "0.2" diff --git a/test/generate_testing_dataset.jl b/test/generate_testing_dataset.jl index 7f4dd0f..3527204 100644 --- a/test/generate_testing_dataset.jl +++ b/test/generate_testing_dataset.jl @@ -77,24 +77,24 @@ dat = MatrixLMnet.MatrixLM.RawData(Response(Y), Predictors(X, Z)); ############### # Lasso penalized regression - ista -est = mlmnet(ista!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet(ista!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_ista = est.B[:, :, 3]; # Lasso penalized regression - fista -est = mlmnet(fista!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet(fista!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_fista = est.B[:, :, 3]; # Lasso penalized regression - fista backtracking -est = mlmnet(fista_bt!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet(fista_bt!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_fista_bt = est.B[:, :, 3]; # Lasso penalized regression - admm -est = mlmnet(admm!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet(admm!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_admm = est.B[:, :, 3]; # Lasso penalized regression - cd Random.seed!(rng) -est = mlmnet(cd!, dat, 位, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet(cd!, dat, 位, addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_cd = est.B[:, :, 3]; ################################# @@ -104,31 +104,31 @@ est_B_Lasso_cd = est.B[:, :, 3]; # Lasso penalized regression - ista cv Random.seed!(rng) -est = mlmnet_cv(ista!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet_cv(ista!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Lasso = lambda_min_deprecated(est); smmr_ista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - fista cv Random.seed!(rng) -est = mlmnet_cv(fista!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet_cv(fista!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Lasso = lambda_min_deprecated(est); smmr_fista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - fista-bt cv Random.seed!(rng) -est = mlmnet_cv(fista_bt!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet_cv(fista_bt!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Lasso = lambda_min_deprecated(est); smmr_fistabt = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - admm cv Random.seed!(rng) -est = mlmnet_cv(admm!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet_cv(admm!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Lasso = lambda_min_deprecated(est); smmr_admm = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - cd cv Random.seed!(rng) -est = mlmnet_cv(cd!, dat, 位, 10, 1, hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est = mlmnet_cv(cd!, dat, 位, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Lasso = lambda_min_deprecated(est); smmr_cd = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) diff --git a/test/mlmnetBicTests.jl b/test/mlmnetBicTests.jl index 95669a7..2782514 100644 --- a/test/mlmnetBicTests.jl +++ b/test/mlmnetBicTests.jl @@ -2,9 +2,9 @@ # Library # ########### # using Random -using MatrixLMnet, Distributions, LinearAlgebra -using Helium -using Test +# using MatrixLMnet, Distributions, LinearAlgebra +# using Helium +# using Test ######################################## # TEST BIC Validation - Simulated Data # @@ -76,20 +76,20 @@ numVersion = VERSION if Int(numVersion.minor) < 7 tolVersion=2e-1 else - tolVersion=1e-6 + tolVersion=1e-5 end #################################### # TEST BIC Validation - Estimation # #################################### -est = mlmnet(dat, 位, 伪; method = "fista_bt", hasXIntercept = false, hasZIntercept=false, isVerbose = false); +est = mlmnet(dat, 位, 伪; method = "fista_bt", addXIntercept = false, addZIntercept=false, isVerbose = false); ############################# # TEST BIC Validation - BIC # ############################# -est_BIC = mlmnet_bic(dat, 位, 伪; method = "fista_bt", hasXIntercept = false, hasZIntercept=false, isVerbose = false); +est_BIC = mlmnet_bic(dat, 位, 伪; method = "fista_bt", addXIntercept = false, addZIntercept=false, isVerbose = false); df_BIC = mlmnet_bic_summary(est_BIC); @@ -106,7 +106,7 @@ for i in 1:length(est.lambdas), j in 1:length(est.alphas) # BIC for (lambdas i, alphas j) k = sum(est.B[:,:,i,j] .!= 0.0, dims = 1) .+ m; - distResids = MvNormal(zeros(m), (sqrt.(sum(resids[:,:,i,j], dims = 1)./n))[:]); + distResids = MvNormal(zeros(m), LinearAlgebra.Diagonal(map(abs2, (sqrt.(sum(resids[:,:,i,j], dims = 1)./n))[:]))); L虃 = loglikelihood(distResids, permutedims(resids[:,:,i,j])) BIC2[i,j] = sum(k)*log(n) - 2*(L虃) end; diff --git a/test/mlmnetCvTests.jl b/test/mlmnetCvTests.jl index 351d25d..0c7e1cc 100644 --- a/test/mlmnetCvTests.jl +++ b/test/mlmnetCvTests.jl @@ -48,7 +48,7 @@ numVersion = VERSION if Int(numVersion.minor) < 7 tolVersion=2e-1 else - tolVersion=1e-6 + tolVersion=1e-5 end ############################################# @@ -57,12 +57,12 @@ end # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - ista cv @@ -78,12 +78,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 1 - ista: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista cv @@ -99,12 +99,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 2 - fista: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista-bt cv @@ -121,12 +121,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 3 - fista-bt: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista-bt cv @@ -142,12 +142,12 @@ println("CV Lasso vs Elastic Net when 伪=1 test 4 - admm: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - cd cv diff --git a/test/mlmnetTests.jl b/test/mlmnetTests.jl index 5435b87..7619d31 100644 --- a/test/mlmnetTests.jl +++ b/test/mlmnetTests.jl @@ -1,12 +1,12 @@ -########### -# Library # -########### -# using MatrixLM -# using Distributions, Random, Statistics, LinearAlgebra, StatsBase -# using Random -using MatrixLMnet -using Helium -using Test +# ########### +# # Library # +# ########### +# # using MatrixLM +# # using Distributions, Random, Statistics, LinearAlgebra, StatsBase +# # using Random +# using MatrixLMnet +# using Helium +# using Test #################################################### @@ -54,11 +54,11 @@ rng = MatrixLMnet.Random.MersenneTwister(2021) # Elastic net penalized regression -est_ista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_ista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_ista_1 = est_ista_1.B[:, :, 3, 1]; # Elastic net penalized regression -est_ista_2 = MatrixLMnet.mlmnet(dat, 位, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_ista_2 = MatrixLMnet.mlmnet(dat, 位, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_ista_2 = est_ista_2.B[:, :, 3, 1]; # Lasso penalized regression - ista @@ -71,11 +71,11 @@ println("Lasso vs Elastic Net when 伪=1 test 1 - ista: ", @test (B_Net_ista_1 ############################################# # Elastic net penalized regression -est_fista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_fista_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fista_1 = est_fista_1.B[:, :, 3, 1]; # Elastic net penalized regression -est_fista_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_fista_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fista_2 = est_fista_2.B[:, :, 3, 1]; # Lasso penalized regression - fista @@ -88,11 +88,11 @@ println("Lasso vs Elastic Net when 伪=1 test 2 - fista: ", @test (B_Net_fista_1 ########################################################## # Elastic net penalized regression -est_fistabt_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_fistabt_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fistabt_1 = est_fistabt_1.B[:, :, 3, 1]; # Elastic net penalized regression -est_fistabt_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_fistabt_2 = MatrixLMnet.mlmnet(dat, 位, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fistabt_2 = est_fistabt_2.B[:, :, 3, 1]; # Lasso penalized regression - fista-bt @@ -106,11 +106,11 @@ println("Lasso vs Elastic Net when 伪=1 test 3 - fista-bt: ", @test (B_Net_fista ############################################ # Elastic net penalized regression -est_admm_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_admm_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_admm_1 = est_admm_1.B[:, :, 3, 1]; # Elastic net penalized regression -est_admm_2 = MatrixLMnet.mlmnet(dat, 位, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_admm_2 = MatrixLMnet.mlmnet(dat, 位, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_admm_2 = est_admm_2.B[:, :, 3, 1]; # Lasso penalized regression - admm @@ -125,12 +125,12 @@ println("Lasso vs Elastic Net when 伪=1 test 4 - admm: ", @test (B_Net_admm_1 # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est_cd_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_cd_1 = MatrixLMnet.mlmnet(dat, 位, 伪, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_cd_1 = est_cd_1.B[:, :, 3, 1]; # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est_cd_2 = MatrixLMnet.mlmnet(dat, 位, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est_cd_2 = MatrixLMnet.mlmnet(dat, 位, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_cd_2 = est_cd_2.B[:, :, 3, 1]; # Lasso penalized regression - cd diff --git a/test/runtests.jl b/test/runtests.jl index 7ebd2fb..aea01a8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using MatrixLMnet +using Distributions, LinearAlgebra +using Helium using Test @testset "MatrixLMnet" begin diff --git a/test/summaryCvTests.jl b/test/summaryCvTests.jl index 98f2bee..3359ef4 100644 --- a/test/summaryCvTests.jl +++ b/test/summaryCvTests.jl @@ -1,10 +1,10 @@ -########### -# Library # -########### -# using Random -using MatrixLMnet -using Helium -using Test +# ########### +# # Library # +# ########### +# # using Random +# using MatrixLMnet +# using Helium +# using Test ##################################################################### # TEST Cross Validation Lasso vs Elastic Net (饾浖=1) - Simulated Data # @@ -48,7 +48,7 @@ numVersion = VERSION if Int(numVersion.minor) < 7 tolVersion=2e-1 else - tolVersion=1e-6 + tolVersion=1e-5 end @@ -58,7 +58,7 @@ end # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1); smmr_min_Net1 = MatrixLMnet.lambda_min(est1); @@ -70,7 +70,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) && # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2); smmr_min_Net2 = MatrixLMnet.lambda_min(est2); @@ -90,7 +90,7 @@ println("Summary cross-validation test 1 - ista: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1); smmr_min_Net1 = MatrixLMnet.lambda_min(est1); @@ -102,7 +102,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) && # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2); smmr_min_Net2 = MatrixLMnet.lambda_min(est2); @@ -122,7 +122,7 @@ println("Summary cross-validation test 2 - fista: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1); smmr_min_Net1 = MatrixLMnet.lambda_min(est1); @@ -134,7 +134,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) && # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2); smmr_min_Net2 = MatrixLMnet.lambda_min(est2); @@ -154,7 +154,7 @@ println("Summary cross-validation test 3 - fista_bt: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, 位, 伪, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net1 = MatrixLMnet.mlmnet_cv_summary(est1); smmr_min_Net1 = MatrixLMnet.lambda_min(est1); @@ -166,7 +166,7 @@ test_ElasticNet = (smmr_Net1[idxSmmr, :AvgMSE] == smmr_min_Net1[1, :AvgMSE]) && # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2); smmr_min_Net2 = MatrixLMnet.lambda_min(est2); @@ -186,7 +186,7 @@ println("Summary cross-validation test 4 - admm: ", # Elastic net penalized regression MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "admm", hasZIntercept = false, hasXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, 位, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); # Summaries smmr_Net2 = MatrixLMnet.mlmnet_cv_summary(est2); smmr_min_Net2 = MatrixLMnet.lambda_min(est2); @@ -196,5 +196,5 @@ test_Lasso = (smmr_Net2[idxSmmr, :AvgMSE] == smmr_min_Net2[1, :AvgMSE]) && (smmr_Net2[idxSmmr, :Lambda] == smmr_min_Net2[1, :Lambda]) && (smmr_Net2[idxSmmr, :Alpha] == smmr_min_Net2[1, :Alpha]) -println("Summary cross-validation test 5 - admm: ", +println("Summary cross-validation test 5 - cd: ", @test (test_Lasso)) \ No newline at end of file diff --git a/test/utilitiesTests.jl b/test/utilitiesTests.jl index 9ed236d..d1ba0de 100644 --- a/test/utilitiesTests.jl +++ b/test/utilitiesTests.jl @@ -36,7 +36,7 @@ rng = MatrixLMnet.Random.MersenneTwister(2021) #Test the dimension of results by util functions # ################################################### -est = mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = true, hasXIntercept = true, isVerbose = true) +est = mlmnet(dat, 位, 伪, method = "cd", addZIntercept = true, addXIntercept = true, isVerbose = true) predicted = predict(est, est.data.predictors) @@ -59,7 +59,7 @@ end newPredictors = Predictors(X, Z, false, false) predicted = predict(est, newPredictors) -est2 = mlmnet(dat, 位, 伪, method = "cd", hasZIntercept = false, hasXIntercept = false, isVerbose = true) +est2 = mlmnet(dat, 位, 伪, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = true) newPredictors2 = Predictors(hcat(ones(size(X, 1)), X), hcat(ones(size(Z, 1)), Z), true, true) predicted2 = predict(est2, newPredictors2)