Skip to content

Commit cd038e8

Browse files
committed
WIP: auto diff
1 parent d56aa35 commit cd038e8

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

R/multiscaleSVDxpts.R

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3038,32 +3038,61 @@ project_to_nonneg_orthogonal_alt <- function(X, tol = 1e-6, max_iter = 10) {
30383038
Y
30393039
}
30403040

3041+
3042+
#' Gradient of the Invariant Orthogonality Measure
3043+
#'
3044+
#' This function computes the gradient of the orthogonality defect measure with respect to the input matrix `A`.
3045+
#' The gradient is useful for optimization techniques that require gradient information. The gradient will be zero
3046+
#' for matrices where `AtA` equals the diagonal matrix `D`.
3047+
#'
3048+
#' @param A A numeric matrix.
3049+
#' @return A numeric matrix representing the gradient of the orthogonality defect measure.
3050+
#' @examples
3051+
#' A <- matrix(runif(20), nrow = 10, ncol = 2)
3052+
#' gradient_invariant_orthogonality(A)
3053+
#' @export
3054+
gradient_invariant_orthogonality <- function(A) {
3055+
# Step 1: Compute norm_A_F2
3056+
norm_A_F2 <- sum(A^2)
3057+
if (norm_A_F2 == 0) {
3058+
stop("Norm is zero, cannot compute gradient")
3059+
}
3060+
3061+
# Step 2: Compute AtA
3062+
AtA <- t(A) %*% A
3063+
3064+
# Step 3: Compute Frobenius norm of AtA_normalized
3065+
norm_AtA_normalized_F2 <- norm(AtA / norm_A_F2, "F")^2
3066+
3067+
# Step 4: Compute gradient
3068+
gradient <- (2 / norm_A_F2^2) * (A %*% AtA - norm_AtA_normalized_F2 * A)
3069+
3070+
return(gradient)
3071+
}
3072+
3073+
3074+
gradient_invariant_orthogonality2 <- function(A) {
3075+
gradient_invariant_orthogonality(A) - gradient_invariant_orthogonality(diag(ncol(A)))
3076+
}
3077+
30413078
#' Calculate the invariant orthogonality defect that is zero for diagonal matrices
30423079
#'
30433080
#' @param A Input matrix (n x p, where n >> p)
30443081
#' @return The invariant orthogonality defect that is zero for diagonal matrices
30453082
#' @export
30463083
invariant_orthogonality_defect_diag_zero <- function(A) {
3047-
A=as.matrix(A)
3048-
if (!is.matrix(A) || !is.numeric(A)) {
3049-
stop("invariant_orthogonality_defect_diag_zero: 'A' must be a numeric matrix")
3050-
}
30513084
norm_A_F2 <- sum(A^2)
3052-
if (norm_A_F2 == 0) {
3053-
return( 0 )
3054-
}
30553085
AtA <- t(A) %*% A
3056-
AtA_normalized <- AtA / norm_A_F2
3057-
3086+
AtA_normalized <- AtA / norm_A_F2
30583087
column_sums_sq <- colSums(A^2)
30593088
D <- diag(column_sums_sq / norm_A_F2)
3060-
3061-
orthogonality_defect <- norm(AtA_normalized - D, "F")^2
3062-
3089+
orthogonality_defect <- sum( (AtA_normalized - D)^2)
30633090
return(orthogonality_defect)
30643091
}
30653092

30663093

3094+
3095+
30673096
#' Gradient of the Invariant Orthogonality Defect Measure
30683097
#'
30693098
#' This function computes the gradient of the orthogonality defect measure with respect to the input matrix `A`.
@@ -3076,7 +3105,13 @@ invariant_orthogonality_defect_diag_zero <- function(A) {
30763105
#' A <- matrix(runif(20), nrow = 10, ncol = 2)
30773106
#' gradient_invariant_orthogonality_defect_diag_zero(A)
30783107
#' @export
3079-
gradient_invariant_orthogonality_defect_diag_zero <- function(A) {
3108+
gradient_invariant_orthogonality_defect_diag_zero<- function(A) {
3109+
#### place holder until we get the correct analytical derivative
3110+
f1=invariant_orthogonality_defect_diag_zero
3111+
matrix( salad::d( f1( salad::dual(A) ) ), nrow=nrow(A) )
3112+
}
3113+
3114+
gradient_invariant_orthogonality_defect_diag_zero_old1 <- function(A) {
30803115
A <- as.matrix(A)
30813116
if (!is.matrix(A) || !is.numeric(A)) {
30823117
stop("gradient_invariant_orthogonality_defect_diag_zero: 'A' must be a numeric matrix")
@@ -3135,7 +3170,7 @@ gradient_invariant_orthogonality_defect_diag_zero_old2 <- function(A) {
31353170
return(gradient)
31363171
}
31373172

3138-
gradient_invariant_orthogonality_defect_diag_zero_old <- function(A) {
3173+
gradient_invariant_orthogonality_defect_diag_zero_old3 <- function(A) {
31393174
A=as.matrix(A)
31403175
if (!is.matrix(A) || !is.numeric(A)) {
31413176
stop("gradient_invariant_orthogonality_defect_diag_zero: 'A' must be a numeric matrix")
@@ -4579,21 +4614,25 @@ simlr.search <- function(
45794614

45804615
if ( nrow(options_df) > 1 ) {
45814616
rowsel = 1:(nrow(options_df)-1)
4582-
if ( all( finalE > options_df$final_energy[rowsel] ) & verbose > 0 ) {
4583-
print( paste("improvement" ) )
4584-
print( parameters )
4617+
if ( all( finalE > options_df$final_energy[rowsel] ) ) {
45854618
bestresult = simlrX$simlr_result
45864619
bestsig = simlrX$significance
4587-
print( head( bestresult$v[[ length(bestresult$v)]] ))
4620+
bestparams = parameters
4621+
if ( verbose > 0 ) {
4622+
print( paste("improvement" ) )
4623+
print( parameters )
4624+
print( head( bestresult$v[[ length(bestresult$v)]] ))
4625+
}
45884626
}
4589-
}
4627+
} else { bestresult=bestsig=bestparams=NA }
45904628
}
45914629
if ( verbose ) {
45924630
print( options_df[ which.max(options_df$final_energy),] )
45934631
cat("el finito\n")
45944632
}
4595-
# return(options_df)
4596-
return( list( parameters=options_df, simlr_result=bestresult, significance=bestsig ))
4633+
outlist = list( simlr_result=bestresult, significance=bestsig, parameters=options_df )
4634+
return( outlist )
4635+
# return( list( parameters=options_df, simlr_result=bestresult, significance=bestsig ))
45974636
}
45984637

45994638

0 commit comments

Comments
 (0)