Skip to content

Commit

Permalink
matrix operations now use and check row/column names; indexing by col…
Browse files Browse the repository at this point in the history
…umn names
  • Loading branch information
john-d-fox committed Sep 9, 2024
1 parent bb369d4 commit 3e2e9a5
Show file tree
Hide file tree
Showing 3 changed files with 402 additions and 10 deletions.
10 changes: 9 additions & 1 deletion R/latexMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ latexMatrix <- function(
x
}

if (0 != anyDuplicated(rownames, incomparables=""))
stop("there are duplicated row names")
cnames <- colnames
cnames <- sub("\\\\phantom\\{.*\\}", "", cnames)
if (0 != anyDuplicated(cnames, incomparables=""))
stop("there are duplicated column names")

if (is.null(matrix)) matrix <- "pmatrix"

end.at.n.minus.1 <- gsub(" ", "", end.at) == c("n-1", "m-1")
Expand All @@ -302,7 +309,6 @@ latexMatrix <- function(
# start composing output string:

result <- paste0(if (fractions) "\\renewcommand*{\\arraystretch}{1.5} \n",
# if (!missing(lhs)) paste0(lhs, " = \n"),
"\\begin{", matrix, "} \n"
)

Expand Down Expand Up @@ -858,6 +864,8 @@ as.double.latexMatrix <- function(x, locals=list(), ...){
`[.latexMatrix` <- function(x, i, j, ..., drop){
numericDimensions(x)
X <- getBody(x)
if (!is.null(nms <- rownames(x))) rownames(X) <- nms
if (!is.null(nms <- colnames(x))) colnames(X) <- nms
X <- X[i, j, drop=FALSE]
X <- latexMatrix(X)
updateWrapper(X, getWrapper(x))
Expand Down
55 changes: 51 additions & 4 deletions R/latexMatrixOperations.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,14 @@ matsum.latexMatrix <- function(A, ..., as.numeric=TRUE){
}

numericDimensions(A)
for (M in matrices) numericDimensions(M)

dimnames <- dimnames(A)
for (M in matrices) {
numericDimensions(M)
if (!isTRUE(all.equal(dimnames, dimnames(M)))){
stop("matrix dimension names don't match")
}
}

wrapper <- getWrapper(A)

Expand All @@ -187,6 +194,7 @@ matsum.latexMatrix <- function(A, ..., as.numeric=TRUE){
}
A <- latexMatrix(A)
A <- updateWrapper(A, wrapper)
Dimnames(A) <- dimnames
A
}

Expand All @@ -211,6 +219,7 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
# unary -
if (is.null(B)){
numericDimensions(A)
dimnames <- dimnames(A)
if (as.numeric && is.numeric(A)){
A <- as.numeric(A)
A <- -A
Expand All @@ -221,6 +230,7 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
}
A <- latexMatrix(A)
A <- updateWrapper(A, getWrapper(A))
Dimnames(A) <- dimnames
return(A)
}
if (!inherits(B, "latexMatrix")){
Expand All @@ -229,6 +239,10 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
}
numericDimensions(A)
numericDimensions(B)
if (!isTRUE(all.equal(dimnames(A), dimnames(B)))){
stop("matrix dimension names don't match")
}
dimnames <- dimnames(A)
dimA <- Dim(A)
dimB <- Dim(B)
if (!all(dimA == dimB))
Expand All @@ -246,6 +260,7 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
}
A <- latexMatrix(A)
A <- updateWrapper(A, wrapper)
Dimnames(A) <- dimnames
A
}

