Skip to content

Commit

Permalink
Improve fordot module
Browse files Browse the repository at this point in the history
  • Loading branch information
gha3mi committed Jan 24, 2024
1 parent 6b6884c commit eabab46
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 48 deletions.
75 changes: 32 additions & 43 deletions src/fordot.f90
Original file line number Diff line number Diff line change
Expand Up @@ -10,67 +10,56 @@ module fordot

interface dot_product
procedure :: dot_R0R1R1_rel
procedure :: dot_R0R1R1_rel_default
procedure :: dot_R0R1R1_rel_coarray
end interface

contains

!> author: Seyed Ali Ghasemi
pure function dot_R0R1R1_rel_default(u,v,option) result(a)
pure function dot_R0R1R1_rel(u,v,option) result(a)
real(rk), intent(in), contiguous :: u(:)
real(rk), intent(in), contiguous :: v(:)
character(*), intent(in) :: option
real(rk) :: a
a = dot_opts(u, v, option)
end function dot_R0R1R1_rel_default
end function dot_R0R1R1_rel



!> author: Seyed Ali Ghasemi
#if defined(USE_COARRAY)
impure function dot_R0R1R1_rel(u,v,method,option) result(a)
#else
pure function dot_R0R1R1_rel(u,v,method,option) result(a)
#endif
real(rk), intent(in), contiguous :: u(:)
real(rk), intent(in), contiguous :: v(:)
character(*), intent(in) :: method
impure function dot_R0R1R1_rel_coarray(u,v,option,coarray) result(a)
real(rk), intent(in) :: u(:)
real(rk), intent(in) :: v(:)
character(*), intent(in) :: option
real(rk) :: a

select case (method)
logical, intent(in) :: coarray
#if defined(USE_COARRAY)
case ('coarray')

block
integer :: i, im, nimg, m
integer :: block_size(num_images()), start_elem(num_images()), end_elem(num_images())
real(rk), allocatable :: a_block[:], u_block(:)[:], v_block(:)[:]
im = this_image()
nimg = num_images()
m = size(u)
call compute_block_ranges(size(u), nimg, block_size, start_elem, end_elem)
allocate(u_block(block_size(im))[*], v_block(block_size(im))[*], a_block[*])
u_block(:)[im] = u(start_elem(im):end_elem(im))
v_block(:)[im] = v(start_elem(im):end_elem(im))
a_block[im] = dot_opts(u_block(:)[im],v_block(:)[im],option)
call co_sum(a_block, result_image=1)
a = a_block[1]
! sync all
! if (im == 1) then
! a = 0.0_rk
! do i = 1, nimg
! a = a + a_block[i]
! end do
! end if
end block

integer :: i, im, nimg, m
integer :: block_size(num_images()), start_elem(num_images()), end_elem(num_images())
real(rk), allocatable :: a_block[:], u_block(:)[:], v_block(:)[:]

im = this_image()
nimg = num_images()
m = size(u)
call compute_block_ranges(size(u), nimg, block_size, start_elem, end_elem)
allocate(u_block(block_size(im))[*], v_block(block_size(im))[*], a_block[*])
u_block(:)[im] = u(start_elem(im):end_elem(im))
v_block(:)[im] = v(start_elem(im):end_elem(im))
a_block[im] = dot_opts(u_block(:)[im],v_block(:)[im],option)
! call co_sum(a_block, result_image=1)
! a = a_block[1]
sync all
if (im == 1) then
a = 0.0_rk
do i = 1, nimg
a = a + a_block[i]
end do
end if
#else
a = dot_product(u, v, option)
#endif
case ('default')
a = dot_opts(u, v, option)
end select

end function dot_R0R1R1_rel
end function dot_R0R1R1_rel_coarray



Expand All @@ -90,7 +79,7 @@ pure subroutine compute_block_ranges(d, nimg, block_size, start_elem, end_elem)
end_elem(i) = start_elem(i) + block_size(i) - 1
end do
! Check if the block sizes are valid.
if (minval(block_size) <= 0) error stop 'fordot: reduce the number of images of coarray.'
if (minval(block_size) <= 0) error stop 'ForDot: reduce the number of images of coarray.'
end subroutine compute_block_ranges

end module fordot
10 changes: 5 additions & 5 deletions test/test2.f90
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ program test_dot2

a_ref = dot_product(u,v)

a = dot_product(u,v,'coarray')
a = dot_product(u,v,coarray=.true.)
if (im==1) call ut%check(a, a_ref, tol=1e-5_rk, msg='test_dot2.1')

a = dot_product(u,v, 'coarray', 'm1')
a = dot_product(u,v, coarray=.true., 'm1')
if (im==1) call ut%check(a, a_ref, tol=1e-5_rk, msg='test_dot2.2')

a = dot_product(u,v, 'coarray', 'm2')
a = dot_product(u,v, coarray=.true., 'm2')
if (im==1) call ut%check(a, a_ref, tol=1e-5_rk, msg='test_dot2.3')

a = dot_product(u,v, 'coarray', 'm3')
a = dot_product(u,v, coarray=.true., 'm3')
if (im==1) call ut%check(a, a_ref, tol=1e-5_rk, msg='test_dot2.4')

a = dot_product(u,v, 'coarray', 'm4')
a = dot_product(u,v, coarray=.true., 'm4')
if (im==1) call ut%check(a, a_ref, tol=1e-5_rk, msg='test_dot2.5')

end program test_dot2
Expand Down

0 comments on commit eabab46

Please sign in to comment.