Skip to content

Commit

Permalink
Refactor remaining non-scalar tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 9, 2024
1 parent a724e12 commit 6010711
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 78 deletions.
42 changes: 22 additions & 20 deletions test/MatrixFields/matrix_field_broadcasting.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
#=
julia --project
ENV["CLIMACOMMS_DEVICE"] = "CPU"
ENV["BUILDKITE"] = "true" # to also run opt tests
using Revise; include(joinpath("test", "MatrixFields", "matrix_field_broadcasting.jl"))
=#
using Test

#! format: off
@testset "Scalar Matrix Field Broadcasting" begin
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_1.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_2.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_3.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_4.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_5.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_6.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_7.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_8.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_9.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_10.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_11.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_12.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_13.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_14.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_15.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_16.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_1.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_2.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_3.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_4.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_5.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_6.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_7.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_8.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_9.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_10.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_11.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_12.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_13.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_14.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_15.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_scalar_16.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc()
end

@testset "Non-scalar Matrix Field Broadcasting" begin
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_1.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_2.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_3.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_4.jl")); @info "mem usage" rss = Sys.maxrss() / 2^30
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_1.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_2.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_3.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc(); include(joinpath("matrix_fields_broadcasting", "test_non_scalar_4.jl")); @info "mem usage: rss = $(Sys.maxrss() / 2^30)"
GC.gc()
end
#! format: on
54 changes: 35 additions & 19 deletions test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,37 @@
if !(@isdefined(test_field_broadcast))
include("test_non_scalar_utils.jl")
#=
julia --project
using Revise; include(joinpath("test", "MatrixFields", "matrix_fields_broadcasting", "test_non_scalar_1.jl"))
=#
import ClimaCore
#! format: off
if !(@isdefined(unit_test_field_broadcast))
include(joinpath(pkgdir(ClimaCore),"test","MatrixFields","matrix_fields_broadcasting","test_non_scalar_utils.jl"))
end

test_field_broadcast(;
test_name = "matrix of covectors times matrix of vectors times matrix \
#! format: on
test_opt = get(ENV, "BUILDKITE", "") == "true"
@testset "matrix of covectors times matrix of vectors times matrix \
of numbers times matrix of covectors times matrix of \
vectors",
get_result = () ->
(@. ᶜᶠmat_AC1 ᶠᶜmat_C12 ᶜᶠmat ᶠᶜmat_AC1 ᶜᶠmat_C12),
set_result! = result ->
(@. result = ᶜᶠmat_AC1 ᶠᶜmat_C12 ᶜᶠmat ᶠᶜmat_AC1 ᶜᶠmat_C12),
ref_set_result! = result -> (@. result =
ᶜᶠmat (
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ᶠᶜmat2 +
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ᶠᶜmat3
) ᶜᶠmat ᶠᶜmat (
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ᶜᶠmat2 +
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ᶜᶠmat3
)),
)
vectors" begin

bc = @lazy @. ᶜᶠmat_AC1 ᶠᶜmat_C12 ᶜᶠmat ᶠᶜmat_AC1 ᶜᶠmat_C12
result = materialize(bc)

ref_set_result! =
result -> (@. result =
ᶜᶠmat (
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ᶠᶜmat2 +
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ᶠᶜmat3
) ᶜᶠmat ᶠᶜmat (
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ᶜᶠmat2 +
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ᶜᶠmat3
))

unit_test_field_broadcast(
result,
bc;
ref_set_result!,
allowed_max_eps_error = 10,
)

test_opt && opt_test_field_broadcast(result, bc; ref_set_result!)
end
58 changes: 37 additions & 21 deletions test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
if !(@isdefined(test_field_broadcast))
include("test_non_scalar_utils.jl")
#=
julia --project
using Revise; include(joinpath("test", "MatrixFields", "matrix_fields_broadcasting", "test_non_scalar_1.jl"))
=#
import ClimaCore
#! format: off
if !(@isdefined(unit_test_field_broadcast))
include(joinpath(pkgdir(ClimaCore),"test","MatrixFields","matrix_fields_broadcasting","test_non_scalar_utils.jl"))
end

