From dc4dd48798d35dfd5c1ef017dafb2bf6b523cb24 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Jun 2021 19:13:23 +0200 Subject: [PATCH] Improve `sinkhorn_gibbs` (#90) --- Project.toml | 5 +- src/OptimalTransport.jl | 129 ++++++++++++++++++++++------------------ src/utils.jl | 56 +++++++++++++++++ test/entropic.jl | 110 ++++++++++++++++++++-------------- test/gpu/simple_gpu.jl | 37 ++++++++++-- test/runtests.jl | 3 + test/utils.jl | 98 ++++++++++++++++++++++++++++++ 7 files changed, 329 insertions(+), 109 deletions(-) create mode 100644 src/utils.jl create mode 100644 test/utils.jl diff --git a/Project.toml b/Project.toml index 7868580e..e727f4d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.7" +version = "0.3.8" [deps] Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" @@ -25,6 +25,7 @@ StatsBase = "0.33.8" julia = "1" [extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PythonOT = "3c485715-4278-42b2-9b5f-8f00e43c12ef" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -33,4 +34,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tulip = "6dd1b50a-3aae-11e9-10b5-ef983d2400fa" [targets] -test = ["Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"] +test = ["ForwardDiff", "Pkg", "PythonOT", "Random", "SafeTestsets", "Test", "Tulip"] diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index e81c9802..b0fe52f7 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -22,15 +22,10 @@ export ot_cost, ot_plan, wasserstein, squared2wasserstein const MOI = MathOptInterface +include("utils.jl") include("exact.jl") include("wasserstein.jl") -dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y) -function dot_matwise(x::AbstractArray, y::AbstractMatrix) - xmat = reshape(x, size(x, 1) * size(x, 2), :) - return reshape(reshape(y, 1, :) * xmat, size(x)[3:end]) -end - """ sinkhorn_gibbs( μ, ν, K; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000 @@ -58,11 +53,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1)) The default `rtol` depends on the types of `μ`, `ν`, and `K`. After `maxiter` iterations, the computation is stopped. -Note that for a common kernel `K`, multiple histograms may be provided for a batch computation by passing `μ` and `ν` -as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms. -The output are then matrices `u` and `v` such that `u[:, i]` and `v[:, i]` are the dual variables for `μ[:, i]` and `ν[:, i]`. - -In addition, the case where one of `μ` or `ν` is a single histogram and the other a matrix of histograms is supported. +Batch computations for multiple histograms with a common Gibbs kernel `K` can be performed +by passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required +that the number of source and target marginals is equal or that a single source or single +target marginal is provided (either as matrix or as vector). The optimal transport plans are +returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the +`i`th pair of source and target marginals. """ function sinkhorn_gibbs( μ, @@ -87,43 +83,66 @@ function sinkhorn_gibbs( :sinkhorn_gibbs, ) end - if (size(μ, 2) != size(ν, 2)) && (min(size(μ, 2), size(ν, 2)) > 1) - throw( - DimensionMismatch( - "Error: number of columns in μ and ν must coincide, if both are matrix valued", - ), - ) - end - all(sum(μ; dims=1) .≈ sum(ν; dims=1)) || - throw(ArgumentError("source and target marginals must have the same mass")) + + # checks + size2 = checksize2(μ, ν) + checkbalanced(μ, ν) # set default values of tolerances T = float(Base.promote_eltype(μ, ν, K)) _atol = atol === nothing ? 0 : atol _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol - # initial iteration - u = if isequal(size(μ, 2), size(ν, 2)) - similar(μ) - else - repeat(similar(μ[:, 1]); outer=(1, max(size(μ, 2), size(ν, 2)))) + # initialize iterates + u = similar(μ, T, size(μ, 1), size2...) + v = similar(ν, T, size(ν, 1), size2...) + fill!(v, one(T)) + + # arrays for convergence check + Kv = similar(u) + mul!(Kv, K, v) + tmp = similar(u) + norm_μ = μ isa AbstractVector ? sum(abs, μ) : sum(abs, μ; dims=1) + if u isa AbstractMatrix + tmp2 = similar(u) + norm_uKv = similar(u, 1, size2...) + norm_diff = similar(u, 1, size2...) + _isconverged = similar(u, Bool, 1, size2...) end - u .= μ ./ vec(sum(K; dims=2)) - v = ν ./ (K' * u) - tmp1 = K * v - tmp2 = similar(u) - norm_μ = sum(abs, μ; dims=1) # for convergence check isconverged = false check_step = check_convergence === nothing ? 10 : check_convergence - for iter in 0:maxiter - if iter % check_step == 0 - # check source marginal - # do not overwrite `tmp1` but reuse it for computing `u` if not converged - @. tmp2 = u * tmp1 - norm_uKv = sum(abs, tmp2; dims=1) - @. tmp2 = μ - tmp2 - norm_diff = sum(abs, tmp2; dims=1) + to_check_step = check_step + for iter in 1:maxiter + # reduce counter + to_check_step -= 1 + + # compute next iterate + u .= μ ./ Kv + mul!(v, K', u) + v .= ν ./ v + mul!(Kv, K, v) + + # check source marginal + # always check convergence after the final iteration + if to_check_step <= 0 || iter == maxiter + # reset counter + to_check_step = check_step + + # do not overwrite `Kv` but reuse it for computing `u` if not converged + tmp .= u .* Kv + if u isa AbstractMatrix + tmp2 .= abs.(tmp) + sum!(norm_uKv, tmp2) + else + norm_uKv = sum(abs, tmp) + end + tmp .= abs.(μ .- tmp) + if u isa AbstractMatrix + sum!(norm_diff, tmp) + else + norm_diff = sum(tmp) + end @debug "Sinkhorn algorithm (" * string(iter) * @@ -133,20 +152,17 @@ function sinkhorn_gibbs( string(maximum(norm_diff)) # check stopping criterion - if all(@. norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv))) + isconverged = if u isa AbstractMatrix + @. _isconverged = norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv)) + all(_isconverged) + else + norm_diff < max(_atol, _rtol * max(norm_μ, norm_uKv)) + end + if isconverged @debug "Sinkhorn algorithm ($iter/$maxiter): converged" - isconverged = true break end end - - # perform next iteration - if iter < maxiter - @. u = μ / tmp1 - mul!(v, K', u) - @. v = ν / v - mul!(tmp1, K, v) - end end if !isconverged @@ -156,13 +172,6 @@ function sinkhorn_gibbs( return u, v end -function add_singleton(x::AbstractArray, ::Val{dim}) where {dim} - shape = ntuple(ndims(x) + 1) do i - return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1) - end - return reshape(x, shape) -end - """ sinkhorn( μ, ν, C, ε; atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000 @@ -188,10 +197,12 @@ isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1)) The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, the computation is stopped. -Note that for a common cost `C`, multiple histograms may be provided for a batch computation by passing `μ` and `ν` -as matrices whose columns `μ[:, i]` and `ν[:, i]` correspond to pairs of histograms. - -The output in this case is an `Array` `γ` of coupling matrices such that `γ[:, :, i]` is a coupling of `μ[:, i]` and `ν[:, i]`. +Batch computations for multiple histograms with a common cost matrix `C` can be performed by +passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that +the number of source and target marginals is equal or that a single source or single target +marginal is provided (either as matrix or as vector). The optimal transport plans are +returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the +`i`th pair of source and target marginals. See also: [`sinkhorn2`](@ref) """ diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 00000000..e079b325 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,56 @@ +""" + add_singleton(x::AbstractArray, ::Val{dim}) where {dim} + +Add an additional dimension `dim` of size 1 to array `x`. +""" +function add_singleton(x::AbstractArray, ::Val{dim}) where {dim} + shape = ntuple(max(ndims(x) + 1, dim)) do i + return i < dim ? size(x, i) : (i > dim ? size(x, i - 1) : 1) + end + return reshape(x, shape) +end + +""" + dot_matwise(x::AbstractArray, y::AbstractArray) + +Compute the inner product of all matrices in `x` and `y`. + +At least one of `x` and `y` has to be a matrix. +""" +dot_matwise(x::AbstractMatrix, y::AbstractMatrix) = dot(x, y) +function dot_matwise(x::AbstractArray, y::AbstractMatrix) + xmat = reshape(x, size(x, 1) * size(x, 2), :) + return reshape(reshape(y, 1, :) * xmat, size(x)[3:end]) +end +dot_matwise(x::AbstractMatrix, y::AbstractArray) = dot_matwise(y, x) + +""" + checksize2(x::AbstractVecOrMat, y::AbstractVecOrMat) + +Check if arrays `x` and `y` are compatible, then return a tuple of its broadcasted second +dimension. +""" +checksize2(::AbstractVector, ::AbstractVector) = () +function checksize2(μ::AbstractVecOrMat, ν::AbstractVecOrMat) + size_μ_2 = size(μ, 2) + size_ν_2 = size(ν, 2) + if size_μ_2 > 1 && size_ν_2 > 1 && size_μ_2 != size_ν_2 + throw(DimensionMismatch("size of source and target marginals is not compatible")) + end + return (max(size_μ_2, size_ν_2),) +end + +""" + checkbalanced(μ::AbstractVecOrMat, ν::AbstractVecOrMat) + +Check that source and target marginals `μ` and `ν` are balanced. +""" +function checkbalanced(μ::AbstractVector, ν::AbstractVector) + sum(μ) ≈ sum(ν) || throw(ArgumentError("source and target marginals are not balanced")) + return nothing +end +function checkbalanced(x::AbstractVecOrMat, y::AbstractVecOrMat) + all(isapprox.(sum(x; dims=1), sum(y; dims=1))) || + throw(ArgumentError("source and target marginals are not balanced")) + return nothing +end diff --git a/test/entropic.jl b/test/entropic.jl index 081d6644..1ba70182 100644 --- a/test/entropic.jl +++ b/test/entropic.jl @@ -1,6 +1,8 @@ using OptimalTransport using Distances +using ForwardDiff +using LogExpFunctions using PythonOT: PythonOT using Random @@ -23,7 +25,7 @@ Random.seed!(100) # create random cost matrix C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) - # compute optimal transport map (Julia implementation + POT) + # compute optimal transport plan (Julia implementation + POT) eps = 0.01 γ = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-9) γ_pot = POT.sinkhorn(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9) @@ -40,13 +42,34 @@ Random.seed!(100) c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-9)[1] @test c_pot ≈ c - # ensure that provided map is used and correct + # ensure that provided plan is used and correct c2 = sinkhorn2(similar(μ), similar(ν), C, rand(); plan=γ) @test c2 ≈ c c2_w_regularization = sinkhorn2( similar(μ), similar(ν), C, eps; plan=γ, regularization=true ) @test c2_w_regularization ≈ c_w_regularization + + # batches of histograms + d = 10 + for (size2_μ, size2_ν) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # generate uniform histograms + μ = fill(1 / M, (M, size2_μ...)) + ν = fill(1 / N, (N, size2_ν...)) + + # compute optimal transport plan and check that it is consistent with the + # plan for individual histograms + γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-9) + @test size(γ_all) == (M, N, d) + @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + + # compute optimal transport cost and check that it is consistent with the + # cost for individual histograms + c_all = sinkhorn2(μ, ν, C, eps; maxiter=5_000, rtol=1e-9) + @test size(c_all) == (d,) + @test all(x ≈ c for x in c_all) + end end # different element type @@ -58,7 +81,7 @@ Random.seed!(100) # create random cost matrix C = pairwise(SqEuclidean(), rand(Float32, 1, M), rand(Float32, 1, N); dims=2) - # compute optimal transport map (Julia implementation + POT) + # compute optimal transport plan (Julia implementation + POT) eps = 0.01f0 γ = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-6) @test eltype(γ) === Float32 @@ -80,52 +103,51 @@ Random.seed!(100) c_pot = POT.sinkhorn2(μ, ν, C, eps; numItermax=5_000, stopThr=1e-6)[1] @test Float32(c_pot) ≈ c rtol = 1e-3 - # batch + # batches of histograms d = 10 - μ = fill(Float32(1 / M), (M, d)) - ν = fill(Float32(1 / N), N) - - γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-6) - γ_pot = [ - POT.sinkhorn(μ[:, i], vec(ν), C, eps; numItermax=5_000, stopThr=1e-6) for - i in 1:d - ] - @test all([ - isapprox(Float32.(γ_pot[i]), γ_all[:, :, i]; rtol=1e-3) for i in 1:d - ]) - @test eltype(γ_all) == Float32 + for (size2_μ, size2_ν) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # generate uniform histograms + μ = fill(Float32(1 / M), (M, size2_μ...)) + ν = fill(Float32(1 / N), (N, size2_ν...)) + + # compute optimal transport plan and check that it is consistent with the + # plan for individual histograms + γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000, rtol=1e-6) + @test eltype(γ_all) === Float32 + @test size(γ_all) == (M, N, d) + @test all(view(γ_all, :, :, i) ≈ γ for i in axes(γ_all, 3)) + + # compute optimal transport cost and check that it is consistent with the + # cost for individual histograms + c_all = sinkhorn2(μ, ν, C, eps; maxiter=5_000, rtol=1e-6) + @test eltype(c_all) === Float32 + @test size(c_all) == (d,) + @test all(x ≈ c for x in c_all) + end end - @testset "batch" begin - # create two sets of batch histograms - d = 10 - μ = rand(Float64, (M, d)) - μ = μ ./ sum(μ; dims=1) - ν = rand(Float64, (N, d)) - ν = ν ./ sum(ν; dims=1) - - # create random cost matrix + # https://github.com/JuliaOptimalTransport/OptimalTransport.jl/issues/86 + @testset "AD" begin + # uniform histograms with random cost matrix + μ = fill(1 / M, M) + ν = fill(1 / N, N) C = pairwise(SqEuclidean(), rand(1, M), rand(1, N); dims=2) - # compute optimal transport map (Julia implementation + POT) - eps = 0.01 - γ_all = sinkhorn(μ, ν, C, eps; maxiter=5_000) - γ_pot = [POT.sinkhorn(μ[:, i], ν[:, i], C, eps; numItermax=5_000) for i in 1:d] - @test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d]) - - c_all = sinkhorn2(μ, ν, C, eps; maxiter=5_000) - c_pot = [ - POT.sinkhorn2(μ[:, i], ν[:, i], C, eps; numItermax=5_000)[1] for i in 1:d - ] - @test c_all ≈ c_pot rtol = 1e-6 - - γ_all = sinkhorn(μ[:, 1], ν, C, eps; maxiter=5_000) - γ_pot = [POT.sinkhorn(μ[:, 1], ν[:, i], C, eps; numItermax=5_000) for i in 1:d] - @test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d]) - - γ_all = sinkhorn(μ, ν[:, 1], C, eps; maxiter=5_000) - γ_pot = [POT.sinkhorn(μ[:, i], ν[:, 1], C, eps; numItermax=5_000) for i in 1:d] - @test all([isapprox(γ_all[:, :, i], γ_pot[i]; rtol=1e-6) for i in 1:d]) + # compute gradients with respect to source and target marginals separately and + # together + ε = 0.01 + ForwardDiff.gradient(zeros(N)) do xs + sinkhorn2(μ, softmax(xs), C, ε; regularization=true) + end + ForwardDiff.gradient(zeros(M)) do xs + sinkhorn2(softmax(xs), ν, C, ε; regularization=true) + end + ForwardDiff.gradient(zeros(M + N)) do xs + sinkhorn2( + softmax(xs[1:M]), softmax(xs[(M + 1):end]), C, ε; regularization=true + ) + end end @testset "deprecations" begin diff --git a/test/gpu/simple_gpu.jl b/test/gpu/simple_gpu.jl index c2284dff..d82780ae 100644 --- a/test/gpu/simple_gpu.jl +++ b/test/gpu/simple_gpu.jl @@ -20,23 +20,52 @@ Random.seed!(100) m = 200 μ = rand(Float32, m) μ ./= sum(μ) + cu_μ = cu(μ) # target histogram n = 250 ν = rand(Float32, n) ν ./= sum(ν) + cu_ν = cu(ν) # random cost matrix C = pairwise(SqEuclidean(), randn(Float32, 1, m), randn(Float32, 1, n); dims=2) + cu_C = cu(C) # compute transport plan and cost on the GPU ε = 0.01f0 - γ = sinkhorn(cu(μ), cu(ν), cu(C), ε) - c = sinkhorn2(cu(μ), cu(ν), cu(C), ε) + γ = sinkhorn(cu_μ, cu_ν, cu_C, ε) + @test γ isa CuArray{Float32,2} + c = sinkhorn2(cu_μ, cu_ν, cu_C, ε) + @test c isa Float32 # compare with results on the CPU - @test γ ≈ cu(sinkhorn(μ, ν, C, ε)) - @test c ≈ cu(sinkhorn2(μ, ν, C, ε)) + γ_cpu = sinkhorn(μ, ν, C, ε) + @test convert(Array, γ) ≈ γ_cpu + @test c ≈ sinkhorn2(μ, ν, C, ε) + + # batches of histograms + d = 10 + for (size2_μ, size2_ν) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + # generate uniform histograms + μ_batch = repeat(cu_μ, 1, size2_μ...) + ν_batch = repeat(cu_ν, 1, size2_ν...) + + # compute optimal transport plan and check that it is consistent with the + # plan for individual histograms + γ_all = sinkhorn(μ_batch, ν_batch, cu_C, ε) + @test γ_all isa CuArray{Float32,3} + @test size(γ_all) == (m, n, d) + @test all(γi ≈ γ_cpu for γi in eachslice(convert(Array, γ_all); dims=3)) + + # compute optimal transport cost and check that it is consistent with the + # cost for individual histograms + c_all = sinkhorn2(μ_batch, ν_batch, cu_C, ε) + @test c_all isa CuArray{Float32,1} + @test size(c_all) == (d,) + @test all(ci ≈ c for ci in convert(Array, c_all)) + end end @testset "sinkhorn_unbalanced" begin diff --git a/test/runtests.jl b/test/runtests.jl index ab84a013..771f9b79 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,9 @@ const GROUP = get(ENV, "GROUP", "All") @testset "OptimalTransport" begin if GROUP == "All" || GROUP == "OptimalTransport" + @safetestset "Utilities" begin + include("utils.jl") + end @safetestset "Exact OT" begin include("exact.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..3a81df4d --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,98 @@ +using OptimalTransport + +using LinearAlgebra +using Random +using Test + +Random.seed!(100) + +@testset "utils.jl" begin + @testset "add_singleton" begin + x = rand(3) + y = @inferred(OptimalTransport.add_singleton(x, Val(1))) + @test size(y) == (1, length(x)) + @test vec(y) == x + + y = @inferred(OptimalTransport.add_singleton(x, Val(2))) + @test size(y) == (length(x), 1) + @test vec(y) == x + + x = rand(3, 4) + y = @inferred(OptimalTransport.add_singleton(x, Val(1))) + @test size(y) == (1, size(x, 1), size(x, 2)) + @test vec(y) == vec(x) + + y = @inferred(OptimalTransport.add_singleton(x, Val(2))) + @test size(y) == (size(x, 1), 1, size(x, 2)) + @test vec(y) == vec(x) + + y = @inferred(OptimalTransport.add_singleton(x, Val(3))) + @test size(y) == (size(x, 1), size(x, 2), 1) + @test vec(y) == vec(x) + end + + @testset "dot_matwise" begin + l, m, n = 4, 5, 3 + x = rand(l, m) + y = rand(l, m) + @test OptimalTransport.dot_matwise(x, y) == dot(x, y) + + y = rand(l, m, n) + @test OptimalTransport.dot_matwise(x, y) ≈ + mapreduce(vcat, (view(y, :, :, i) for i in axes(y, 3))) do yi + dot(x, yi) + end + @test OptimalTransport.dot_matwise(y, x) == OptimalTransport.dot_matwise(x, y) + end + + @testset "checksize2" begin + x = rand(5) + y = rand(10) + @test OptimalTransport.checksize2(x, y) === () + + d = 4 + for (size2_x, size2_y) in + (((), (d,)), ((1,), (d,)), ((d,), ()), ((d,), (1,)), ((d,), (d,))) + x = rand(5, size2_x...) + y = rand(10, size2_y...) + @test OptimalTransport.checksize2(x, y) == (d,) + end + + x = rand(5, 4) + y = rand(10, 3) + @test_throws DimensionMismatch OptimalTransport.checksize2(x, y) + end + + @testset "checkbalanced" begin + mass = rand() + + x1 = rand(20) + x1 .*= mass / sum(x1) + y1 = rand(30) + y1 .*= mass / sum(y1) + @test OptimalTransport.checkbalanced(x1, y1) === nothing + @test OptimalTransport.checkbalanced(y1, x1) === nothing + @test_throws ArgumentError OptimalTransport.checkbalanced(rand() .* x1, y1) + @test_throws ArgumentError OptimalTransport.checkbalanced(x1, rand() .* y1) + + y2 = rand(30, 5) + y2 .*= mass ./ sum(y2; dims=1) + @test OptimalTransport.checkbalanced(x1, y2) === nothing + @test OptimalTransport.checkbalanced(y2, x1) === nothing + @test_throws ArgumentError OptimalTransport.checkbalanced(rand() .* x1, y2) + @test_throws ArgumentError OptimalTransport.checkbalanced( + x1, y2 .* hcat(rand(), ones(1, size(y2, 2) - 1)) + ) + + x2 = rand(20, 5) + x2 .*= mass ./ sum(x2; dims=1) + @test OptimalTransport.checkbalanced(x2, y2) === nothing + @test OptimalTransport.checkbalanced(y2, x2) === nothing + @test_throws ArgumentError OptimalTransport.checkbalanced( + x2 .* hcat(ones(1, size(x2, 2) - 1), rand()), y2 + ) + @test_throws ArgumentError OptimalTransport.checkbalanced( + x2, y2 .* hcat(rand(), ones(1, size(y2, 2) - 1)) + ) + end +end