Skip to content

Commit dfa0f3d

Browse files
feat: re-implement co_sum
1 parent 7c1b65a commit dfa0f3d

File tree

2 files changed

+109
-7
lines changed

2 files changed

+109
-7
lines changed

src/caffeine/collective_subroutines/co_sum_s.f90

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,20 @@
77
contains
88

99
module procedure prif_co_sum
10-
call unimplemented("prif_co_sum")
10+
call contiguous_co_sum(a, result_image, stat, errmsg, errmsg_alloc)
1111
end procedure
1212

13+
subroutine contiguous_co_sum(a, result_image, stat, errmsg, errmsg_alloc)
14+
type(*), intent(inout), target, contiguous :: a(..)
15+
integer(c_int), intent(in), optional :: result_image
16+
integer(c_int), intent(out), optional :: stat
17+
character(len=*), intent(inout), optional :: errmsg
18+
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc
19+
20+
if (present(stat)) stat=0
21+
22+
call caf_co_sum( &
23+
a, optional_value(result_image), int(product(shape(a)), c_size_t), current_team%info%gex_team)
24+
end subroutine
25+
1326
end submodule co_sum_s

test/caf_co_sum_test.f90

Lines changed: 95 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
module caf_co_sum_test
2+
use iso_c_binding, only: c_int32_t, c_int64_t, c_float, c_double
23
use prif, only : prif_co_sum, prif_num_images, prif_this_image_no_coarray
34
use veggies, only: result_t, test_item_t, assert_equals, describe, it, succeed
45

@@ -23,32 +24,120 @@ function test_prif_co_sum() result(tests)
2324

2425
function check_32_bit_integer() result(result_)
2526
type(result_t) :: result_
26-
result_ = succeed("temporarily")
27+
28+
integer(c_int32_t), parameter :: values(*) = [1, 19, 5, 13, 11, 7, 17, 3]
29+
integer :: me, ni, i
30+
integer(c_int32_t) :: my_val, expected
31+
32+
call prif_this_image_no_coarray(this_image=me)
33+
call prif_num_images(ni)
34+
35+
my_val = values(mod(me-1, size(values))+1)
36+
call prif_co_sum(my_val)
37+
38+
expected = sum([(values(mod(i-1,size(values))+1), i = 1, ni)])
39+
result_ = assert_equals(expected, my_val)
2740
end function
2841

2942
function check_64_bit_integer() result(result_)
3043
type(result_t) :: result_
31-
result_ = succeed("temporarily")
44+
45+
integer(c_int64_t), parameter :: values(*,*) = reshape([1, 19, 5, 13, 11, 7, 17, 3], [2, 4])
46+
integer :: me, ni, i
47+
integer(c_int64_t), dimension(size(values,1)) :: my_val, expected
48+
49+
call prif_this_image_no_coarray(this_image=me)
50+
call prif_num_images(ni)
51+
52+
my_val = values(:, mod(me-1, size(values,2))+1)
53+
call prif_co_sum(my_val)
54+
55+
expected = sum(reshape([(values(:, mod(i-1,size(values,2))+1), i = 1, ni)], [size(values,1),ni]), dim=2)
56+
result_ = assert_equals(int(expected), int(my_val))
3257
end function
3358

3459
function check_32_bit_real() result(result_)
3560
type(result_t) :: result_
36-
result_ = succeed("temporarily")
61+
62+
real(c_float), parameter :: values(*,*,*) = reshape([1, 19, 5, 13, 11, 7, 17, 3], [2,2,2])
63+
integer :: me, ni, i
64+
real(c_float), dimension(size(values,1), size(values,2)) :: my_val, expected
65+
66+
call prif_this_image_no_coarray(this_image=me)
67+
call prif_num_images(ni)
68+
69+
my_val = values(:, :, mod(me-1, size(values,3))+1)
70+
call prif_co_sum(my_val)
71+
72+
expected = sum(reshape([(values(:,:,mod(i-1,size(values,3))+1), i = 1, ni)], [size(values,1), size(values,2), ni]), dim=3)
73+
result_ = assert_equals(real(expected,kind=c_double), real(my_val,kind=c_double))
3774
end function
3875

3976
function check_64_bit_real() result(result_)
4077
type(result_t) :: result_
41-
result_ = succeed("temporarily")
78+
79+
real(c_double), parameter :: values(*,*) = reshape([1, 19, 5, 13, 11, 7, 17, 3], [2, 4])
80+
integer :: me, ni, i
81+
real(c_double), dimension(size(values,1)) :: my_val, expected
82+
83+
call prif_this_image_no_coarray(this_image=me)
84+
call prif_num_images(ni)
85+
86+
my_val = values(:, mod(me-1, size(values,2))+1)
87+
call prif_co_sum(my_val)
88+
89+
expected = sum(reshape([(values(:, mod(i-1,size(values,2))+1), i = 1, ni)], [size(values,1),ni]), dim=2)
90+
result_ = assert_equals(expected, my_val)
4291
end function
4392

4493
function check_32_bit_complex() result(result_)
4594
type(result_t) :: result_
46-
result_ = succeed("temporarily")
95+
96+
complex(c_float), parameter :: values(*,*,*) = reshape( &
97+
[ cmplx(1., 53.), cmplx(3., 47.) &
98+
, cmplx(5., 43.), cmplx(7., 41.) &
99+
, cmplx(11., 37.), cmplx(13., 31.) &
100+
, cmplx(17., 29.), cmplx(19., 23.) &
101+
], &
102+
[2,2,2])
103+
integer :: me, ni, i
104+
complex(c_float), dimension(size(values,1),size(values,2)) :: my_val, expected
105+
106+
call prif_this_image_no_coarray(this_image=me)
107+
call prif_num_images(ni)
108+
109+
my_val = values(:, :, mod(me-1, size(values,3))+1)
110+
call prif_co_sum(my_val)
111+
112+
expected = sum(reshape([(values(:,:,mod(i-1,size(values,3))+1), i = 1, ni)], [size(values,1), size(values,2), ni]), dim=3)
113+
result_ = &
114+
assert_equals(real(expected, kind=c_double), real(my_val, kind=c_double)) &
115+
.and.assert_equals(real(aimag(expected), kind=c_double), real(aimag(my_val), kind=c_double))
47116
end function
48117

49118
function check_64_bit_complex() result(result_)
50119
type(result_t) :: result_
51-
result_ = succeed("temporarily")
120+
121+
complex(c_double), parameter :: values(*,*) = reshape( &
122+
[ cmplx(1., 53.), cmplx(3., 47.) &
123+
, cmplx(5., 43.), cmplx(7., 41.) &
124+
, cmplx(11., 37.), cmplx(13., 31.) &
125+
, cmplx(17., 29.), cmplx(19., 23.) &
126+
], &
127+
[2,4])
128+
integer :: me, ni, i
129+
complex(c_double), dimension(size(values,1)) :: my_val, expected
130+
131+
call prif_this_image_no_coarray(this_image=me)
132+
call prif_num_images(ni)
133+
134+
my_val = values(:, mod(me-1, size(values,2))+1)
135+
call prif_co_sum(my_val)
136+
137+
expected = sum(reshape([(values(:,mod(i-1,size(values,2))+1), i = 1, ni)], [size(values,1), ni]), dim=2)
138+
result_ = &
139+
assert_equals(real(expected), real(my_val)) &
140+
.and.assert_equals(aimag(expected), aimag(my_val))
52141
end function
53142

54143
end module caf_co_sum_test

0 commit comments

Comments
 (0)