From 1fad981ae9cc1783bf0cb2ef0e1381cafbc864bc Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Thu, 18 Apr 2019 22:40:45 +0800 Subject: [PATCH] update --- src/operations.jl | 17 +++++++++-------- test/operations.jl | 3 +++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/operations.jl b/src/operations.jl index c116572..a52ed1e 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -25,36 +25,37 @@ end # basic arithmatics # neg -Base.:-(reg::ArrayRegOrAdjointArrayReg) = ArrayReg(-state(reg)) +Base.:-(reg::ArrayReg) = ArrayReg(-state(reg)) # +, - for op in [:+, :-] - @eval function Base.$op(lhs::ArrayRegOrAdjointArrayReg{B}, rhs::ArrayRegOrAdjointArrayReg{B}) where B + @eval function Base.$op(lhs::ArrayReg{B}, rhs::ArrayReg{B}) where B return ArrayReg(($op)(state(lhs), state(rhs))) end end # *, / for op in [:*, :/] - @eval function Base.$op(lhs::RT, rhs::Number) where {B, RT <: ArrayRegOrAdjointArrayReg{B}} + @eval function Base.$op(lhs::RT, rhs::Number) where {B, RT <: ArrayReg{B}} ArrayReg{B}($op(state(lhs), rhs)) end if op == :* - @eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: ArrayRegOrAdjointArrayReg{B}} + @eval function Base.$op(lhs::Number, rhs::RT) where {B, RT <: ArrayReg{B}} ArrayReg{B}(($op)(lhs, state(rhs))) end end end for op in [:(==), :≈] - @eval function Base.$op(lhs::ArrayRegOrAdjointArrayReg{B}, rhs::ArrayRegOrAdjointArrayReg{B}) where B - ($op)(state(lhs), state(rhs)) + for AT in [:ArrayReg, :AdjointArrayReg] + @eval function Base.$op(lhs::$AT, rhs::$AT) + ($op)(state(lhs), state(rhs)) + end end end -Base.:*(op::AbstractMatrix, r::ArrayRegOrAdjointArrayReg) = op * state(r) -Base.:*(bra::AdjointArrayReg{1}, ket::ArrayReg{1}) = dot(state(bra), state(ket)) +Base.:*(bra::AdjointArrayReg{1}, ket::ArrayReg{1}) = dot(parent(bra).state, state(ket)) Base.:*(bra::AdjointArrayReg{B}, ket::ArrayReg{B}) where B = bra .* ket # broadcast diff --git a/test/operations.jl b/test/operations.jl index a87e0f0..9288866 100644 --- a/test/operations.jl +++ b/test/operations.jl @@ -22,4 +22,7 @@ end reg = rand_state(4) @test all(state(reg + (-reg)).==0) @test all(state(reg*2 - reg/0.5) .== 0) + + reg = rand_state(3) + @test reg'*reg ≈ 1 end