Skip to content

Commit

Permalink
improve performance of conv (#634)
Browse files Browse the repository at this point in the history
* improve performance of `conv`

* tweak

* rm comment

* format

* Apply suggestions from code review

---------

Co-authored-by: Oscar Dowson <odow@users.noreply.github.com>
  • Loading branch information
ericphanson and odow authored May 7, 2024
1 parent 96904be commit f2b37dd
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
30 changes: 26 additions & 4 deletions src/reformulations/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,37 @@
# Use of this source code is governed by a BSD-style license that can be found
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause

"""
conv1D_matrix(h::AbstractVector, n::Integer) -> SparseMatrixCSC
Create a sparse matrix `A` such that if `x` has length `n`,
then we have `A * x ≈ conv1d(h, x)`.
"""
function conv1D_matrix(h::AbstractVector, n::Integer)
m = length(h)
Is = Int[]
Js = Int[]
Vs = eltype(h)[]
sizehint!(Is, n * m)
sizehint!(Js, n * m)
sizehint!(Vs, n * m)
# build matrix by columns
for j in 1:n
append!(Is, j:(j+m-1))
append!(Js, (j for _ in 1:m))
append!(Vs, h)
end
return SparseArrays.sparse(Is, Js, Vs, m + n - 1, n)
end

function conv(x::Value, y::AbstractExpr)
if length(x) != size(x, 1) || size(y, 2) > 1
error("convolution only supported between two vectors")
end
m, n = length(x), size(y, 1)
X = spzeros(eltype(x), m + n - 1, n)
for i in 1:n, j in 1:m
X[i+j-1, i] = x[j]
if length(x) == 0
throw(ArgumentError("convolution with empty vector not supported"))
end
X = conv1D_matrix(x, length(y))
return X * y
end

Expand Down
4 changes: 4 additions & 0 deletions test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,10 @@ function test_conv()
ErrorException("convolution only supported between two vectors"),
conv([1, 2], Variable(2, 2)),
)
@test_throws(
ArgumentError("convolution with empty vector not supported"),
conv([], Variable(2)),
)
return
end

Expand Down
25 changes: 19 additions & 6 deletions test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,23 +660,36 @@ function test_logsumexp_stability()
return
end

# simple 1D convolution implementation
function _conv(h, x)
m = length(h)
n = length(x)
zero_pad_x(i) = 1 <= i <= n ? x[i] : 0
return [sum(h[j] * zero_pad_x(i - j + 1) for j in 1:m) for i in 1:m+n-1]
end

function test_conv_issue_364()
n = 3
m = 11
h = rand(m)
x = rand(n)
hvar = Variable(m)
hvar.value = h
function _conv(h, x)
m = length(h)
n = length(x)
zero_pad_x(i) = 1 <= i <= n ? x[i] : 0
return [sum(h[j] * zero_pad_x(i - j + 1) for j in 1:m) for i in 1:m+n-1]
end
@test evaluate(conv(hvar, x)) _conv(h, x)
return
end

function test_conv1D_matrix()
for (x_len, y_len) in ((20, 5), (5, 20), (5, 5), (1, 1), (2, 3))
for im1 in (im, 0), im2 in (im, 0)
x = randn(x_len) + randn(x_len) * im1
y = randn(y_len) + randn(y_len) * im2
@test Convex.conv1D_matrix(x, length(y)) * y _conv(x, y)
end
end
return
end

function test_conj_issue_416()
A = [1 1im; -1im 1]
X = ComplexVariable(2, 2)
Expand Down

0 comments on commit f2b37dd

Please sign in to comment.