Skip to content

Commit

Permalink
Improve mat_mat_block and mat_vec_block functions
Browse files Browse the repository at this point in the history
  • Loading branch information
gha3mi committed Jan 30, 2024
1 parent afc00b6 commit b87ebce
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions src/formatmul.f90
Original file line number Diff line number Diff line change
Expand Up @@ -311,49 +311,53 @@ end subroutine compute_block_ranges
!> author: Seyed Ali Ghasemi
pure function mat_mat_block_rel(a, b, transA, transB, option, nblock) result(c)
real(rk), intent(in), contiguous :: a(:,:), b(:,:)
character(*), intent(in), optional :: option
logical, intent(in), optional :: transA, transB
real(rk), allocatable :: c(:,:)
integer :: ib
character(*), intent(in), optional :: option
integer, intent(in) :: nblock
real(rk), allocatable :: c(:,:)
integer :: ib, se, ee
integer :: block_size(nblock), start_elem(nblock), end_elem(nblock)

if (present(transA) .and. present(transB)) then
if (.not.transA .and. .not.transB) then
! AB
allocate(C(size(A,1), size(B,2)), source=0.0_rk)
call compute_block_ranges(size(B,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(:, start_elem(ib):end_elem(ib)) = &
C(:, start_elem(ib):end_elem(ib)) + matmul(A, B(:,start_elem(ib):end_elem(ib)), transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(:, se:ee) = &
C(:, se:ee) + matmul(A, B(:,se:ee), transA, transB, option)
end do
else if (transA .and. transB) then
! ATBT
allocate(C(size(A,2), size(B,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(start_elem(ib):end_elem(ib), :) = &
C(start_elem(ib):end_elem(ib), :) + matmul(A(:, start_elem(ib):end_elem(ib)), B, transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(se:ee, :) = &
C(se:ee, :) + matmul(A(:, se:ee), B, transA, transB, option)
end do
else if (transA .and. .not.transB) then
! ATB
allocate(C(size(A,2), size(B,2)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(start_elem(ib):end_elem(ib), :) = &
C(start_elem(ib):end_elem(ib), :) + matmul(A(:, start_elem(ib):end_elem(ib)), B, transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(se:ee, :) = &
C(se:ee, :) + matmul(A(:, se:ee), B, transA, transB, option)
end do
else if (.not.transA .and. transB) then
! ABT
allocate(C(size(A,1), size(B,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
se = start_elem(ib)
ee = end_elem(ib)
C(:, :) = C(:, :) + &
matmul(A(:, start_elem(ib):end_elem(ib)), B(:,start_elem(ib):end_elem(ib)), transA, transB, option)
matmul(A(:, se:ee), B(:,se:ee), transA, transB, option)
end do
end if
else if (present(transA) .or. present(transB)) then
Expand All @@ -362,50 +366,55 @@ pure function mat_mat_block_rel(a, b, transA, transB, option, nblock) result(c)
! ATB
allocate(C(size(A,2), size(B,2)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(start_elem(ib):end_elem(ib), :) = &
C(start_elem(ib):end_elem(ib), :) + matmul(A(:, start_elem(ib):end_elem(ib)), B, transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(se:ee, :) = &
C(se:ee, :) + matmul(A(:, se:ee), B, transA, transB, option)
end do
else if (.not.transA) then
! ABT
allocate(C(size(A,1), size(B,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
se = start_elem(ib)
ee = end_elem(ib)
C(:, :) = C(:, :) + &
matmul(A(:, start_elem(ib):end_elem(ib)), B(:,start_elem(ib):end_elem(ib)), transA, transB, option)
matmul(A(:, se:ee), B(:,se:ee), transA, transB, option)
end do
end if
else if (present(transB)) then
if (transB) then
! ABT
allocate(C(size(A,1), size(B,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
se = start_elem(ib)
ee = end_elem(ib)
C(:, :) = C(:, :) + &
matmul(A(:, start_elem(ib):end_elem(ib)), B(:,start_elem(ib):end_elem(ib)), transA, transB, option)
matmul(A(:, se:ee), B(:,se:ee), transA, transB, option)
end do
else if (.not.transB) then
! ATB
allocate(C(size(A,2), size(B,2)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(start_elem(ib):end_elem(ib), :) = &
C(start_elem(ib):end_elem(ib), :) + matmul(A(:, start_elem(ib):end_elem(ib)), B, transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(se:ee, :) = &
C(se:ee, :) + matmul(A(:, se:ee), B, transA, transB, option)
end do
end if
end if
else if (.not.present(transA) .and. .not.present(transB)) then
! AB
allocate(C(size(A,1), size(B,2)), source=0.0_rk)
call compute_block_ranges(size(B,2), nblock, block_size, start_elem, end_elem)
c = 0.0_rk
do ib = 1, nblock
C(:, start_elem(ib):end_elem(ib)) = &
C(:, start_elem(ib):end_elem(ib)) + matmul(A, B(:,start_elem(ib):end_elem(ib)), transA, transB, option)
se = start_elem(ib)
ee = end_elem(ib)
C(:, se:ee) = &
C(:, se:ee) + matmul(A, B(:,se:ee), transA, transB, option)
end do
end if

Expand All @@ -417,11 +426,11 @@ end function mat_mat_block_rel
!> author: Seyed Ali Ghasemi
pure function mat_vec_block_rel(A, v, transA, option, nblock) result(w)
real(rk), intent(in), contiguous :: A(:,:), v(:)
character(*), intent(in), optional :: option
logical, intent(in), optional :: transA
real(rk), allocatable :: w(:)
integer :: ib
character(*), intent(in), optional :: option
integer, intent(in) :: nblock
real(rk), allocatable :: w(:)
integer :: ib, se, ee
integer :: block_size(nblock), start_elem(nblock), end_elem(nblock)


Expand All @@ -430,29 +439,32 @@ pure function mat_vec_block_rel(A, v, transA, option, nblock) result(w)
! ATv
allocate(w(size(A,2)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
w = 0.0_rk
do ib = 1, nblock
w(start_elem(ib):end_elem(ib)) = &
w(start_elem(ib):end_elem(ib)) + matmul(A(:,start_elem(ib):end_elem(ib)), v, transA, option)
se = start_elem(ib)
ee = end_elem(ib)
w(se:ee) = &
w(se:ee) + matmul(A(:,se:ee), v, transA, option)
end do
else if (.not. transA) then
! Av
allocate(w(size(A,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
w = 0.0_rk
do ib = 1, nblock
se = start_elem(ib)
ee = end_elem(ib)
w(:) = &
w(:) + matmul(A(:,start_elem(ib):end_elem(ib)), v(start_elem(ib):end_elem(ib)), transA, option)
w(:) + matmul(A(:,se:ee), v(se:ee), transA, option)
end do
end if
else if (.not. present(transA)) then
! Av
allocate(w(size(A,1)), source=0.0_rk)
call compute_block_ranges(size(A,2), nblock, block_size, start_elem, end_elem)
w = 0.0_rk
do ib = 1, nblock
se = start_elem(ib)
ee = end_elem(ib)
w(:) = &
w(:) + matmul(A(:,start_elem(ib):end_elem(ib)), v(start_elem(ib):end_elem(ib)), transA, option)
w(:) + matmul(A(:,se:ee), v(se:ee), transA, option)
end do
end if

Expand Down

0 comments on commit b87ebce

Please sign in to comment.