Skip to content

Commit

Permalink
Tidy src/atoms/affine
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jan 10, 2024
1 parent 8dfc7f6 commit dda073e
Show file tree
Hide file tree
Showing 17 changed files with 223 additions and 436 deletions.
66 changes: 20 additions & 46 deletions src/atoms/affine/add_subtract.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,30 @@
#############################################################################
# add_subtract.jl
# Handles unary negation, addition and subtraction of variables, constants
# and expressions.
# All expressions and atoms are subtpyes of AbstractExpr.
# Please read expressions.jl first.
#############################################################################

### Unary Negation

mutable struct NegateAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}

function NegateAtom(x::AbstractExpr)
children = (x,)
return new(children, x.size)
end
NegateAtom(x::AbstractExpr) = new((x,), x.size)
end
head(io::IO, ::NegateAtom) = print(io, "-")

function Base.sign(x::NegateAtom)
return -sign(x.children[1])
end
Base.sign(x::NegateAtom) = -sign(x.children[1])

function monotonicity(x::NegateAtom)
return (Nonincreasing(),)
end
monotonicity(::NegateAtom) = (Nonincreasing(),)

function curvature(x::NegateAtom)
return ConstVexity()
end
curvature(::NegateAtom) = ConstVexity()

function evaluate(x::NegateAtom)
return -evaluate(x.children[1])
end
evaluate(x::NegateAtom) = -evaluate(x.children[1])

Base.:-(x::AbstractExpr) = NegateAtom(x)

Base.:-(x::Union{Constant,ComplexConstant}) = constant(-evaluate(x))

function new_conic_form!(context::Context{T}, A::NegateAtom) where {T}
subobj = conic_form!(context, only(AbstractTrees.children(A)))
if subobj isa Value
return -subobj
else
return operate(-, T, sign(A), subobj)
end
return operate(-, T, sign(A), subobj)
end

### Addition
mutable struct AdditionAtom <: AbstractExpr
children::Array{AbstractExpr,1}
size::Tuple{Int,Int}
Expand All @@ -65,7 +42,6 @@ mutable struct AdditionAtom <: AbstractExpr
else
error("Cannot add expressions of sizes $(x.size) and $(y.size)")
end

if x.size != y.size
if (x isa Constant || x isa ComplexConstant) && (x.size == (1, 1))
x = constant(fill(evaluate(x), y.size))
Expand All @@ -74,7 +50,6 @@ mutable struct AdditionAtom <: AbstractExpr
y = constant(fill(evaluate(y), x.size))
end
end

# see if we're forming a sum of more than two terms and condense them
children = AbstractExpr[]
if isa(x, AdditionAtom)
Expand All @@ -93,38 +68,37 @@ end

head(io::IO, ::AdditionAtom) = print(io, "+")

function Base.sign(x::AdditionAtom)
return sum(Sign[sign(child) for child in x.children])
# Creating an array of type Sign and adding all the sign of xhildren of x so if anyone is complex the resultant sign would be complex.
end
# Creating an array of type Sign and adding all the sign of children of x,
# so if anyone is complex the resultant sign would be complex.
Base.sign(x::AdditionAtom) = sum(sign.(x.children))

function monotonicity(x::AdditionAtom)
return Monotonicity[Nondecreasing() for child in x.children]
end
monotonicity(x::AdditionAtom) = [Nondecreasing() for _ in x.children]

function curvature(x::AdditionAtom)
return ConstVexity()
end
curvature(::AdditionAtom) = ConstVexity()

function evaluate(x::AdditionAtom)
# broadcast is used in reduction instead of using sum directly to support addition
# between scalars and arrays
# broadcast is used in reduction instead of using sum directly to support
# addition between scalars and arrays
return mapreduce(evaluate, (a, b) -> a .+ b, x.children)
end

function new_conic_form!(context::Context{T}, x::AdditionAtom) where {T}
obj = operate(
return operate(
+,
T,
sign(x),
(conic_form!(context, c) for c in AbstractTrees.children(x))...,
)
return obj
end

Base.:+(x::AbstractExpr, y::AbstractExpr) = AdditionAtom(x, y)

Base.:+(x::Value, y::AbstractExpr) = AdditionAtom(constant(x), y)

Base.:+(x::AbstractExpr, y::Value) = AdditionAtom(x, constant(y))

Base.:-(x::AbstractExpr, y::AbstractExpr) = x + (-y)

Base.:-(x::Value, y::AbstractExpr) = constant(x) + (-y)

Base.:-(x::AbstractExpr, y::Value) = x + constant(-y)
27 changes: 9 additions & 18 deletions src/atoms/affine/conjugate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,18 @@ mutable struct ConjugateAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}