test_field_broadcast(;
test_name = "matrix of covectors and numbers times matrix of vectors \
#! format: on
test_opt = get(ENV, "BUILDKITE", "") == "true"
@testset "matrix of covectors and numbers times matrix of vectors \
and covectors times matrix of numbers and vectors times \
vector of numbers",
get_result = () ->
(@. ᶜᶠmat_AC1_num ᶠᶜmat_C12_AC1 ᶜᶠmat_num_C12 ᶠvec),
set_result! = result ->
(@. result = ᶜᶠmat_AC1_num ᶠᶜmat_C12_AC1 ᶜᶠmat_num_C12 ᶠvec),
ref_set_result! = result -> (@. result = tuple(
ᶜᶠmat (
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ᶠᶜmat2 +
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ᶠᶜmat3
) ᶜᶠmat ᶠvec,
ᶜᶠmat ᶠᶜmat (
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ᶜᶠmat2 +
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ᶜᶠmat3
) ᶠvec,
)),
)
vector of numbers" begin

bc = @lazy @. ᶜᶠmat_AC1_num ᶠᶜmat_C12_AC1 ᶜᶠmat_num_C12 ᶠvec
result = materialize(bc)

ref_set_result! =
result -> (@. result = tuple(
ᶜᶠmat (
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ᶠᶜmat2 +
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ᶠᶜmat3
) ᶜᶠmat ᶠvec,
ᶜᶠmat ᶠᶜmat (
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ᶜᶠmat2 +
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ᶜᶠmat3
) ᶠvec,
))

unit_test_field_broadcast(
result,
bc;
ref_set_result!,
allowed_max_eps_error = 10,
)

test_opt && opt_test_field_broadcast(result, bc; ref_set_result!)
end
47 changes: 32 additions & 15 deletions test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
if !(@isdefined(test_field_broadcast))
include("test_non_scalar_utils.jl")
#=
julia --project
using Revise; include(joinpath("test", "MatrixFields", "matrix_fields_broadcasting", "test_non_scalar_1.jl"))
=#
import ClimaCore
#! format: off
if !(@isdefined(unit_test_field_broadcast))
include(joinpath(pkgdir(ClimaCore),"test","MatrixFields","matrix_fields_broadcasting","test_non_scalar_utils.jl"))
end

test_field_broadcast(;
test_name = "matrix of nested values times matrix of nested values \
#! format: on
test_opt = get(ENV, "BUILDKITE", "") == "true"
@testset "matrix of nested values times matrix of nested values \
times matrix of numbers times matrix of numbers times \
vector of nested values",
get_result = () -> (@. ᶜᶠmat_NT ᶠᶜmat ᶜᶠmat ᶠᶜmat_NT ᶜvec_NT),
set_result! = result ->
(@. result = ᶜᶠmat_NT ᶠᶜmat ᶜᶠmat ᶠᶜmat_NT ᶜvec_NT),
ref_set_result! = result -> (@. result = nested_type(
ᶜᶠmat ᶠᶜmat ᶜᶠmat ᶠᶜmat ᶜvec,
ᶜᶠmat2 ᶠᶜmat ᶜᶠmat ᶠᶜmat2 ᶜvec,
ᶜᶠmat3 ᶠᶜmat ᶜᶠmat ᶠᶜmat3 ᶜvec,
)),
)
vector of nested values" begin

bc = @lazy @. ᶜᶠmat_NT ᶠᶜmat ᶜᶠmat ᶠᶜmat_NT ᶜvec_NT
result = materialize(bc)

ref_set_result! =
result -> (@. result = nested_type(
ᶜᶠmat ᶠᶜmat ᶜᶠmat ᶠᶜmat ᶜvec,
ᶜᶠmat2 ᶠᶜmat ᶜᶠmat ᶠᶜmat2 ᶜvec,
ᶜᶠmat3 ᶠᶜmat ᶜᶠmat ᶠᶜmat3 ᶜvec,
))

unit_test_field_broadcast(
result,
bc;
ref_set_result!,
allowed_max_eps_error = 10,
)

test_opt && opt_test_field_broadcast(result, bc; ref_set_result!)
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ UnitTest("MatrixFields - broadcasting (14)" ,"MatrixFields/matrix_fields_
UnitTest("MatrixFields - broadcasting (15)" ,"MatrixFields/matrix_fields_broadcasting/test_scalar_15.jl"),
UnitTest("MatrixFields - broadcasting (16)" ,"MatrixFields/matrix_fields_broadcasting/test_scalar_16.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (1)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_1.jl"),
# UnitTest("MatrixFields - non-scalar broadcasting (2)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl"),
# UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"),
# UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (2)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
# UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long
# UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long
# UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long
Expand Down

0 comments on commit 6010711

Please sign in to comment.