-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
223 additions
and
436 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,90 +1,60 @@ | ||
############################################################################# | ||
# diag.jl | ||
# Returns the kth diagonal of a matrix expression | ||
# All expressions and atoms are subtpyes of AbstractExpr. | ||
# Please read expressions.jl first. | ||
############################################################################# | ||
|
||
# k >= min(num_cols, num_rows) || k <= -min(num_rows, num_cols) | ||
|
||
### Diagonal | ||
### Represents the kth diagonal of an mxn matrix as a (min(m, n) - k) x 1 vector | ||
|
||
mutable struct DiagAtom <: AbstractExpr | ||
children::Tuple{AbstractExpr} | ||
size::Tuple{Int,Int} | ||
k::Int | ||
|
||
function DiagAtom(x::AbstractExpr, k::Int = 0) | ||
(num_rows, num_cols) = x.size | ||
|
||
if k >= min(num_rows, num_cols) || k <= -min(num_rows, num_cols) | ||
K = min(x.size[1], x.size[2]) | ||
if !(-K < k < K) | ||
error("Bounds error in calling diag") | ||
end | ||
|
||
children = (x,) | ||
return new(children, (min(num_rows, num_cols) - k, 1), k) | ||
return new((x,), (K - k, 1), k) | ||
end | ||
end | ||
|
||
head(io::IO, ::DiagAtom) = print(io, "diag") | ||
|
||
## Type Definition Ends | ||
|
||
function Base.sign(x::DiagAtom) | ||
return sign(x.children[1]) | ||
end | ||
Base.sign(x::DiagAtom) = sign(x.children[1]) | ||
|
||
# The monotonicity | ||
function monotonicity(x::DiagAtom) | ||
return (Nondecreasing(),) | ||
end | ||
monotonicity(::DiagAtom) = (Nondecreasing(),) | ||
|
||
# If we have h(x) = f o g(x), the chain rule says h''(x) = g'(x)^T f''(g(x))g'(x) + f'(g(x))g''(x); | ||
# this represents the first term | ||
function curvature(x::DiagAtom) | ||
return ConstVexity() | ||
end | ||
curvature(::DiagAtom) = ConstVexity() | ||
|
||
function evaluate(x::DiagAtom) | ||
return LinearAlgebra.diag(evaluate(x.children[1]), x.k) | ||
end | ||
|
||
## API begins | ||
LinearAlgebra.diag(x::AbstractExpr, k::Int = 0) = DiagAtom(x, k) | ||
## API ends | ||
|
||
# Finds the "k"-th diagonal of x as a column vector | ||
# If k == 0, it returns the main diagonal and so on | ||
# Let x be of size m x n and d be the diagonal | ||
# Finds the "k"-th diagonal of x as a column vector. | ||
# | ||
# If k == 0, it returns the main diagonal and so on. | ||
# | ||
# Let x be of size m x n and d be the diagonal. | ||
# | ||
# Since x is vectorized, the way canonicalization works is: | ||
# | ||
# 1. We calculate the size of the diagonal (sz_diag) and the first index | ||
# of vectorized x that will be part of d | ||
# of vectorized x that will be part of d | ||
# 2. We create the coefficient matrix for vectorized x, called coeff of size | ||
# sz_diag x mn | ||
# sz_diag x mn | ||
# 3. We populate coeff with 1s at the correct indices | ||
# The canonical form will then be: | ||
# coeff * x - d = 0 | ||
# | ||
# The canonical form will then be: coeff * x - d = 0 | ||
function new_conic_form!(context::Context{T}, x::DiagAtom) where {T} | ||
(num_rows, num_cols) = x.children[1].size | ||
k = x.k | ||
|
||
if k >= 0 | ||
start_index = k * num_rows + 1 | ||
sz_diag = Base.min(num_rows, num_cols - k) | ||
num_rows, num_cols = x.children[1].size | ||
if x.k >= 0 | ||
start_index = x.k * num_rows + 1 | ||
sz_diag = Base.min(num_rows, num_cols - x.k) | ||
else | ||
start_index = -k + 1 | ||
sz_diag = Base.min(num_rows + k, num_cols) | ||
start_index = -x.k + 1 | ||
sz_diag = Base.min(num_rows + x.k, num_cols) | ||
end | ||
|
||
select_diag = spzeros(T, sz_diag, length(x.children[1])) | ||
for i in 1:sz_diag | ||
select_diag[i, start_index] = 1 | ||
start_index += num_rows + 1 | ||
end | ||
|
||
child_obj = conic_form!(context, only(AbstractTrees.children(x))) | ||
obj = operate(add_operation, T, sign(x), select_diag, child_obj) | ||
return obj | ||
return operate(add_operation, T, sign(x), select_diag, child_obj) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,49 @@ | ||
############################################################################# | ||
# diagm.jl | ||
# Converts a vector of size n into an n x n diagonal | ||
# All expressions and atoms are subtpyes of AbstractExpr. | ||
# Please read expressions.jl first. | ||
############################################################################# | ||
|
||
mutable struct DiagMatrixAtom <: AbstractExpr | ||
children::Tuple{AbstractExpr} | ||
size::Tuple{Int,Int} | ||
|
||
function DiagMatrixAtom(x::AbstractExpr) | ||
(num_rows, num_cols) = x.size | ||
|
||
num_rows, num_cols = x.size | ||
if num_rows == 1 | ||
sz = num_cols | ||
return new((x,), (num_cols, num_cols)) | ||
elseif num_cols == 1 | ||
sz = num_rows | ||
return new((x,), (num_rows, num_rows)) | ||
else | ||
throw( | ||
ArgumentError( | ||
"Only vectors are allowed for diagm/Diagonal. Did you mean to use diag?", | ||
), | ||
) | ||
msg = "Only vectors are allowed for diagm/Diagonal. Did you mean to use diag?" | ||
throw(ArgumentError(msg)) | ||
end | ||
|
||
children = (x,) | ||
return new(children, (sz, sz)) | ||
end | ||
end | ||
|
||
head(io::IO, ::DiagMatrixAtom) = print(io, "diagm") | ||
|
||
function Base.sign(x::DiagMatrixAtom) | ||
return sign(x.children[1]) | ||
end | ||
Base.sign(x::DiagMatrixAtom) = sign(x.children[1]) | ||
|
||
# The monotonicity | ||
function monotonicity(x::DiagMatrixAtom) | ||
return (Nondecreasing(),) | ||
end | ||
monotonicity(::DiagMatrixAtom) = (Nondecreasing(),) | ||
|
||
# If we have h(x) = f o g(x), the chain rule says h''(x) = g'(x)^T f''(g(x))g'(x) + f'(g(x))g''(x); | ||
# this represents the first term | ||
function curvature(x::DiagMatrixAtom) | ||
return ConstVexity() | ||
end | ||
curvature(::DiagMatrixAtom) = ConstVexity() | ||
|
||
function evaluate(x::DiagMatrixAtom) | ||
return LinearAlgebra.Diagonal(vec(evaluate(x.children[1]))) | ||
end | ||
|
||
function LinearAlgebra.diagm((d, x)::Pair{<:Integer,<:AbstractExpr}) | ||
d == 0 || throw(ArgumentError("only the main diagonal is supported")) | ||
if d != 0 | ||
throw(ArgumentError("only the main diagonal is supported")) | ||
end | ||
return DiagMatrixAtom(x) | ||
end | ||
|
||
LinearAlgebra.Diagonal(x::AbstractExpr) = DiagMatrixAtom(x) | ||
|
||
LinearAlgebra.diagm(x::AbstractExpr) = DiagMatrixAtom(x) | ||
|
||
function new_conic_form!(context::Context{T}, x::DiagMatrixAtom) where {T} | ||
obj = conic_form!(context, only(AbstractTrees.children(x))) | ||
|
||
sz = x.size[1] | ||
I = collect(1:sz+1:sz*sz) | ||
J = collect(1:sz) | ||
V = one(T) | ||
coeff = create_sparse(T, I, J, V, sz * sz, sz) | ||
# coeff = create_sparse(, 1:sz, one(T), | ||
|
||
return operate(add_operation, T, sign(x), coeff, obj) | ||
end |
Oops, something went wrong.