Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev v1.1.0 #12

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixLMnet"
uuid = "436227dc-6146-11e9-240b-9df6ee45b799"
authors = ["Jane Liang", "Zifan Yu", "Gregory Farage", "Saunak Sen"]
version = "1.0.2"
version = "1.1.0"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -15,5 +15,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
DataFrames = "^1.3.0"
MLBase = "0.8, 0.9"
MatrixLM = "^0.1.3"
MatrixLM = "~0.2"
Statistics = "1"
julia = "1.6.4, 1"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distributions", "Helium", "Test"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ alphas = reverse(collect(0:0.1:1))

L<sub>1</sub> and L<sub>2</sub>-penalized estimates for matrix linear models can be obtained by running `mlmnet`. In addition to the `RawData` object, `lambdas` and `alphas`, `mlmnet` requires as a keyword argument the function name for an algorithm used to fit Elastic Net penalized estimates. Current methods are: `"cd"` (coordinate descent), `"cd_active"` (active coordinate descent), `"ista"` (ISTA with fixed step size), `"fista"` (FISTA with fixed step size), `"fista_bt"` (FISTA with backtracking), and `"admm"` (ADMM).

An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `hasXIntercept=false` and/or `hasZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations.
An object of type `Mlmnet` will be returned, with variables for the penalized coefficient estimates (`B`) along with the lambda and alpha penalty values used (`lambdas`, `alphas`). By default, `mlmnet` estimates both row and column main effects (X and Z intercepts), but this behavior can be suppressed by setting `addXIntercept=false` and/or `addZIntercept=false`; the intercepts will be regularized unless `toXInterceptReg=false` and/or `toZInterceptReg=false`. Individual `X` (row) and `Z` (column) effects can be left unregularized by manually passing in 1d boolean arrays of length `p` and `q` to indicate which effects should be regularized (`true`) or not (`false`) into `toXReg` and `toZReg`. By default, `mlmnet` centers and normalizes the columns of `X` and `Z` to have mean 0 and norm 1 (`toNormalize=true`). Additional keyword arguments include `isVerbose`, which controls message printing; `thresh`, the threshold at which the coefficients are considered to have converged; and `maxiter`, the maximum number of iterations.

```
est = mlmnet(dat, lambdas, alphas, method = "fista_bt")
Expand Down
12 changes: 6 additions & 6 deletions src/bic/mlmnet_bic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
lambdas::AbstractArray{Float64,1},
alphas::AbstractArray{Float64,1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -55,9 +55,9 @@ Performs BIC validation for `mlmnet`.
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- isNaive = boolean flag indicating whether to solve the Naive or non-Naive
Elastic-net problem
- hasXIntercept = boolean flag indicating whether or not to include an `X`
- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
- hasZIntercept = boolean flag indicating whether or not to include a `Z`
- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
Expand Down Expand Up @@ -94,7 +94,7 @@ function mlmnet_bic(data::RawData,
lambdas::AbstractArray{Float64,1},
alphas::AbstractArray{Float64,1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -107,8 +107,8 @@ function mlmnet_bic(data::RawData,
# Run mlmnet on each RawData object, in parallel when possible
MLMNet = mlmnet(data, lambdas, alphas;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept,
hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept,
addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand Down
48 changes: 24 additions & 24 deletions src/crossvalidation/mlmnet_cv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand Down Expand Up @@ -74,9 +74,9 @@ input.
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- isNaive = boolean flag indicating whether to solve the Naive or non-Naive
Elastic-net problem
- hasXIntercept = boolean flag indicating whether or not to include an `X`
- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
- hasZIntercept = boolean flag indicating whether or not to include a `Z`
- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
Expand Down Expand Up @@ -116,7 +116,7 @@ function mlmnet_cv(data::RawData,
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand Down Expand Up @@ -156,8 +156,8 @@ function mlmnet_cv(data::RawData,

MLMNets = Distributed.pmap(data -> mlmnet(data, lambdas, alphas;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept,
hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept,
addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand All @@ -177,7 +177,7 @@ end
alphas::AbstractArray{Float64,1},
rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64;
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand Down Expand Up @@ -207,9 +207,9 @@ Calls the base `mlmnet_cv` function.

- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- hasXIntercept = boolean flag indicating whether or not to include an `X`
- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
- hasZIntercept = boolean flag indicating whether or not to include a `Z`
- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
Expand Down Expand Up @@ -248,7 +248,7 @@ function mlmnet_cv(data::RawData,
alphas::AbstractArray{Float64,1},
rowFolds::Array{Array{Int64,1},1}, nColFolds::Int64;
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -263,7 +263,7 @@ function mlmnet_cv(data::RawData,
# base mlmnet_cv function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand All @@ -278,7 +278,7 @@ end
alphas::AbstractArray{Float64,1},
nRowFolds::Int64, colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand Down Expand Up @@ -308,9 +308,9 @@ input. Calls the base `mlmnet_cv` function.

- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- hasXIntercept = boolean flag indicating whether or not to include an `X`
- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
- hasZIntercept = boolean flag indicating whether or not to include a `Z`
- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
Expand Down Expand Up @@ -349,7 +349,7 @@ function mlmnet_cv(data::RawData,
alphas::AbstractArray{Float64,1},
nRowFolds::Int64, colFolds::Array{Array{Int64,1},1};
method::String="ista", isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -364,7 +364,7 @@ function mlmnet_cv(data::RawData,
# base mlmnet_cv function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand All @@ -379,7 +379,7 @@ end
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand Down Expand Up @@ -409,9 +409,9 @@ folds randomly generated using calls to `make_folds`. Calls the base

- methods = function name that applies the Elastic-net penalty estimate method;
default is `ista`, and the other methods are `fista`, `fista_bt`, `admm` and `cd`
- hasXIntercept = boolean flag indicating whether or not to include an `X`
- addXIntercept = boolean flag indicating whether or not to include an `X`
intercept (row main effects). Defaults to `true`.
- hasZIntercept = boolean flag indicating whether or not to include a `Z`
- addZIntercept = boolean flag indicating whether or not to include a `Z`
intercept (column main effects). Defaults to `true`.
- toXReg = 1d array of bit flags indicating whether or not to regularize each
of the `X` (row) effects. Defaults to 2d array of `true`s with length
Expand Down Expand Up @@ -465,7 +465,7 @@ function mlmnet_cv(data::RawData,
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -481,7 +481,7 @@ function mlmnet_cv(data::RawData,
# function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand All @@ -496,7 +496,7 @@ end
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -516,7 +516,7 @@ function mlmnet_cv(data::RawData,
nRowFolds::Int64=10, nColFolds::Int64=10;
method::String="ista",
isNaive::Bool=false,
hasXIntercept::Bool=true, hasZIntercept::Bool=true,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
Expand All @@ -532,7 +532,7 @@ function mlmnet_cv(data::RawData,
# function
mlmnet_cv(data, lambdas, alphas, nRowFolds, nColFolds;
method=method, isNaive=isNaive,
hasXIntercept=hasXIntercept, hasZIntercept=hasZIntercept,
addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
Expand Down
Loading
Loading