Skip to content

Commit

Permalink
Fix sum(::EinExpr) method
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 7, 2024
1 parent ebb683d commit 0e68598
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/EinExpr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Explicit sum over `indices`.
See [`sum!`](@ref) for inplace modification.
"""
function Base.sum(path::EinExpr{L}, inds::Union{L,AbstractVecOrTuple{L}}) where {L}
function Base.sum(path::EinExpr{L}, inds::AbstractVecOrTuple{L}) where {L}
i = .!isdisjoint.((inds,), head.(args(path)))

subinds = head.(args(path)[findall(i)])
Expand All @@ -215,6 +215,7 @@ function Base.sum(path::EinExpr{L}, inds::Union{L,AbstractVecOrTuple{L}}) where

return EinExpr(head(path), (EinExpr(suboutput, args(path)[findall(i)]), args(path)[findall(.!i)]...))
end
Base.sum(path::EinExpr{L}, inds::L) where {L} = sum(path, (inds,))

"""
sum(tensors::Vector{EinExpr}; skip = [])
Expand Down Expand Up @@ -287,3 +288,17 @@ AbstractTrees.ParentLinks(::Type{EinExpr}) = ImplicitParents()
AbstractTrees.SiblingLinks(::Type{EinExpr}) = ImplicitSiblings()
AbstractTrees.ChildIndexing(::Type{EinExpr}) = IndexedChildren()
AbstractTrees.NodeType(::Type{EinExpr}) = HasNodeType()

# Utils
function sumtraces(path::EinExpr)
do_not_contract_inds = hyperinds(path) path.head
_args = map(path.args) do arg
selfinds = nonunique(arg.head)
isempty(selfinds) && return arg

skip = selfinds do_not_contract_inds
sum([arg]; skip)
end

EinExpr(path.head, _args)
end

0 comments on commit 0e68598

Please sign in to comment.