diff --git a/Project.toml b/Project.toml index ac70b392..a617446f 100644 --- a/Project.toml +++ b/Project.toml @@ -43,7 +43,7 @@ Reexport = "1" SparseArrays = "1.8" Statistics = "1" StatsBase = "0.34" -TensorCast = "0.3 - 0.4" +TensorCast = "0.4" TensorTrains = "0.7, 0.8" Tullio = "0.3" UnPack = "1" diff --git a/src/mpems.jl b/src/mpems.jl index f96bf4d3..0eedc9f9 100644 --- a/src/mpems.jl +++ b/src/mpems.jl @@ -71,9 +71,9 @@ function mpem2(B::MPEM3{F}) where {F} for t in Iterators.take(eachindex(B), length(B)-1) U, λ, V = svd(M) m = length(λ) - @cast Cᵗ[m, k, xᵢᵗ, xⱼᵗ] := U[(xᵢᵗ, xⱼᵗ, m), k] k:m, xᵢᵗ:qᵢᵗ, xⱼᵗ:qⱼᵗ + @cast Cᵗ[m, k, xᵢᵗ, xⱼᵗ] := U[(xᵢᵗ, xⱼᵗ, m), k] k∈1:m, xᵢᵗ∈1:qᵢᵗ, xⱼᵗ∈1:qⱼᵗ C[t] = Cᵗ - @cast Vt[m, n, xᵢᵗ⁺¹] := V'[m, (n, xᵢᵗ⁺¹)] xᵢᵗ⁺¹:qᵢᵗ⁺¹ + @cast Vt[m, n, xᵢᵗ⁺¹] := V'[m, (n, xᵢᵗ⁺¹)] xᵢᵗ⁺¹∈1:qᵢᵗ⁺¹ Bᵗ⁺¹ = B[t+1] @tullio Bᵗ⁺¹_new[m, n, xᵢᵗ⁺¹, xⱼᵗ⁺¹, xᵢᵗ⁺²] := λ[m] * Vt[m, l, xᵢᵗ⁺¹] * Bᵗ⁺¹[l, n, xᵢᵗ⁺¹, xⱼᵗ⁺¹, xᵢᵗ⁺²] @@ -119,9 +119,9 @@ function mpem2(B::PeriodicMPEM3{F}) where {F} for t in eachindex(B) U, λ, V = svd(M) m = length(λ) - @cast Cᵗ[m, k, xᵢᵗ, xⱼᵗ] := U[(xᵢᵗ, xⱼᵗ, m), k] k:m, xᵢᵗ:qᵢᵗ, xⱼᵗ:qⱼᵗ + @cast Cᵗ[m, k, xᵢᵗ, xⱼᵗ] := U[(xᵢᵗ, xⱼᵗ, m), k] k∈1:m, xᵢᵗ∈1:qᵢᵗ, xⱼᵗ∈1:qⱼᵗ C[t] = Cᵗ - @cast Vt[m, n, xᵢᵗ⁺¹] := V'[m, (n, xᵢᵗ⁺¹)] xᵢᵗ⁺¹:qᵢᵗ⁺¹ + @cast Vt[m, n, xᵢᵗ⁺¹] := V'[m, (n, xᵢᵗ⁺¹)] xᵢᵗ⁺¹∈1:qᵢᵗ⁺¹ if t < length(B) Bᵗ⁺¹ = B[t+1] @tullio Bᵗ⁺¹_new[m, n, xᵢᵗ⁺¹, xⱼᵗ⁺¹, xᵢᵗ⁺²] := λ[m] *