function ConjugateAtom(x::AbstractExpr)
children = (x,)
return new(children, (x.size[1], x.size[2]))
end
ConjugateAtom(x::AbstractExpr) = new((x,), x.size)
end

head(io::IO, ::ConjugateAtom) = print(io, "conj")

function Base.sign(x::ConjugateAtom)
return sign(x.children[1])
end
Base.sign(x::ConjugateAtom) = sign(x.children[1])

function monotonicity(x::ConjugateAtom)
return (Nondecreasing(),)
end
monotonicity(::ConjugateAtom) = (Nondecreasing(),)

function curvature(x::ConjugateAtom)
return ConstVexity()
end
curvature(::ConjugateAtom) = ConstVexity()

function evaluate(x::ConjugateAtom)
return conj(evaluate(x.children[1]))
end
evaluate(x::ConjugateAtom) = conj(evaluate(x.children[1]))

Check warning on line 16 in src/atoms/affine/conjugate.jl

View check run for this annotation

Codecov / codecov/patch

src/atoms/affine/conjugate.jl#L16

Added line #L16 was not covered by tests

function new_conic_form!(context::Context{T}, x::ConjugateAtom) where {T}
objective = conic_form!(context, only(AbstractTrees.children(x)))
Expand All @@ -33,9 +23,10 @@ end
function Base.conj(x::AbstractExpr)
if sign(x) == ComplexSign()
return ConjugateAtom(x)
else
return x
end
return x
end

Base.conj(x::Constant) = x

Base.conj(x::ComplexConstant) = ComplexConstant(real(x), -imag(x))
5 changes: 0 additions & 5 deletions src/atoms/affine/conv.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
#############################################################################
# conv.jl
# Handles convolution between a constant vector and an expression vector.
#############################################################################

function conv(x::Value, y::AbstractExpr)
if (size(x, 2) != 1 && length(size(x)) != 1) || size(y, 2) != 1
error("convolution only supported between two vectors")
Expand Down
76 changes: 23 additions & 53 deletions src/atoms/affine/diag.jl
Original file line number Diff line number Diff line change
@@ -1,90 +1,60 @@
#############################################################################
# diag.jl
# Returns the kth diagonal of a matrix expression
# All expressions and atoms are subtpyes of AbstractExpr.
# Please read expressions.jl first.
#############################################################################

# k >= min(num_cols, num_rows) || k <= -min(num_rows, num_cols)

### Diagonal
### Represents the kth diagonal of an mxn matrix as a (min(m, n) - k) x 1 vector

mutable struct DiagAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}
k::Int

function DiagAtom(x::AbstractExpr, k::Int = 0)
(num_rows, num_cols) = x.size

if k >= min(num_rows, num_cols) || k <= -min(num_rows, num_cols)
K = min(x.size[1], x.size[2])
if !(-K < k < K)
error("Bounds error in calling diag")
end

children = (x,)
return new(children, (min(num_rows, num_cols) - k, 1), k)
return new((x,), (K - k, 1), k)
end
end

head(io::IO, ::DiagAtom) = print(io, "diag")

## Type Definition Ends

function Base.sign(x::DiagAtom)
return sign(x.children[1])
end
Base.sign(x::DiagAtom) = sign(x.children[1])

# The monotonicity
function monotonicity(x::DiagAtom)
return (Nondecreasing(),)
end
monotonicity(::DiagAtom) = (Nondecreasing(),)

# If we have h(x) = f o g(x), the chain rule says h''(x) = g'(x)^T f''(g(x))g'(x) + f'(g(x))g''(x);
# this represents the first term
function curvature(x::DiagAtom)
return ConstVexity()
end
curvature(::DiagAtom) = ConstVexity()

function evaluate(x::DiagAtom)
return LinearAlgebra.diag(evaluate(x.children[1]), x.k)
end

## API begins
LinearAlgebra.diag(x::AbstractExpr, k::Int = 0) = DiagAtom(x, k)
## API ends

# Finds the "k"-th diagonal of x as a column vector
# If k == 0, it returns the main diagonal and so on
# Let x be of size m x n and d be the diagonal
# Finds the "k"-th diagonal of x as a column vector.
#
# If k == 0, it returns the main diagonal and so on.
#
# Let x be of size m x n and d be the diagonal.
#
# Since x is vectorized, the way canonicalization works is:
#
# 1. We calculate the size of the diagonal (sz_diag) and the first index
# of vectorized x that will be part of d
# of vectorized x that will be part of d
# 2. We create the coefficient matrix for vectorized x, called coeff of size
# sz_diag x mn
# sz_diag x mn
# 3. We populate coeff with 1s at the correct indices
# The canonical form will then be:
# coeff * x - d = 0
#
# The canonical form will then be: coeff * x - d = 0
function new_conic_form!(context::Context{T}, x::DiagAtom) where {T}
(num_rows, num_cols) = x.children[1].size
k = x.k