Expand Down Expand Up @@ -278,6 +293,7 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
A <- getBody(e2)
dimA <- dim(A)
wrapper <- getWrapper(e2)
dimnames <- dimnames(e2)
result <- matrix(if (swapped) {
paste(sapply(A, parenthesize), latexMultSymbol, e1)
} else{
Expand All @@ -287,6 +303,7 @@ matdiff.latexMatrix <- function(A, B=NULL, as.numeric=TRUE, ...){
result <- latexMatrix(result)
result <- updateWrapper(result, getWrapper(e2))
result$dim <- Dim(e2)
Dimnames(result) <- dimnames
result
}

Expand Down Expand Up @@ -352,6 +369,17 @@ matmult.latexMatrix <- function(X, ..., simplify=TRUE,
numericDimensions(X)
for (M in matrices) numericDimensions(M)

n.matrices <- length(matrices)
if (n.matrices > 1){
for (i in 1:(n.matrices - 1)){
if (!isTRUE(all.equal(colnames(M[[i]]),
rownames(M[[i + 1]])))){
stop("matrix dimension names don't match")
}
}
}
dimnames <- list(rownames = rownames(X),
colnames = colnames(matrices[[n.matrices]]))
wrapper <- getWrapper(X)

if (as.numeric && is.numeric(X) && all(sapply(matrices, is.numeric))){
Expand Down Expand Up @@ -385,6 +413,7 @@ matmult.latexMatrix <- function(X, ..., simplify=TRUE,
}
X <- latexMatrix(X)
X <- updateWrapper(X, wrapper)
Dimnames(X) <- dimnames
return(X)

}
Expand All @@ -408,6 +437,8 @@ matpower.latexMatrix <- function(X, power, simplify=TRUE,

numericDimensions(X)
dimX <- Dim(X)
dimnames <- dimnames(X)

if (dimX[1] != dimX[2]) stop ("X is not square")
if (power != round(power) || power < -1)
stop("'power' must be an integer >= -1")
Expand All @@ -417,6 +448,7 @@ matpower.latexMatrix <- function(X, power, simplify=TRUE,
if (power == 0){
result <- latexMatrix(diag(dimX[1]))
result <- updateWrapper(result, wrapper)
Dimnames(result) <- dimnames
return(result)
}

Expand Down Expand Up @@ -444,7 +476,10 @@ matpower.latexMatrix <- function(X, power, simplify=TRUE,
result
}
}
Xp <- updateWrapper(Xp, wrapper)
if (inherits(Xp, "latexMatrix")){
Xp <- updateWrapper(Xp, wrapper)
Dimnames(Xp) <- dimnames
}
return(Xp)
}

Expand Down Expand Up @@ -473,7 +508,9 @@ t.latexMatrix <- function(x){
numericDimensions(x)
result <- latexMatrix(t(getBody(x)))
result <- updateWrapper(result, getWrapper(x))
result$dim <- rev(Dim(x))
dimnames <- dimnames(x)
Dimnames(result) <- list(rownames = dimnames[[2]],
colnames = dimnames[[1]])
result
}

Expand Down Expand Up @@ -525,10 +562,14 @@ solve.latexMatrix <- function (a, b, simplify=FALSE, as.numeric=TRUE,
if (Nrow(a) != Ncol(a)) stop("matrix 'a' must be square")
if (!missing(b)) warning("'b' argument to solve() ignored")

dimnames <- dimnames(a)

if (as.numeric && is.numeric(a)){
a.inv <- solve(as.numeric(a))
a.inv <- latexMatrix(a.inv)
return(updateWrapper(a.inv, getWrapper(a)))
a.inv <- updateWrapper(a.inv, getWrapper(a))
Dimnames(a.inv) <- dimnames
return(a.inv)
}

det <- determinant(a)
Expand All @@ -554,6 +595,7 @@ solve.latexMatrix <- function (a, b, simplify=FALSE, as.numeric=TRUE,
A_inv <- t(A_inv) # adjoint
result <- latexMatrix(A_inv)
result <- updateWrapper(result, getWrapper(a))
Dimnames(result) <- dimnames

if (!simplify) {
return(result)
Expand All @@ -575,6 +617,11 @@ setMethod("kronecker",
numericDimensions(X)
numericDimensions(Y)

if (!is.null(unlist(dimnames(X))) &&
!is.null(unlist(dimnames(X)))){
message("Note: dimension names are ignored")
}

latexMultSymbol <- getLatexMultSymbol()

Xmat <- getBody(X)
Expand Down
Loading

0 comments on commit 3e2e9a5

Please sign in to comment.