Skip to content

Commit

Permalink
Merge pull request #24 from GregFa/dev
Browse files Browse the repository at this point in the history
v1.1.1

- updated backtransform functions
- updated function to get one-standard-error lambda value
- updated tests
- updated ci.yml
  • Loading branch information
GregFa authored Apr 27, 2024
2 parents cae437d + 7ded28e commit 3b8a6ca
Show file tree
Hide file tree
Showing 29 changed files with 629 additions and 298 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ jobs:
env:
JULIA_NUM_THREADS: 4
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
- uses: codecov/codecov-action@v2
with:
file: lcov.info
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MatrixLMnet"
uuid = "436227dc-6146-11e9-240b-9df6ee45b799"
authors = ["Jane Liang", "Zifan Yu", "Gregory Farage", "Saunak Sen"]
version = "1.1.0"
version = "1.1.1"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -24,7 +24,8 @@ julia = "1.6.4, 1"
[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Distributions", "Helium", "Test"]
test = ["Distributions", "Helium", "StableRNGs", "Test"]
20 changes: 13 additions & 7 deletions src/MatrixLMnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@ using MLBase


export Response, Predictors, RawData, get_X, get_Z, get_Y, contr,
add_intercept, remove_intercept, shuffle_rows, shuffle_cols,
cd!, cd_active!, ista!, fista!, fista_bt!, admm!,
mlmnet, Mlmnet, coef, coef_2d, predict, fitted, resid,
mlmnet_perms,
make_folds, make_folds_conds, mlmnet_cv, Mlmnet_cv,
avg_mse, lambda_min, avg_prop_zero, mlmnet_cv_summary,
mlmnet_bic, Mlmnet_bic, calc_bic, mlmnet_bic_summary
add_intercept, remove_intercept, shuffle_rows, shuffle_cols




Expand All @@ -39,25 +34,36 @@ include("methods/ista.jl")
include("methods/fista.jl")
include("methods/fista_bt.jl")
include("methods/admm.jl")
export cd!, cd_active!, ista!, fista!, fista_bt!, admm!

# Top level functions that call Elastic Net algorithms using warm starts
include("mlmnet/mlmnet.jl")
export mlmnet, Mlmnet

# Predictions and residuals
include("utilities/predict.jl")
export coef, predict, fitted, resid#, coef_2d,

# Permutations
include("mlmnet/mlmnet_perms.jl")
export mlmnet_perms

# Cross-validation
include("crossvalidation/mlmnet_cv_helpers.jl")
export make_folds, make_folds_conds
include("crossvalidation/mlmnet_cv.jl")
export mlmnet_cv, Mlmnet_cv
include("crossvalidation/mlmnet_cv_summary.jl")
export calc_avg_mse, lambda_min, calc_avg_prop_zero, mlmnet_cv_summary

# BIC validation
include("bic/mlmnet_bic_helpers.jl")
export calc_bic
include("bic/mlmnet_bic.jl")
export mlmnet_bic, Mlmnet_bic
include("bic/mlmnet_bic_summary.jl")
export mlmnet_bic_summary



end
54 changes: 54 additions & 0 deletions src/crossvalidation/mlmnet_cv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,4 +539,58 @@ function mlmnet_cv(data::RawData,
isVerbose=isVerbose, toNormalize=toNormalize,
stepsize=stepsize, setStepsize=setStepsize, dig=dig, funArgs...)

end


"""
mlmnet_cv(data::RawData,
lambdas::Array{Float64,1},
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1}
method::String="ista",
isNaive::Bool=false,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
toNormalize::Bool=true, isVerbose::Bool=true,
stepsize::Float64=0.01, setStepsize::Bool=true,
dig::Int64=12, funArgs...)
Performs cross-validation for `mlmnet` using non-overlapping row and column
folds randomly generated using calls to `make_folds`. Calls the base
`mlmnet_cv` function.
"""

function mlmnet_cv(data::RawData,
lambdas::Array{Float64,1},
rowFolds::Array{Array{Int64,1},1},
colFolds::Array{Array{Int64,1},1};
method::String="ista",
isNaive::Bool=false,
addXIntercept::Bool=true, addZIntercept::Bool=true,
toXReg::BitArray{1}=trues(size(get_X(data), 2)),
toZReg::BitArray{1}=trues(size(get_Z(data), 2)),
toXInterceptReg::Bool=false, toZInterceptReg::Bool=false,
toNormalize::Bool=true, isVerbose::Bool=true,
stepsize::Float64=0.01, setStepsize::Bool=true,
dig::Int64=12, funArgs...)



alphas = [1.0]

# Pass in randomly generated row and column folds to the base mlmnet_cv
# function
mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds;
method=method, isNaive=isNaive,
addXIntercept=addXIntercept, addZIntercept=addZIntercept,
toXReg=toXReg, toZReg=toZReg,
toXInterceptReg=toXInterceptReg,
toZInterceptReg=toZInterceptReg,
isVerbose=isVerbose, toNormalize=toNormalize,
stepsize=stepsize, setStepsize=setStepsize, dig=dig, funArgs...)