if k >= 0
start_index = k * num_rows + 1
sz_diag = Base.min(num_rows, num_cols - k)
num_rows, num_cols = x.children[1].size
if x.k >= 0
start_index = x.k * num_rows + 1
sz_diag = Base.min(num_rows, num_cols - x.k)
else
start_index = -k + 1
sz_diag = Base.min(num_rows + k, num_cols)
start_index = -x.k + 1
sz_diag = Base.min(num_rows + x.k, num_cols)

Check warning on line 51 in src/atoms/affine/diag.jl

View check run for this annotation

Codecov / codecov/patch

src/atoms/affine/diag.jl#L50-L51

Added lines #L50 - L51 were not covered by tests
end

select_diag = spzeros(T, sz_diag, length(x.children[1]))
for i in 1:sz_diag
select_diag[i, start_index] = 1
start_index += num_rows + 1
end

child_obj = conic_form!(context, only(AbstractTrees.children(x)))
obj = operate(add_operation, T, sign(x), select_diag, child_obj)
return obj
return operate(add_operation, T, sign(x), select_diag, child_obj)
end
48 changes: 13 additions & 35 deletions src/atoms/affine/diagm.jl
Original file line number Diff line number Diff line change
@@ -1,71 +1,49 @@
#############################################################################
# diagm.jl
# Converts a vector of size n into an n x n diagonal
# All expressions and atoms are subtpyes of AbstractExpr.
# Please read expressions.jl first.
#############################################################################

mutable struct DiagMatrixAtom <: AbstractExpr
children::Tuple{AbstractExpr}
size::Tuple{Int,Int}

function DiagMatrixAtom(x::AbstractExpr)
(num_rows, num_cols) = x.size

num_rows, num_cols = x.size
if num_rows == 1
sz = num_cols
return new((x,), (num_cols, num_cols))

Check warning on line 8 in src/atoms/affine/diagm.jl

View check run for this annotation

Codecov / codecov/patch

src/atoms/affine/diagm.jl#L8

Added line #L8 was not covered by tests
elseif num_cols == 1
sz = num_rows
return new((x,), (num_rows, num_rows))
else
throw(
ArgumentError(
"Only vectors are allowed for diagm/Diagonal. Did you mean to use diag?",
),
)
msg = "Only vectors are allowed for diagm/Diagonal. Did you mean to use diag?"
throw(ArgumentError(msg))
end

children = (x,)
return new(children, (sz, sz))
end
end

head(io::IO, ::DiagMatrixAtom) = print(io, "diagm")

function Base.sign(x::DiagMatrixAtom)
return sign(x.children[1])
end
Base.sign(x::DiagMatrixAtom) = sign(x.children[1])

# The monotonicity
function monotonicity(x::DiagMatrixAtom)
return (Nondecreasing(),)
end
monotonicity(::DiagMatrixAtom) = (Nondecreasing(),)

# If we have h(x) = f o g(x), the chain rule says h''(x) = g'(x)^T f''(g(x))g'(x) + f'(g(x))g''(x);
# this represents the first term
function curvature(x::DiagMatrixAtom)
return ConstVexity()
end
curvature(::DiagMatrixAtom) = ConstVexity()

function evaluate(x::DiagMatrixAtom)
return LinearAlgebra.Diagonal(vec(evaluate(x.children[1])))
end

function LinearAlgebra.diagm((d, x)::Pair{<:Integer,<:AbstractExpr})
d == 0 || throw(ArgumentError("only the main diagonal is supported"))
if d != 0
throw(ArgumentError("only the main diagonal is supported"))

Check warning on line 32 in src/atoms/affine/diagm.jl

View check run for this annotation

Codecov / codecov/patch

src/atoms/affine/diagm.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
end
return DiagMatrixAtom(x)
end

LinearAlgebra.Diagonal(x::AbstractExpr) = DiagMatrixAtom(x)

LinearAlgebra.diagm(x::AbstractExpr) = DiagMatrixAtom(x)

function new_conic_form!(context::Context{T}, x::DiagMatrixAtom) where {T}
obj = conic_form!(context, only(AbstractTrees.children(x)))

sz = x.size[1]
I = collect(1:sz+1:sz*sz)
J = collect(1:sz)
V = one(T)
coeff = create_sparse(T, I, J, V, sz * sz, sz)
# coeff = create_sparse(, 1:sz, one(T),

return operate(add_operation, T, sign(x), coeff, obj)
end
Loading

0 comments on commit dda073e

Please sign in to comment.