diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 245a028..71eb3aa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,6 @@ jobs: env: JULIA_NUM_THREADS: 4 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v2 with: file: lcov.info diff --git a/Project.toml b/Project.toml index 7201fd0..d36e043 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixLMnet" uuid = "436227dc-6146-11e9-240b-9df6ee45b799" authors = ["Jane Liang", "Zifan Yu", "Gregory Farage", "Saunak Sen"] -version = "1.1.0" +version = "1.1.1" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -24,7 +24,8 @@ julia = "1.6.4, 1" [extras] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Helium = "3f79f04f-7cac-48b4-bde1-3ad54d8f74fa" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Distributions", "Helium", "Test"] +test = ["Distributions", "Helium", "StableRNGs", "Test"] diff --git a/src/MatrixLMnet.jl b/src/MatrixLMnet.jl index 5b6bf1c..5a6429e 100644 --- a/src/MatrixLMnet.jl +++ b/src/MatrixLMnet.jl @@ -19,13 +19,8 @@ using MLBase export Response, Predictors, RawData, get_X, get_Z, get_Y, contr, - add_intercept, remove_intercept, shuffle_rows, shuffle_cols, - cd!, cd_active!, ista!, fista!, fista_bt!, admm!, - mlmnet, Mlmnet, coef, coef_2d, predict, fitted, resid, - mlmnet_perms, - make_folds, make_folds_conds, mlmnet_cv, Mlmnet_cv, - avg_mse, lambda_min, avg_prop_zero, mlmnet_cv_summary, - mlmnet_bic, Mlmnet_bic, calc_bic, mlmnet_bic_summary + add_intercept, remove_intercept, shuffle_rows, shuffle_cols + @@ -39,25 +34,36 @@ include("methods/ista.jl") include("methods/fista.jl") include("methods/fista_bt.jl") include("methods/admm.jl") +export cd!, cd_active!, ista!, fista!, fista_bt!, admm! # Top level functions that call Elastic Net algorithms using warm starts include("mlmnet/mlmnet.jl") +export mlmnet, Mlmnet # Predictions and residuals include("utilities/predict.jl") +export coef, predict, fitted, resid#, coef_2d, # Permutations include("mlmnet/mlmnet_perms.jl") +export mlmnet_perms # Cross-validation include("crossvalidation/mlmnet_cv_helpers.jl") +export make_folds, make_folds_conds include("crossvalidation/mlmnet_cv.jl") +export mlmnet_cv, Mlmnet_cv include("crossvalidation/mlmnet_cv_summary.jl") +export calc_avg_mse, lambda_min, calc_avg_prop_zero, mlmnet_cv_summary # BIC validation include("bic/mlmnet_bic_helpers.jl") +export calc_bic include("bic/mlmnet_bic.jl") +export mlmnet_bic, Mlmnet_bic include("bic/mlmnet_bic_summary.jl") +export mlmnet_bic_summary + end diff --git a/src/crossvalidation/mlmnet_cv.jl b/src/crossvalidation/mlmnet_cv.jl index ea1e10c..2ecf0b3 100644 --- a/src/crossvalidation/mlmnet_cv.jl +++ b/src/crossvalidation/mlmnet_cv.jl @@ -539,4 +539,58 @@ function mlmnet_cv(data::RawData, isVerbose=isVerbose, toNormalize=toNormalize, stepsize=stepsize, setStepsize=setStepsize, dig=dig, funArgs...) +end + + +""" + mlmnet_cv(data::RawData, + lambdas::Array{Float64,1}, + rowFolds::Array{Array{Int64,1},1}, + colFolds::Array{Array{Int64,1},1} + method::String="ista", + isNaive::Bool=false, + addXIntercept::Bool=true, addZIntercept::Bool=true, + toXReg::BitArray{1}=trues(size(get_X(data), 2)), + toZReg::BitArray{1}=trues(size(get_Z(data), 2)), + toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, + toNormalize::Bool=true, isVerbose::Bool=true, + stepsize::Float64=0.01, setStepsize::Bool=true, + dig::Int64=12, funArgs...) + +Performs cross-validation for `mlmnet` using non-overlapping row and column +folds randomly generated using calls to `make_folds`. Calls the base +`mlmnet_cv` function. + + +""" + +function mlmnet_cv(data::RawData, + lambdas::Array{Float64,1}, + rowFolds::Array{Array{Int64,1},1}, + colFolds::Array{Array{Int64,1},1}; + method::String="ista", + isNaive::Bool=false, + addXIntercept::Bool=true, addZIntercept::Bool=true, + toXReg::BitArray{1}=trues(size(get_X(data), 2)), + toZReg::BitArray{1}=trues(size(get_Z(data), 2)), + toXInterceptReg::Bool=false, toZInterceptReg::Bool=false, + toNormalize::Bool=true, isVerbose::Bool=true, + stepsize::Float64=0.01, setStepsize::Bool=true, + dig::Int64=12, funArgs...) + + + + alphas = [1.0] + + # Pass in randomly generated row and column folds to the base mlmnet_cv + # function + mlmnet_cv(data, lambdas, alphas, rowFolds, colFolds; + method=method, isNaive=isNaive, + addXIntercept=addXIntercept, addZIntercept=addZIntercept, + toXReg=toXReg, toZReg=toZReg, + toXInterceptReg=toXInterceptReg, + toZInterceptReg=toZInterceptReg, + isVerbose=isVerbose, toNormalize=toNormalize, + stepsize=stepsize, setStepsize=setStepsize, dig=dig, funArgs...) + end \ No newline at end of file diff --git a/src/crossvalidation/mlmnet_cv_helpers.jl b/src/crossvalidation/mlmnet_cv_helpers.jl index 863902f..f223a85 100644 --- a/src/crossvalidation/mlmnet_cv_helpers.jl +++ b/src/crossvalidation/mlmnet_cv_helpers.jl @@ -111,6 +111,11 @@ function findnotin(a::AbstractArray{Int64,1}, b::AbstractArray{Int64,1}) end end +# CHECK if [] is an argument +# WHY not using setdiff(collect(1:n), vec) + + + """ calc_mse(MLMNets::AbstractArray{Mlmnet,1}, data::RawData, lambdas::AbstractArray{Float64,1}, @@ -233,4 +238,51 @@ function calc_prop_zero(MLMNets::AbstractArray{Mlmnet,1}, end return propZero +end +""" + minimize_rows(indices::Vector{CartesianIndex{2}}) + +Processes a vector of `CartesianIndex` objects representing positions in a 2D matrix +and returns a new vector of CartesianIndex objects. Each element in the resulting vector +should represent the smallest row index for each unique column index. + +# Arguments + +- indices = 1d array of `CartesianIndex` objects representing positions in a 2D matrix + +# Value + +1d array of `CartesianIndex` objects representing the smallest row index for each unique +column index. + +# Example +```julia +julia> input_indices = [CartesianIndex(1, 1), CartesianIndex(2, 1), CartesianIndex(3, 2), CartesianIndex(1, 2)] + +julia> output_indices = minimize_rows(input_indices) +2-element Vector{CartesianIndex{2}}: + CartesianIndex(1, 2) + CartesianIndex(1, 1) +``` +""" +function minimize_rows(indices::Vector{CartesianIndex{2}}) + # Dictionary to hold the minimum row index for each column + min_rows = Dict{Int, Int}() + + for index in indices + col = index[2] + row = index[1] + + # Update the dictionary with the minimum row for each column + if haskey(min_rows, col) + min_rows[col] = min(min_rows[col], row) + else + min_rows[col] = row + end + end + + # Create a vector of CartesianIndices from the dictionary + result_indices = [CartesianIndex(min_rows[col], col) for col in keys(min_rows)] + + return result_indices end \ No newline at end of file diff --git a/src/crossvalidation/mlmnet_cv_summary.jl b/src/crossvalidation/mlmnet_cv_summary.jl index 78fab07..6c4ca00 100644 --- a/src/crossvalidation/mlmnet_cv_summary.jl +++ b/src/crossvalidation/mlmnet_cv_summary.jl @@ -152,12 +152,19 @@ function lambda_min(MLMNet_cv::Mlmnet_cv) minIdy = argmin(mseMean)[2] # Compute standard error across folds for the minimum MSE - mse1StdErr = mseMean[minIdx, minIdy] + mseStd[minIdx, minIdy] + mse1StdErr = mseMean[minIdx, minIdy] + mseStd[minIdx, minIdy]./sqrt(length(MLMNet_cv.rowFolds)) # Find the index of the lambda that is closest to being 1 SE greater than # the lowest lambda, in the direction of the bigger lambdas - min1StdErrIdx = argmin(abs.(mseMean[1:minIdx[1], 1:minIdy[1]].-mse1StdErr))[1] - min1StdErrIdy = argmin(abs.(mseMean[1:minIdx[1], 1:minIdy[1]].-mse1StdErr))[2] + # the “one-standard-error rule” recommended by Hastie, Tibshirani, and Wainwright (2015, 13–14) + # instead of the λ that minimizes the CV function. The one-standard-error rule selects, for each α, + # the largest λ for which the CV function is within a standard error of the minimum of the CV function. + # Then, from among these (α,λ) pairs, the one with the smallest value of the CV function is selected. + mse1tmp = mse1StdErr .-mseMean; + idxMin1StdErr = minimize_rows(findall(mse1tmp .>= 0)) + min1StdErrIdx = idxMin1StdErr[argmin(mseMean[idxMin1StdErr])][1] + min1StdErrIdy = idxMin1StdErr[argmin(mseMean[idxMin1StdErr])][2] + # Pull out summary information for these two lambdas out = hcat(MLMNet_cv.lambdas[minIdx], MLMNet_cv.alphas[minIdy], mseMean[minIdx, minIdy], prop_zeroMean[minIdx, minIdy]) diff --git a/src/mlmnet/mlmnet.jl b/src/mlmnet/mlmnet.jl index 434e5a0..f4a29c6 100644 --- a/src/mlmnet/mlmnet.jl +++ b/src/mlmnet/mlmnet.jl @@ -261,6 +261,10 @@ function mlmnet(data::RawData, error("toZReg does not have same length as number of columns in Z.") end + # create a copy of data to preserve original values and structure + data = RawData(Response(data.response.Y),Predictors(data.predictors.X, data.predictors.Z)) + + # Add X and Z intercepts if necessary # Update toXReg and toZReg accordingly if addXIntercept==true && data.predictors.hasXIntercept==false @@ -375,10 +379,7 @@ function mlmnet(data::RawData, # Back-transform coefficient estimates, if necessary. # Case if including both X and Z intercepts: - if toNormalize == true && (addXIntercept==true) && (addZIntercept==true) - backtransform!(coeffs, meansX, meansZ, normsX, normsZ, get_Y(data), - data.predictors.X, data.predictors.Z) - elseif toNormalize == true # Otherwise + if toNormalize == true backtransform!(coeffs, addXIntercept, addZIntercept, meansX, meansZ, normsX, normsZ) end @@ -435,4 +436,3 @@ function mlmnet(data::RawData, return rslts end - diff --git a/src/utilities/std_helpers.jl b/src/utilities/std_helpers.jl index 4f33747..8f03fb5 100644 --- a/src/utilities/std_helpers.jl +++ b/src/utilities/std_helpers.jl @@ -21,7 +21,7 @@ function normalize!(A::AbstractArray{Float64,2}, hasIntercept::Bool) if hasIntercept == true means = mean(A, dims=1) A[:,2:end] = A[:,2:end].-transpose(means[2:end]) - else # Otherwise, subtract the column means from all columns + else # If no intercept do not center. means = Array{Float64}(undef, 1, size(A,2)) end @@ -36,195 +36,6 @@ function normalize!(A::AbstractArray{Float64,2}, hasIntercept::Bool) return means, norms end - -""" - backtransform!(B::AbstractArray{Float64,2}, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}, - Y::AbstractArray{Float64,2}, - Xold::AbstractArray{Float64,2}, - Zold::AbstractArray{Float64,2}) - -Back-transform coefficient estimates B in place if X and Z were standardized -prior to the estimation-- when both X and Z include intercept columns. - -# Arguments - -- B = 2d array of coefficient estimates B -- meansX = 2d array of column means of X, obtained prior to standardizing X -- meansZ = 2d array of column means of Z, obtained prior to standardizing Z -- normsX = 2d array of column norms of X, obtained prior to standardizing X -- normsZ = 2d array of column norms of Z, obtained prior to standardizing Z -- Y = 2d array of response matrix Y -- Xold = 2d array row covariates X prior to standardization -- Zold = 2d array column covariates Z prior to standardization - -# Value - -None; back-transforms B in place - -""" -function backtransform!(B::AbstractArray{Float64,2}, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}, - Y::AbstractArray{Float64,2}, - Xold::AbstractArray{Float64,2}, - Zold::AbstractArray{Float64,2}) - - # Back transform the X intercepts (row main effects) - prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end, 2:end] - B[1,2:end] = (B[1,2:end]-vec(prodX))./vec(normsZ[:,2:end])/normsX[1,1] - - # Back transform the Z intercepts (column main effects) - prodZ = B[2:end, 2:end]*transpose(meansZ[:,2:end]./normsZ[:,2:end]) - B[2:end,1] = (B[2:end,1]-prodZ)./transpose(normsX[:,2:end])/normsZ[1,1] - - # Back transform the interactions - B[2:end, 2:end] = B[2:end, 2:end]./transpose(normsX[:,2:end])./ - normsZ[:,2:end] - - # Re-estimate intercept - B[1,1] = 0 - B[1,1] = mean(Y-Xold*B*transpose(Zold)) -end - - -""" - backtransform!(B::AbstractArray{Float64,2}, - addXIntercept::Bool, addZIntercept::Bool, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}) - -Back-transform coefficient estimates B in place if X and Z were standardized -prior to the estimation-- when not including intercept columns for either X -or Z. - -# Arguments - -- B = 2d array of coefficient estimates B -- addXIntercept = boolean flag indicating whether or not to X has an - intercept column -- addZIntercept = boolean flag indicating whether or not to Z has an - intercept column -- meansX = 2d array of column means of X, obtained prior to standardizing X -- meansZ = 2d array of column means of Z, obtained prior to standardizing Z -- normsX = 2d array of column norms of X, obtained prior to standardizing X -- normsZ = 2d array of column norms of Z, obtained prior to standardizing Z - -# Value - -None; back-transforms B in place - -""" -function backtransform!(B::AbstractArray{Float64,2}, - addXIntercept::Bool, addZIntercept::Bool, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}) - - # Back transform the X intercepts (row main effects), if necessary - if addXIntercept == true - prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end, 2:end] - B[1,2:end] = (B[1,2:end]-vec(prodX))./vec(normsZ[:,2:end])/normsX[1,1] - end - - # Back transform the Z intercepts (column main effects), if necessary - if addZIntercept == true - prodZ = B[2:end, 2:end]*transpose(meansZ[:,2:end]./normsZ[:,2:end]) - B[2:end,1] = (B[2:end,1]-prodZ)./transpose(normsX[:,2:end])/ - normsZ[1,1] - end - - # Back transform the interactions, if necessary - if (addXIntercept == true) || (addZIntercept == true) - B[2:end, 2:end] = B[2:end, 2:end]./transpose(normsX[:,2:end])./ - normsZ[:,2:end] - end - - # Back transform the interactions if not including any main effects - if (addXIntercept == false) && (addZIntercept == false) - B = B./transpose(normsX)./normsZ - end -end - - -""" - backtransform!(B::AbstractArray{Float64,4}, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}, - Y::AbstractArray{Float64,2}, - Xold::AbstractArray{Float64,2}, - Zold::AbstractArray{Float64,2}) - -Back-transform coefficient estimates B in place if X and Z were standardized -prior to the estimation-- when both X and Z include intercept columns. - -# Arguments - -- B = 4d array of coefficient estimates B -- meansX = 2d array of column means of X, obtained prior to standardizing X -- meansZ = 2d array of column means of Z, obtained prior to standardizing Z -- normsX = 2d array of column norms of X, obtained prior to standardizing X -- normsZ = 2d array of column norms of Z, obtained prior to standardizing Z -- Y = 2d array of response matrix Y -- Xold = 2d array row covariates X prior to standardization -- Zold = 2d array column covariates Z prior to standardization - -# Value - -None; back-transforms B in place - -# Some notes - -B is a 4d array in which each coefficient matrix is stored along the third and fourth -dimension. - -""" -function backtransform!(B::AbstractArray{Float64,4}, - meansX::AbstractArray{Float64,2}, - meansZ::AbstractArray{Float64,2}, - normsX::AbstractArray{Float64,2}, - normsZ::AbstractArray{Float64,2}, - Y::AbstractArray{Float64,2}, - Xold::AbstractArray{Float64,2}, - Zold::AbstractArray{Float64,2}) - - # Iterate through the first dimension of B to back-transform each - # coefficient matrix. - for j in 1:size(B,4) - for i in 1:size(B,3) - # Back transform the X intercepts (row main effects) - prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end, 2:end,i,j] - B[1,2:end,i,j] = (B[1,2:end,i,j]-vec(prodX))./vec(normsZ[:,2:end])/ - normsX[1,1] - - # Back transform the Z intercepts (column main effects) - prodZ = B[2:end,2:end,i,j]*transpose(meansZ[:,2:end]./normsZ[:,2:end]) - B[2:end,1,i,j] = (B[2:end,1,i,j]-prodZ)./transpose(normsX[:,2:end])/ - normsZ[1,1] - - # Back transform the interactions - B[2:end,2:end,i,j] = B[2:end,2:end,i,j]./transpose(normsX[:,2:end])./ - normsZ[:,2:end] - - # Re-estimate intercept - B[1,1,i,j] = 0 - B[1,1,i,j] = mean(Y-Xold*B[:,:,i,j]*transpose(Zold)) - end - end -end - - - """ backtransform!(B::AbstractArray{Float64,4}, addXIntercept::Bool, addZIntercept::Bool, @@ -270,31 +81,22 @@ function backtransform!(B::AbstractArray{Float64,4}, # coefficient matrix: for j in 1:size(B,4) for i in 1:size(B,3) - # Back transform the X intercepts (row main effects), if necessary - if addXIntercept == true - prodX = (meansX[:,2:end]./normsX[:,2:end])*B[2:end,2:end,i,j] - B[1,2:end,i,j] = (B[1,2:end,i,j]-vec(prodX))./vec(normsZ[:,2:end])/ - normsX[1,1] - end - - # Back transform the Z intercepts (column main effects), if necessary - if addZIntercept == true - prodZ = B[2:end,2:end,i,j]*transpose(meansZ[:,2:end]./ - normsZ[:,2:end]) - B[2:end,1,i,j] = (B[2:end,1,i,j]-prodZ)./transpose(normsX[:,2:end])/ - normsZ[1,1] - end - - # Back transform the interactions, if necessary - if (addXIntercept == true) || (addZIntercept == true) - B[2:end,2:end,i,j] = B[2:end,2:end,i,j]./transpose(normsX[:,2:end])./ - normsZ[:,2:end] + + # reverse scale from X and Z normalization + B[:,:,i,j] = (B[:,:,i,j]./permutedims(normsX))./normsZ + + # Back transform the X intercepts (row main effects) + if addXIntercept == true + # B̂intercept = Ȳ - X̄ B̂coef. Since X centered, B̂intercept|Xcentered = Ȳ + B[1,:,i,j] = B[1,:,i,j] - vec(meansX[:,2:end]*B[2:end,:,i,j]) end - - # Back transform the interactions if not including any main effects - if (addXIntercept == false) && (addZIntercept == false) - B[:,:,i,j] = B[:,:,i,j]./transpose(normsX)./normsZ + + # Back transform the Z intercepts (column main effects) + if addZIntercept == true + # B̂intercept = Ȳ - Z̄ B̂coef. Since X centered, B̂intercept|Xcentered = Ȳ + B[:,1,i,j] = B[:,1,i,j] - B[:,2:end,i,j]*permutedims(meansZ[:,2:end]) end + end end end \ No newline at end of file diff --git a/test/cv_helpersTests.jl b/test/cv_helpersTests.jl new file mode 100644 index 0000000..d8c30ac --- /dev/null +++ b/test/cv_helpersTests.jl @@ -0,0 +1,142 @@ +########### +# Library # +########### +# using Random +using MatrixLMnet +using Helium +using Test + +##################################################################### +# TEST Cross Validation Lasso vs Elastic Net (𝛼=1) - Simulated Data # +##################################################################### + +# Data testing directory name +dataDir = realpath(joinpath(@__DIR__,"data")) + +# Get predictors +X = Helium.readhe(joinpath(dataDir, "Xmat.he")) + +# Get response +Y = Helium.readhe(joinpath(dataDir, "Ymat.he")) + +# Get Z matrix +Z = Helium.readhe(joinpath(dataDir, "Zmat.he")) + + +# Build raw data object from MatrixLM.jl +dat = RawData(Response(Y), Predictors(X, Z)); + +dat.n + +test = make_folds(dat.n, 10, max(10, 1)) +testcol = make_folds(dat.m, 1, max(10, 1)) + + +# Check https://mldatautilsjl.readthedocs.io/en/latest/ +# check resid + +################### +# Test make_folds # +################### + +####################################################### +# Test 1: Basic functionality with default parameters # +####################################################### +@test length(make_folds(100)) == 10 +@test all([length(fold) ≈ 90 for fold in make_folds(100)]) # Approximate because the folds may not be exactly equal + +################################## +# Test 2: Non-default `k` values # +################################## +@test length(make_folds(100, 5)) == 5 +@test length(make_folds(100, 100)) == 100 # Each fold should have exactly one element +@test length(make_folds(100, 20)) == 20 + +################################# +# Test 3: `k` Equals 1 Scenario # +################################# +@test length(make_folds(100, 1, 5)) == 5 # Should repeat 5 times +@test all([length(fold) == 100 for fold in make_folds(100, 1, 5)]) # Each fold contains all indices + +############################## +# Test 4: Invalid `k` values # +############################## +@test_throws ErrorException make_folds(100, 0) +@test_throws ErrorException make_folds(100, -1) + +################################# +# Test 5: Type Check (Optional) # +################################# +@test_throws MethodError make_folds(100.0, 10) # Float instead of Int for `n` +@test_throws MethodError make_folds(100, "10") # String instead of Int for `k` + + + + +######################### +# Test make_folds_conds # +######################### + + + + +################## +# Test findnotin # +################## + +# Basic Functionality +@testset "Basic Functionality" begin + a = [1, 2, 3, 4] + b = [3, 4, 5, 6] + @test MatrixLMnet.findnotin(a, b) == [5, 6] +end + +# No Missing Elements +@testset "No Missing Elements" begin + a = [1, 2, 3, 4] + b = [1, 2, 3, 4] + @test isempty(MatrixLMnet.findnotin(a, b)) +end + +# All Elements Missing +@testset "All Elements Missing" begin + a = [1, 2, 3, 4] + b = [5, 6, 7, 8] + @test MatrixLMnet.findnotin(a, b) == b +end + +# Duplicate Elements +@testset "Duplicate Elements" begin + a = [1, 2, 3] + b = [2, 2, 4, 4] + @test MatrixLMnet.findnotin(a, b) == [4, 4] # Expect duplicates in b that are not in a to be returned as-is +end + + + +a= [1, 2, 3, 4] +b = collect(1:10) +setdiff(b, a) +setdiff(b, []) +setdiff(b, b) + +b[b .∉ Ref(a)] + + +################## +# Cross-validate # +################## + +function compute_MatrixLMnet() + + +end + +scores = MatrixLMnet.cross_validate( + inds -> compute_center(data[:, inds]), # training function + (c, inds) -> compute_rmse(c, data[:, inds]), # evaluation function + n, # total number of samples + Kfold(n, 5)) # cross validation plan: 5-fold + + + diff --git a/test/data/B_admm.he b/test/data/B_admm.he index 12d17f8..0d08e97 100644 Binary files a/test/data/B_admm.he and b/test/data/B_admm.he differ diff --git a/test/data/B_cd.he b/test/data/B_cd.he index 8f03d7e..5f43d8b 100644 Binary files a/test/data/B_cd.he and b/test/data/B_cd.he differ diff --git a/test/data/B_fista.he b/test/data/B_fista.he index 71117b2..f762dc5 100644 Binary files a/test/data/B_fista.he and b/test/data/B_fista.he differ diff --git a/test/data/B_fistabt.he b/test/data/B_fistabt.he index f2d637b..341d185 100644 Binary files a/test/data/B_fistabt.he and b/test/data/B_fistabt.he differ diff --git a/test/data/B_ista.he b/test/data/B_ista.he index 05dd55f..eada11b 100644 Binary files a/test/data/B_ista.he and b/test/data/B_ista.he differ diff --git a/test/data/Xmat.he b/test/data/Xmat.he index 7fb300c..c0050bf 100644 Binary files a/test/data/Xmat.he and b/test/data/Xmat.he differ diff --git a/test/data/Ymat.he b/test/data/Ymat.he index 5d80fda..d8c448c 100644 Binary files a/test/data/Ymat.he and b/test/data/Ymat.he differ diff --git a/test/data/col_folds.he b/test/data/col_folds.he new file mode 100644 index 0000000..75aa20c Binary files /dev/null and b/test/data/col_folds.he differ diff --git a/test/data/row_folds.he b/test/data/row_folds.he new file mode 100644 index 0000000..a7c9ec7 Binary files /dev/null and b/test/data/row_folds.he differ diff --git a/test/data/smmr_admm.he b/test/data/smmr_admm.he index 9c01551..e2b5277 100644 Binary files a/test/data/smmr_admm.he and b/test/data/smmr_admm.he differ diff --git a/test/data/smmr_cd.he b/test/data/smmr_cd.he index d008571..5a9f9c9 100644 Binary files a/test/data/smmr_cd.he and b/test/data/smmr_cd.he differ diff --git a/test/data/smmr_fista.he b/test/data/smmr_fista.he index 31b4178..a9d1c15 100644 Binary files a/test/data/smmr_fista.he and b/test/data/smmr_fista.he differ diff --git a/test/data/smmr_fistabt.he b/test/data/smmr_fistabt.he index 9c3037b..8e19b31 100644 Binary files a/test/data/smmr_fistabt.he and b/test/data/smmr_fistabt.he differ diff --git a/test/data/smmr_ista.he b/test/data/smmr_ista.he index bf6aa0d..b61a2c2 100644 Binary files a/test/data/smmr_ista.he and b/test/data/smmr_ista.he differ diff --git a/test/generate_testing_dataset.jl b/test/generate_testing_dataset.jl index 3527204..a91902c 100644 --- a/test/generate_testing_dataset.jl +++ b/test/generate_testing_dataset.jl @@ -7,9 +7,10 @@ ########### # Library # ########### -# using Distributions, Random, Statistics, LinearAlgebra, StatsBase +using Distributions, Random, Statistics, LinearAlgebra, StatsBase +using StableRNGs # using DataFrames, MLBase, Distributed -using MatrixLMnet #v0.1.0 +using MatrixLMnet #v1.1.1 # using Test using Helium @@ -38,7 +39,7 @@ The pairwise correlation between 𝑋ᵢ and 𝑋ⱼ was set to be 𝑐𝑜𝑟( Here, the Z matrix is an identity matrix. =# -rng = MersenneTwister(2021) +rng = rng = StableRNG(123) # MatrixLMnet.Random.MersenneTwister(2021); # Simulation parameters p = 8; # Number of predictors @@ -69,7 +70,7 @@ dat = MatrixLMnet.MatrixLM.RawData(Response(Y), Predictors(X, Z)); # Hyper parameters λ = [10.0, 5.0, 3.0] - +α =[1.0] ############### @@ -77,59 +78,66 @@ dat = MatrixLMnet.MatrixLM.RawData(Response(Y), Predictors(X, Z)); ############### # Lasso penalized regression - ista -est = mlmnet(ista!, dat, λ, addZIntercept = false, addXIntercept = false, isVerbose = false); +est = mlmnet(dat, λ, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_ista = est.B[:, :, 3]; # Lasso penalized regression - fista -est = mlmnet(fista!, dat, λ, addZIntercept = false, addXIntercept = false, isVerbose = false); +est = mlmnet(dat, λ, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_fista = est.B[:, :, 3]; # Lasso penalized regression - fista backtracking -est = mlmnet(fista_bt!, dat, λ, addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est = mlmnet(dat, λ, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_fista_bt = est.B[:, :, 3]; # Lasso penalized regression - admm -est = mlmnet(admm!, dat, λ, addZIntercept = false, addXIntercept = false, isVerbose = false); +est = mlmnet(dat, λ, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_admm = est.B[:, :, 3]; # Lasso penalized regression - cd -Random.seed!(rng) -est = mlmnet(cd!, dat, λ, addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est = mlmnet(dat, λ, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); est_B_Lasso_cd = est.B[:, :, 3]; ################################# # TEST Lasso - Crossvalidation # ################################# +# Generate random row and column folds +nRowFolds = 10 +nColFolds = 1 + +MatrixLMnet.Random.seed!(rng, 2024) +rowFolds = make_folds(dat.n, nRowFolds, max(nRowFolds, nColFolds)) + +MatrixLMnet.Random.seed!(rng, 2024) +colFolds = make_folds(dat.m, nColFolds, max(nRowFolds, nColFolds)) # Lasso penalized regression - ista cv -Random.seed!(rng) -est = mlmnet_cv(ista!, dat, λ, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); -smmr_Lasso = lambda_min_deprecated(est); +est = mlmnet_cv(dat, λ, rowFolds, colFolds; method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); +smmr_Lasso = lambda_min(est); smmr_ista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - fista cv -Random.seed!(rng) -est = mlmnet_cv(fista!, dat, λ, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); -smmr_Lasso = lambda_min_deprecated(est); +est = mlmnet_cv(dat, λ, rowFolds, colFolds; method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); +smmr_Lasso = lambda_min(est); smmr_fista = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - fista-bt cv -Random.seed!(rng) -est = mlmnet_cv(fista_bt!, dat, λ, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); -smmr_Lasso = lambda_min_deprecated(est); +MatrixLMnet.Random.seed!(rng, 2024) +est = mlmnet_cv(dat, λ, rowFolds, colFolds; method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); +smmr_Lasso = lambda_min(est); smmr_fistabt = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - admm cv -Random.seed!(rng) -est = mlmnet_cv(admm!, dat, λ, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); -smmr_Lasso = lambda_min_deprecated(est); +est = mlmnet_cv(dat, λ, rowFolds, colFolds; method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); +smmr_Lasso = lambda_min(est); smmr_admm = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) # Lasso penalized regression - cd cv -Random.seed!(rng) -est = mlmnet_cv(cd!, dat, λ, 10, 1, addZIntercept = false, addXIntercept = false, isVerbose = false); -smmr_Lasso = lambda_min_deprecated(est); +MatrixLMnet.Random.seed!(rng, 2024) +est = mlmnet_cv(dat, λ, rowFolds, colFolds; method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); +smmr_Lasso = lambda_min(est); smmr_cd = hcat(smmr_Lasso.AvgMSE, smmr_Lasso.AvgPropZero) @@ -145,6 +153,22 @@ Helium.writehe(X, joinpath(dataDir, "Xmat.he")) Helium.writehe(Y, joinpath(dataDir, "Ymat.he")) Helium.writehe(Z, joinpath(dataDir, "Zmat.he")) + +# Save folding indexes +mrowFolds = Matrix{Int64}(undef, length(rowFolds[1]), length(rowFolds)) +for i in 1:length(rowFolds) + mrowFolds[:,i] = rowFolds[i] +end + +mcolFolds = Matrix{Int64}(undef, length(colFolds[1]), length(colFolds)) +for i in 1:length(colFolds) + mcolFolds[:,i] = colFolds[i] +end + +Helium.writehe(mrowFolds, joinpath(dataDir, "row_folds.he")) +Helium.writehe(mcolFolds, joinpath(dataDir, "col_folds.he")) + + # Save estimates results Helium.writehe(est_B_Lasso_ista, joinpath(dataDir, "B_ista.he")) Helium.writehe(est_B_Lasso_fista, joinpath(dataDir, "B_fista.he")) diff --git a/test/mlmnetCV_helpersTests.jl b/test/mlmnetCV_helpersTests.jl new file mode 100644 index 0000000..bb7f02c --- /dev/null +++ b/test/mlmnetCV_helpersTests.jl @@ -0,0 +1,31 @@ + +########### +# Library # +########### +# using MatrixLMnet +# using Helium +# using Test + +###################### +# TEST minimize_rows # +###################### + +@testset "Testing minimize_rows function" begin + # Test with a single element + @test MatrixLMnet.minimize_rows([CartesianIndex(1, 1)]) == [CartesianIndex(1, 1)] + + # Test with multiple elements in the same column + @test MatrixLMnet.minimize_rows([CartesianIndex(3, 1), CartesianIndex(2, 1)]) == [CartesianIndex(2, 1)] + + # Test with multiple elements in different columns + @test MatrixLMnet.minimize_rows([CartesianIndex(1, 1), CartesianIndex(2, 2), CartesianIndex(3, 2)]) == [CartesianIndex(2, 2), CartesianIndex(1, 1)] + + # Test with non-sequential columns + @test MatrixLMnet.minimize_rows([CartesianIndex(3, 5), CartesianIndex(2, 1)]) == [CartesianIndex(3, 5), CartesianIndex(2, 1)] + + # Test with duplicate indices + @test MatrixLMnet.minimize_rows([CartesianIndex(2, 2), CartesianIndex(2, 2)]) == [CartesianIndex(2, 2)] + + # Test with a more complex scenario + @test MatrixLMnet.minimize_rows([CartesianIndex(4, 3), CartesianIndex(1, 3), CartesianIndex(2, 4), CartesianIndex(3, 4)]) == [CartesianIndex(2, 4), CartesianIndex(1, 3)] +end diff --git a/test/mlmnetCvTests.jl b/test/mlmnetCvTests.jl index 0c7e1cc..4fba615 100644 --- a/test/mlmnetCvTests.jl +++ b/test/mlmnetCvTests.jl @@ -42,8 +42,6 @@ dat = RawData(Response(Y), Predictors(X, Z)); λ = [10.0, 5.0, 3.0] α = [1.0] -rng = 2021#MatrixLMnet.Random.MersenneTwister(2021) - numVersion = VERSION if Int(numVersion.minor) < 7 tolVersion=2e-1 @@ -51,18 +49,24 @@ else tolVersion=1e-5 end +# Folding indices +mrow_folds = Helium.readhe(joinpath(dataDir, "row_folds.he")) +row_folds = [mrow_folds[:,i] for i in 1:size(mrow_folds, 2)] + +mcol_folds = Helium.readhe(joinpath(dataDir, "col_folds.he")) +col_folds = [mcol_folds[:,i] for i in 1:size(mcol_folds, 2)] + +rng = StableRNG(123) # MatrixLMnet.Random.MersenneTwister(2021); ############################################# # TEST 1 Lasso vs Elastic Net (𝛼=1) - ista # ############################################# # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, row_folds, col_folds, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, λ, 10, 1, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, λ, row_folds, col_folds, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - ista cv @@ -77,13 +81,11 @@ println("CV Lasso vs Elastic Net when α=1 test 1 - ista: ", ############################################# # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, row_folds, col_folds, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, λ, 10, 1, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, λ, row_folds, col_folds, method = "fista", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista cv @@ -98,13 +100,13 @@ println("CV Lasso vs Elastic Net when α=1 test 2 - fista: ", ########################################################## # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, row_folds, col_folds, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, λ, 10, 1, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est2 = MatrixLMnet.mlmnet_cv(dat, λ, row_folds, col_folds, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista-bt cv @@ -120,13 +122,11 @@ println("CV Lasso vs Elastic Net when α=1 test 3 - fista-bt: ", ############################################ # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); +est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, row_folds, col_folds, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, λ, 10, 1, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); +est2 = MatrixLMnet.mlmnet_cv(dat, λ, row_folds, col_folds, method = "admm", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - fista-bt cv @@ -141,13 +141,13 @@ println("CV Lasso vs Elastic Net when α=1 test 4 - admm: ", ########################################## # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est1 = MatrixLMnet.mlmnet_cv(dat, λ, α, row_folds, col_folds, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net1 = MatrixLMnet.lambda_min(est1); # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) -est2 = MatrixLMnet.mlmnet_cv(dat, λ, 10, 1, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); +MatrixLMnet.Random.seed!(rng, 2024) +est2 = MatrixLMnet.mlmnet_cv(dat, λ, row_folds, col_folds, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); smmr_Net2 = MatrixLMnet.lambda_min(est2); # Lasso penalized regression - cd cv diff --git a/test/mlmnetTests.jl b/test/mlmnetTests.jl index 7619d31..bbd9c1f 100644 --- a/test/mlmnetTests.jl +++ b/test/mlmnetTests.jl @@ -4,9 +4,9 @@ # # using MatrixLM # # using Distributions, Random, Statistics, LinearAlgebra, StatsBase # # using Random -# using MatrixLMnet -# using Helium -# using Test +using MatrixLMnet +using Helium +using Test #################################################### @@ -46,12 +46,11 @@ dat = RawData(Response(Y), Predictors(X, Z)); λ = [10.0, 5.0, 3.0] α = [1.0] -rng = MatrixLMnet.Random.MersenneTwister(2021) - -############################################ -# TEST 1 Lasso vs Elastic Net (𝛼=1) - ista # -############################################ +rng = StableRNG(123) # MatrixLMnet.Random.MersenneTwister(2021); +############################################# +# TEST 1a Lasso vs Elastic Net (𝛼=1) - ista # +############################################# # Elastic net penalized regression est_ista_1 = MatrixLMnet.mlmnet(dat, λ, α, method = "ista", addZIntercept = false, addXIntercept = false, isVerbose = false); @@ -88,10 +87,12 @@ println("Lasso vs Elastic Net when α=1 test 2 - fista: ", @test (B_Net_fista_1 ########################################################## # Elastic net penalized regression +MatrixLMnet.Random.seed!(rng, 2024) est_fistabt_1 = MatrixLMnet.mlmnet(dat, λ, α, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fistabt_1 = est_fistabt_1.B[:, :, 3, 1]; # Elastic net penalized regression +MatrixLMnet.Random.seed!(rng, 2024) est_fistabt_2 = MatrixLMnet.mlmnet(dat, λ, method = "fista_bt", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_fistabt_2 = est_fistabt_2.B[:, :, 3, 1]; @@ -124,12 +125,12 @@ println("Lasso vs Elastic Net when α=1 test 4 - admm: ", @test (B_Net_admm_1 ########################################## # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) +MatrixLMnet.Random.seed!(rng, 2024) est_cd_1 = MatrixLMnet.mlmnet(dat, λ, α, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_cd_1 = est_cd_1.B[:, :, 3, 1]; # Elastic net penalized regression -MatrixLMnet.Random.seed!(rng) +MatrixLMnet.Random.seed!(rng, 2024) est_cd_2 = MatrixLMnet.mlmnet(dat, λ, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = false); B_Net_cd_2 = est_cd_2.B[:, :, 3, 1]; @@ -139,4 +140,16 @@ B_cd = Helium.readhe(joinpath(dataDir, "B_cd.he")) println("Lasso vs Elastic Net when α=1 test 5 - cd: ", @test ≈(B_Net_cd_1, B_cd; atol=1.2e-4) && ≈(B_Net_cd_2, B_cd; atol=1.2e-4)) +################################## +# TEST 6 Data remains unchanged # +################################## + +# Elastic net penalized regression +original_dat_predictors_colsize = size(dat.predictors.X, 2); +est_ista_1 = MatrixLMnet.mlmnet(dat, λ, α, method = "ista", addZIntercept = false, addXIntercept = true, isVerbose = false); + + +println("Test that original data remains unchanged test 6: ", + @test original_dat_predictors_colsize == size(dat.predictors.X, 2)) + println("Tests mlmnet finished!") diff --git a/test/runtests.jl b/test/runtests.jl index aea01a8..5cc5ebd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using MatrixLMnet +using StableRNGs using Distributions, LinearAlgebra using Helium using Test @@ -9,4 +10,6 @@ using Test include("mlmnetCvTests.jl") include("summaryCvTests.jl") include("mlmnetBicTests.jl") + include("utilitiesTests.jl") + include("mlmnetCV_helpersTests.jl") end \ No newline at end of file diff --git a/test/utilitiesTests.jl b/test/utilitiesTests.jl index d1ba0de..58a1956 100644 --- a/test/utilitiesTests.jl +++ b/test/utilitiesTests.jl @@ -32,13 +32,13 @@ dat = RawData(Response(Y), Predictors(X, Z)); rng = MatrixLMnet.Random.MersenneTwister(2021) -################################################### -#Test the dimension of results by util functions # -################################################### +###################################################### +# TEST 1: the dimension of results by util functions # +###################################################### est = mlmnet(dat, λ, α, method = "cd", addZIntercept = true, addXIntercept = true, isVerbose = true) -predicted = predict(est, est.data.predictors) +predicted = MatrixLMnet.predict(est, est.data.predictors) #Test the function predict #println(size(predicted[:,:,1,1] )) @@ -53,14 +53,210 @@ coef = MatrixLMnet.coef_3d(est) end -################################ -#Test2: test predict function ## -################################ +################################# +# TEST 2: test predict function # +################################# newPredictors = Predictors(X, Z, false, false) -predicted = predict(est, newPredictors) +predicted = MatrixLMnet.predict(est, newPredictors) est2 = mlmnet(dat, λ, α, method = "cd", addZIntercept = false, addXIntercept = false, isVerbose = true) newPredictors2 = Predictors(hcat(ones(size(X, 1)), X), hcat(ones(size(Z, 1)), Z), true, true) -predicted2 = predict(est2, newPredictors2) +predicted2 = MatrixLMnet.predict(est2, newPredictors2) -@test size(predicted[:,:,1,1] ) == size(predicted2[:,:,1,1] ) \ No newline at end of file +@test size(predicted[:,:,1,1] ) == size(predicted2[:,:,1,1] ) + + + +####################################### +# TEST 3: test backtransform function # +####################################### + +using MatrixLMnet: normalize!, mean, norm +using Distributions, Random +#= +Description: +Simulate a dataset to test the `backtransform!()` function included +in the `mlmnet()` function. The backtransform!() back-transform +coefficient estimates B in place if X and Z were centered and/or normalized. +All four cases are tested: + - no X intercept, no Z intercept + - X has an intercept, no Z intercept + - no X intercept, Z has an intercept + - X has an intercept, Z has an intercept +=# + +################### +# Simulated Data # +################### +# Model: 𝐘 = 𝐗 𝛃 𝐙' + 𝜎𝜖, with 𝜖∼𝑁(0,1) +rng = MersenneTwister(2024) + +d = Normal(1.0, 1.0); +# Matrices dimensions +n = 240; m = 7; p = 9; q = 4; + +# Simulate the coefficients matrix B +list_coefs = [0,1,1.5,2,2.5,3,.5] +B = rand(list_coefs, p, q) + +# Simulate predictors +X = hcat(ones(n), rand(d, n, p-1)); + +# Simulate Z +list_Z_coefs = [0,1] +Z = hcat(ones(m), rand(list_Z_coefs, m, q-1)) + +# Simulate Y +σ = 3; +Y = X*B*Z' + σ*rand(Normal(0, 1), n, m); + +X = X[:, 2:end]; +Z = Z[:, 2:end]; + +mlmdata = RawData(Response(Y), Predictors(X, Z)); + +############################################################################# +# TEST 3-a test backtransform: addXIntercept = false, addZIntercept = false # +############################################################################# +# MLM +mlm_est = MatrixLMnet.MatrixLM.mlm( + mlmdata, + addXIntercept = false, + addZIntercept = false +); + +# MLMnet +mlmnet_est = mlmnet( + mlmdata, + [0.0], [0.0], # lambda and alpha are set to 0 + method = "fista", stepsize = 0.01, + toNormalize = true, + isNaive = false, + addXIntercept = false, + addZIntercept = false, + isVerbose = false, + thresh = 1e-16 +); + +println("Backtransform test α=0 and λ=0 test 3-a: ", @test isapprox(mlm_est.B, mlmnet_est.B, atol = 1e-3)) +# hcat(mlm_est.B, mlmnet_est.B) + +############################################################################ +# TEST 3-b test backtransform: addXIntercept = true, addZIntercept = false # +############################################################################ +# MLM +mlm_est = MatrixLMnet.MatrixLM.mlm( + mlmdata, + addXIntercept = true, + addZIntercept = false +); + +# MLMnet +mlmnet_est = mlmnet( + mlmdata, + [0.0], [0.0], # lambda and alpha are set to 0 + method = "fista", stepsize = 0.01, + toNormalize = true, + isNaive = true, + addXIntercept = true, + addZIntercept = false, + isVerbose = false, + thresh = 1e-16 +); + +println("Backtransform test α=0 and λ=0 test 3-b: ", @test isapprox(mlm_est.B, mlmnet_est.B, atol = 1e-3)) + +############################################################################ +# TEST 3-c test backtransform: addXIntercept = false, addZIntercept = true # +############################################################################ +# MLM +mlm_est = MatrixLMnet.MatrixLM.mlm( + mlmdata, + addXIntercept = false, + addZIntercept = true +); + +# MLMnet +mlmnet_est = mlmnet( + mlmdata, + [0.0], [0.0], # lambda and alpha are set to 0 + method = "fista", stepsize = 0.01, + toNormalize = true, + isNaive = true, + addXIntercept = false, + addZIntercept = true, + isVerbose = false, + thresh = 1e-16 +); + +println("Backtransform test α=0 and λ=0 test 3-c: ", @test isapprox(mlm_est.B, mlmnet_est.B, atol = 1e-3)) + +########################################################################### +# TEST 3-d test backtransform: addXIntercept = true, addZIntercept = true # +########################################################################### +# MLM +mlm_est = MatrixLMnet.MatrixLM.mlm( + mlmdata, + addXIntercept = true, + addZIntercept = true +); + +# MLMnet +mlmnet_est = mlmnet( + mlmdata, + [0.0], [0.0], # lambda and alpha are set to 0 + method = "fista", stepsize = 0.01, + toNormalize = true, + isNaive = true, + addXIntercept = true, + addZIntercept = true, + isVerbose = false, + thresh = 1e-16 +); + +println("Backtransform test α=0 and λ=0 test 3-d: ", @test isapprox(mlm_est.B, mlmnet_est.B, atol = 1e-5)) + +################################### +# TEST 4: test normalize function # +################################### + +function is_normalized(A) + for col in eachcol(A) + if norm(col) ≈ 1.0 || all(iszero, col) + continue + else + return false + end + end + return true +end + +using MatrixLMnet: normalize!, norm, mean + +@testset "With Intercept" begin + A = hcat(ones(10), rand(Float64, 10, 3) * 100) + original_A = copy(A) + means, norms = normalize!(A, true) + + # Check normalization + @test is_normalized(A) + @test means == mean(original_A, dims=1) + @test !(A[:, 1] ≈ zeros(size(A, 1), 1)) + @test ones(1, size(A, 2)) ≈ mapslices(col -> norm(col), A, dims = 1) +end + +# @testset "Without Intercept" begin + A = rand(Float64, 10, 3) * 100 + original_A = copy(A) + means, norms = normalize!(A, false) + + # Check normalization + @test is_normalized(A) + # do not test means, since no intercept no centering + @test ones(1, size(A, 2)) ≈ mapslices(col -> norm(col), A, dims = 1) +# end + + + + +println("Tests utilities finished!")