end
52 changes: 52 additions & 0 deletions src/crossvalidation/mlmnet_cv_helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ function findnotin(a::AbstractArray{Int64,1}, b::AbstractArray{Int64,1})
end
end

# CHECK if [] is an argument
# WHY not using setdiff(collect(1:n), vec)



"""
calc_mse(MLMNets::AbstractArray{Mlmnet,1}, data::RawData,
lambdas::AbstractArray{Float64,1},
Expand Down Expand Up @@ -233,4 +238,51 @@ function calc_prop_zero(MLMNets::AbstractArray{Mlmnet,1},
end

return propZero
end
"""
minimize_rows(indices::Vector{CartesianIndex{2}})
Processes a vector of `CartesianIndex` objects representing positions in a 2D matrix
and returns a new vector of CartesianIndex objects. Each element in the resulting vector
should represent the smallest row index for each unique column index.
# Arguments
- indices = 1d array of `CartesianIndex` objects representing positions in a 2D matrix
# Value
1d array of `CartesianIndex` objects representing the smallest row index for each unique
column index.
# Example
```julia
julia> input_indices = [CartesianIndex(1, 1), CartesianIndex(2, 1), CartesianIndex(3, 2), CartesianIndex(1, 2)]
julia> output_indices = minimize_rows(input_indices)
2-element Vector{CartesianIndex{2}}:
CartesianIndex(1, 2)
CartesianIndex(1, 1)
```
"""
function minimize_rows(indices::Vector{CartesianIndex{2}})
# Dictionary to hold the minimum row index for each column
min_rows = Dict{Int, Int}()

for index in indices
col = index[2]
row = index[1]

# Update the dictionary with the minimum row for each column
if haskey(min_rows, col)
min_rows[col] = min(min_rows[col], row)
else
min_rows[col] = row
end
end

# Create a vector of CartesianIndices from the dictionary
result_indices = [CartesianIndex(min_rows[col], col) for col in keys(min_rows)]

return result_indices
end
13 changes: 10 additions & 3 deletions src/crossvalidation/mlmnet_cv_summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,19 @@ function lambda_min(MLMNet_cv::Mlmnet_cv)
minIdy = argmin(mseMean)[2]

# Compute standard error across folds for the minimum MSE
mse1StdErr = mseMean[minIdx, minIdy] + mseStd[minIdx, minIdy]
mse1StdErr = mseMean[minIdx, minIdy] + mseStd[minIdx, minIdy]./sqrt(length(MLMNet_cv.rowFolds))
# Find the index of the lambda that is closest to being 1 SE greater than
# the lowest lambda, in the direction of the bigger lambdas
min1StdErrIdx = argmin(abs.(mseMean[1:minIdx[1], 1:minIdy[1]].-mse1StdErr))[1]
min1StdErrIdy = argmin(abs.(mseMean[1:minIdx[1], 1:minIdy[1]].-mse1StdErr))[2]
# the “one-standard-error rule” recommended by Hastie, Tibshirani, and Wainwright (2015, 13–14)
# instead of the λ that minimizes the CV function. The one-standard-error rule selects, for each α,
# the largest λ for which the CV function is within a standard error of the minimum of the CV function.
# Then, from among these (α,λ) pairs, the one with the smallest value of the CV function is selected.
mse1tmp = mse1StdErr .-mseMean;
idxMin1StdErr = minimize_rows(findall(mse1tmp .>= 0))

min1StdErrIdx = idxMin1StdErr[argmin(mseMean[idxMin1StdErr])][1]
min1StdErrIdy = idxMin1StdErr[argmin(mseMean[idxMin1StdErr])][2]

# Pull out summary information for these two lambdas
out = hcat(MLMNet_cv.lambdas[minIdx], MLMNet_cv.alphas[minIdy],
mseMean[minIdx, minIdy], prop_zeroMean[minIdx, minIdy])
Expand Down
10 changes: 5 additions & 5 deletions src/mlmnet/mlmnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ function mlmnet(data::RawData,
error("toZReg does not have same length as number of columns in Z.")
end

# create a copy of data to preserve original values and structure
data = RawData(Response(data.response.Y),Predictors(data.predictors.X, data.predictors.Z))


# Add X and Z intercepts if necessary
# Update toXReg and toZReg accordingly
if addXIntercept==true && data.predictors.hasXIntercept==false
Expand Down Expand Up @@ -375,10 +379,7 @@ function mlmnet(data::RawData,

# Back-transform coefficient estimates, if necessary.
# Case if including both X and Z intercepts:
if toNormalize == true && (addXIntercept==true) && (addZIntercept==true)
backtransform!(coeffs, meansX, meansZ, normsX, normsZ, get_Y(data),
data.predictors.X, data.predictors.Z)
elseif toNormalize == true # Otherwise
if toNormalize == true
backtransform!(coeffs, addXIntercept, addZIntercept, meansX, meansZ,
normsX, normsZ)
end
Expand Down Expand Up @@ -435,4 +436,3 @@ function mlmnet(data::RawData,
return rslts
end


Loading

0 comments on commit 3b8a6ca

Please sign in to comment.