10
0
mirror of https://github.com/QuantumPackage/qp2.git synced 2024-11-19 04:22:32 +01:00

OPTIMZATIONS IN 4-ind integ

This commit is contained in:
AbdAmmar 2023-09-06 21:03:22 +02:00
parent 7076fcd202
commit 8739c26509
2 changed files with 116 additions and 102 deletions

View File

@ -18,10 +18,11 @@ program bi_ort_ints
! call test_5idx ! call test_5idx
! call test_5idx2 ! call test_5idx2
call test_4idx() call test_4idx()
call test_4idx_n4() !call test_4idx_n4()
!call test_4idx2() !call test_4idx2()
!call test_5idx2 !call test_5idx2
!call test_5idx !call test_5idx
end end
subroutine test_5idx2 subroutine test_5idx2
@ -340,7 +341,7 @@ subroutine test_4idx()
implicit none implicit none
integer :: i, j, k, l integer :: i, j, k, l
double precision :: accu, contrib, new, ref, thr double precision :: accu, contrib, new, ref, thr, norm
thr = 1d-10 thr = 1d-10
@ -348,6 +349,7 @@ subroutine test_4idx()
PROVIDE three_e_4_idx_direct_bi_ort PROVIDE three_e_4_idx_direct_bi_ort
accu = 0.d0 accu = 0.d0
norm = 0.d0
do i = 1, mo_num do i = 1, mo_num
do j = 1, mo_num do j = 1, mo_num
do k = 1, mo_num do k = 1, mo_num
@ -356,7 +358,6 @@ subroutine test_4idx()
new = three_e_4_idx_direct_bi_ort (l,k,j,i) new = three_e_4_idx_direct_bi_ort (l,k,j,i)
ref = three_e_4_idx_direct_bi_ort_old(l,k,j,i) ref = three_e_4_idx_direct_bi_ort_old(l,k,j,i)
contrib = dabs(new - ref) contrib = dabs(new - ref)
accu += contrib
if(contrib .gt. thr) then if(contrib .gt. thr) then
print*, ' problem in three_e_4_idx_direct_bi_ort' print*, ' problem in three_e_4_idx_direct_bi_ort'
print*, l, k, j, i print*, l, k, j, i
@ -364,11 +365,14 @@ subroutine test_4idx()
stop stop
endif endif
accu += contrib
norm += dabs(ref)
enddo enddo
enddo enddo
enddo enddo
enddo enddo
print*, ' accu on three_e_4_idx_direct_bi_ort = ', accu / dble(mo_num)**4
print*, ' accu on three_e_4_idx_direct_bi_ort (%) = ', 100.d0 * accu / norm
! --- ! ---
@ -376,6 +380,7 @@ subroutine test_4idx()
PROVIDE three_e_4_idx_exch13_bi_ort PROVIDE three_e_4_idx_exch13_bi_ort
accu = 0.d0 accu = 0.d0
norm = 0.d0
do i = 1, mo_num do i = 1, mo_num
do j = 1, mo_num do j = 1, mo_num
do k = 1, mo_num do k = 1, mo_num
@ -384,7 +389,6 @@ subroutine test_4idx()
new = three_e_4_idx_exch13_bi_ort (l,k,j,i) new = three_e_4_idx_exch13_bi_ort (l,k,j,i)
ref = three_e_4_idx_exch13_bi_ort_old(l,k,j,i) ref = three_e_4_idx_exch13_bi_ort_old(l,k,j,i)
contrib = dabs(new - ref) contrib = dabs(new - ref)
accu += contrib
if(contrib .gt. thr) then if(contrib .gt. thr) then
print*, ' problem in three_e_4_idx_exch13_bi_ort' print*, ' problem in three_e_4_idx_exch13_bi_ort'
print*, l, k, j, i print*, l, k, j, i
@ -392,11 +396,14 @@ subroutine test_4idx()
stop stop
endif endif
accu += contrib
norm += dabs(ref)
enddo enddo
enddo enddo
enddo enddo
enddo enddo
print*, ' accu on three_e_4_idx_exch13_bi_ort = ', accu / dble(mo_num)**4
print*, ' accu on three_e_4_idx_exch13_bi_ort (%) = ', 100.d0 * accu / norm
! --- ! ---
@ -404,6 +411,7 @@ subroutine test_4idx()
PROVIDE three_e_4_idx_cycle_1_bi_ort PROVIDE three_e_4_idx_cycle_1_bi_ort
accu = 0.d0 accu = 0.d0
norm = 0.d0
do i = 1, mo_num do i = 1, mo_num
do j = 1, mo_num do j = 1, mo_num
do k = 1, mo_num do k = 1, mo_num
@ -412,7 +420,6 @@ subroutine test_4idx()
new = three_e_4_idx_cycle_1_bi_ort (l,k,j,i) new = three_e_4_idx_cycle_1_bi_ort (l,k,j,i)
ref = three_e_4_idx_cycle_1_bi_ort_old(l,k,j,i) ref = three_e_4_idx_cycle_1_bi_ort_old(l,k,j,i)
contrib = dabs(new - ref) contrib = dabs(new - ref)
accu += contrib
if(contrib .gt. thr) then if(contrib .gt. thr) then
print*, ' problem in three_e_4_idx_cycle_1_bi_ort' print*, ' problem in three_e_4_idx_cycle_1_bi_ort'
print*, l, k, j, i print*, l, k, j, i
@ -420,11 +427,14 @@ subroutine test_4idx()
stop stop
endif endif
accu += contrib
norm += dabs(ref)
enddo enddo
enddo enddo
enddo enddo
enddo enddo
print*, ' accu on three_e_4_idx_cycle_1_bi_ort = ', accu / dble(mo_num)**4
print*, ' accu on three_e_4_idx_cycle_1_bi_ort (%) = ', 100.d0 * accu / norm
! --- ! ---
@ -432,6 +442,7 @@ subroutine test_4idx()
PROVIDE three_e_4_idx_exch23_bi_ort PROVIDE three_e_4_idx_exch23_bi_ort
accu = 0.d0 accu = 0.d0
norm = 0.d0
do i = 1, mo_num do i = 1, mo_num
do j = 1, mo_num do j = 1, mo_num
do k = 1, mo_num do k = 1, mo_num
@ -440,7 +451,6 @@ subroutine test_4idx()
new = three_e_4_idx_exch23_bi_ort (l,k,j,i) new = three_e_4_idx_exch23_bi_ort (l,k,j,i)
ref = three_e_4_idx_exch23_bi_ort_old(l,k,j,i) ref = three_e_4_idx_exch23_bi_ort_old(l,k,j,i)
contrib = dabs(new - ref) contrib = dabs(new - ref)
accu += contrib
if(contrib .gt. thr) then if(contrib .gt. thr) then
print*, ' problem in three_e_4_idx_exch23_bi_ort' print*, ' problem in three_e_4_idx_exch23_bi_ort'
print*, l, k, j, i print*, l, k, j, i
@ -448,13 +458,18 @@ subroutine test_4idx()
stop stop
endif endif
accu += contrib
norm += dabs(ref)
enddo enddo
enddo enddo
enddo enddo
enddo enddo
print*, ' accu on three_e_4_idx_exch23_bi_ort = ', accu / dble(mo_num)**4
print*, ' accu on three_e_4_idx_exch23_bi_ort (%) = ', 100.d0 * accu / norm
! --- ! ---
return return
end end

View File

@ -64,120 +64,117 @@
!$OMP END DO !$OMP END DO
!$OMP END PARALLEL !$OMP END PARALLEL
! loops approach to break the O(N^4) scaling in memory
call set_multiple_levels_omp(.false.)
!$OMP PARALLEL &
!$OMP DEFAULT (NONE) &
!$OMP PRIVATE (n, ipoint, tmp_loc_1, tmp_loc_2, tmp_2d, tmp1, tmp2) &
!$OMP SHARED (mo_num, n_points_final_grid, i, k, &
!$OMP mos_l_in_r_array_transp, mos_r_in_r_array_transp, &
!$OMP int2_grad1_u12_bimo_t, final_weight_at_r_vector, &
!$OMP tmp_aux_1, tmp_aux_2, &
!$OMP three_e_4_idx_direct_bi_ort, three_e_4_idx_exch13_bi_ort, &
!$OMP three_e_4_idx_exch23_bi_ort, three_e_4_idx_cycle_1_bi_ort)
allocate(tmp_2d(mo_num,mo_num)) allocate(tmp_2d(mo_num,mo_num))
allocate(tmp1(n_points_final_grid,4,mo_num)) allocate(tmp1(n_points_final_grid,4,mo_num))
allocate(tmp2(n_points_final_grid,4,mo_num)) allocate(tmp2(n_points_final_grid,4,mo_num))
! loops approach to break the O(N^4) scaling in memory !$OMP DO
do k = 1, mo_num do k = 1, mo_num
! ---
do i = 1, mo_num do i = 1, mo_num
!$OMP PARALLEL & ! ---
!$OMP DEFAULT (NONE) &
!$OMP PRIVATE (n, ipoint, tmp_loc_1, tmp_loc_2) &
!$OMP SHARED (mo_num, n_points_final_grid, i, k, &
!$OMP mos_l_in_r_array_transp, mos_r_in_r_array_transp, &
!$OMP int2_grad1_u12_bimo_t, final_weight_at_r_vector, &
!$OMP tmp_aux_2, tmp1)
!$OMP DO
do n = 1, mo_num
do ipoint = 1, n_points_final_grid
tmp_loc_1 = mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,i) do n = 1, mo_num
tmp_loc_2 = tmp_aux_2(ipoint,n) do ipoint = 1, n_points_final_grid
tmp1(ipoint,1,n) = int2_grad1_u12_bimo_t(ipoint,1,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,1,k,i) * tmp_loc_2 tmp_loc_1 = mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,i)
tmp1(ipoint,2,n) = int2_grad1_u12_bimo_t(ipoint,2,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,2,k,i) * tmp_loc_2 tmp_loc_2 = tmp_aux_2(ipoint,n)
tmp1(ipoint,3,n) = int2_grad1_u12_bimo_t(ipoint,3,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,3,k,i) * tmp_loc_2
tmp1(ipoint,4,n) = int2_grad1_u12_bimo_t(ipoint,1,n,n) * int2_grad1_u12_bimo_t(ipoint,1,k,i) &
+ int2_grad1_u12_bimo_t(ipoint,2,n,n) * int2_grad1_u12_bimo_t(ipoint,2,k,i) &
+ int2_grad1_u12_bimo_t(ipoint,3,n,n) * int2_grad1_u12_bimo_t(ipoint,3,k,i)
tmp1(ipoint,1,n) = int2_grad1_u12_bimo_t(ipoint,1,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,1,k,i) * tmp_loc_2
tmp1(ipoint,2,n) = int2_grad1_u12_bimo_t(ipoint,2,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,2,k,i) * tmp_loc_2
tmp1(ipoint,3,n) = int2_grad1_u12_bimo_t(ipoint,3,n,n) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,3,k,i) * tmp_loc_2
tmp1(ipoint,4,n) = int2_grad1_u12_bimo_t(ipoint,1,n,n) * int2_grad1_u12_bimo_t(ipoint,1,k,i) &
+ int2_grad1_u12_bimo_t(ipoint,2,n,n) * int2_grad1_u12_bimo_t(ipoint,2,k,i) &
+ int2_grad1_u12_bimo_t(ipoint,3,n,n) * int2_grad1_u12_bimo_t(ipoint,3,k,i)
enddo
enddo enddo
enddo
!$OMP END DO
!$OMP END PARALLEL
call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 & call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 &
, tmp_aux_1(1,1,1), 4*n_points_final_grid, tmp1(1,1,1), 4*n_points_final_grid & , tmp_aux_1(1,1,1), 4*n_points_final_grid, tmp1(1,1,1), 4*n_points_final_grid &
, 0.d0, tmp_2d(1,1), mo_num) , 0.d0, tmp_2d(1,1), mo_num)
!$OMP PARALLEL DO PRIVATE(j,m) do j = 1, mo_num
do j = 1, mo_num do m = 1, mo_num
do m = 1, mo_num three_e_4_idx_direct_bi_ort(m,j,k,i) = -tmp_2d(m,j)
three_e_4_idx_direct_bi_ort(m,j,k,i) = -tmp_2d(m,j) enddo
enddo enddo
enddo
!$OMP END PARALLEL DO
! ---
do n = 1, mo_num
do ipoint = 1, n_points_final_grid
!$OMP PARALLEL & tmp_loc_1 = mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,n)
!$OMP DEFAULT (NONE) & tmp_loc_2 = mos_l_in_r_array_transp(ipoint,n) * mos_r_in_r_array_transp(ipoint,i)
!$OMP PRIVATE (n, ipoint, tmp_loc_1, tmp_loc_2) &
!$OMP SHARED (mo_num, n_points_final_grid, i, k, &
!$OMP mos_l_in_r_array_transp, mos_r_in_r_array_transp, &
!$OMP int2_grad1_u12_bimo_t, final_weight_at_r_vector, &
!$OMP tmp1, tmp2)
!$OMP DO
do n = 1, mo_num
do ipoint = 1, n_points_final_grid
tmp_loc_1 = mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,n) tmp1(ipoint,1,n) = int2_grad1_u12_bimo_t(ipoint,1,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,1,k,n) * tmp_loc_2
tmp_loc_2 = mos_l_in_r_array_transp(ipoint,n) * mos_r_in_r_array_transp(ipoint,i) tmp1(ipoint,2,n) = int2_grad1_u12_bimo_t(ipoint,2,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,2,k,n) * tmp_loc_2
tmp1(ipoint,3,n) = int2_grad1_u12_bimo_t(ipoint,3,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,3,k,n) * tmp_loc_2
tmp1(ipoint,4,n) = int2_grad1_u12_bimo_t(ipoint,1,n,i) * int2_grad1_u12_bimo_t(ipoint,1,k,n) &
+ int2_grad1_u12_bimo_t(ipoint,2,n,i) * int2_grad1_u12_bimo_t(ipoint,2,k,n) &
+ int2_grad1_u12_bimo_t(ipoint,3,n,i) * int2_grad1_u12_bimo_t(ipoint,3,k,n)
tmp1(ipoint,1,n) = int2_grad1_u12_bimo_t(ipoint,1,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,1,k,n) * tmp_loc_2 tmp2(ipoint,1,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,1,i,n)
tmp1(ipoint,2,n) = int2_grad1_u12_bimo_t(ipoint,2,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,2,k,n) * tmp_loc_2 tmp2(ipoint,2,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,2,i,n)
tmp1(ipoint,3,n) = int2_grad1_u12_bimo_t(ipoint,3,n,i) * tmp_loc_1 + int2_grad1_u12_bimo_t(ipoint,3,k,n) * tmp_loc_2 tmp2(ipoint,3,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,3,i,n)
tmp1(ipoint,4,n) = int2_grad1_u12_bimo_t(ipoint,1,n,i) * int2_grad1_u12_bimo_t(ipoint,1,k,n) & tmp2(ipoint,4,n) = final_weight_at_r_vector(ipoint) * mos_l_in_r_array_transp(ipoint,i) * mos_r_in_r_array_transp(ipoint,n)
+ int2_grad1_u12_bimo_t(ipoint,2,n,i) * int2_grad1_u12_bimo_t(ipoint,2,k,n) & enddo
+ int2_grad1_u12_bimo_t(ipoint,3,n,i) * int2_grad1_u12_bimo_t(ipoint,3,k,n)
tmp2(ipoint,1,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,1,i,n)
tmp2(ipoint,2,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,2,i,n)
tmp2(ipoint,3,n) = final_weight_at_r_vector(ipoint) * int2_grad1_u12_bimo_t(ipoint,3,i,n)
tmp2(ipoint,4,n) = final_weight_at_r_vector(ipoint) * mos_l_in_r_array_transp(ipoint,i) * mos_r_in_r_array_transp(ipoint,n)
enddo enddo
enddo
!$OMP END DO
!$OMP END PARALLEL
call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 & ! ---
, tmp1(1,1,1), 4*n_points_final_grid, tmp_aux_1(1,1,1), 4*n_points_final_grid &
, 0.d0, tmp_2d(1,1), mo_num)
!$OMP PARALLEL DO PRIVATE(j,m) call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 &
do j = 1, mo_num , tmp1(1,1,1), 4*n_points_final_grid, tmp_aux_1(1,1,1), 4*n_points_final_grid &
do m = 1, mo_num , 0.d0, tmp_2d(1,1), mo_num)
three_e_4_idx_exch13_bi_ort(m,j,k,i) = -tmp_2d(m,j)
do j = 1, mo_num
do m = 1, mo_num
three_e_4_idx_exch13_bi_ort(m,j,k,i) = -tmp_2d(m,j)
enddo
enddo enddo
enddo
!$OMP END PARALLEL DO
call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 & ! ---
, tmp1(1,1,1), 4*n_points_final_grid, tmp2(1,1,1), 4*n_points_final_grid &
, 0.d0, tmp_2d(1,1), mo_num)
!$OMP PARALLEL DO PRIVATE(j,m) call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 &
do j = 1, mo_num , tmp1(1,1,1), 4*n_points_final_grid, tmp2(1,1,1), 4*n_points_final_grid &
do m = 1, mo_num , 0.d0, tmp_2d(1,1), mo_num)
three_e_4_idx_cycle_1_bi_ort(m,i,k,j) = -tmp_2d(m,j)
do j = 1, mo_num
do m = 1, mo_num
three_e_4_idx_cycle_1_bi_ort(m,i,k,j) = -tmp_2d(m,j)
enddo
enddo enddo
enddo
!$OMP END PARALLEL DO
enddo ! i ! ---
enddo ! i
! ---
do j = 1, mo_num do j = 1, mo_num
!$OMP PARALLEL &
!$OMP DEFAULT (NONE) &
!$OMP PRIVATE (n, ipoint, tmp_loc_1, tmp_loc_2) &
!$OMP SHARED (mo_num, n_points_final_grid, j, k, &
!$OMP mos_l_in_r_array_transp, mos_r_in_r_array_transp, &
!$OMP int2_grad1_u12_bimo_t, final_weight_at_r_vector, &
!$OMP tmp1, tmp2)
!$OMP DO
do n = 1, mo_num do n = 1, mo_num
do ipoint = 1, n_points_final_grid do ipoint = 1, n_points_final_grid
@ -197,31 +194,33 @@
tmp2(ipoint,4,n) = final_weight_at_r_vector(ipoint) * mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,n) tmp2(ipoint,4,n) = final_weight_at_r_vector(ipoint) * mos_l_in_r_array_transp(ipoint,k) * mos_r_in_r_array_transp(ipoint,n)
enddo enddo
enddo enddo
!$OMP END DO
!$OMP END PARALLEL
call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 & call dgemm( 'T', 'N', mo_num, mo_num, 4*n_points_final_grid, 1.d0 &
, tmp1(1,1,1), 4*n_points_final_grid, tmp2(1,1,1), 4*n_points_final_grid & , tmp1(1,1,1), 4*n_points_final_grid, tmp2(1,1,1), 4*n_points_final_grid &
, 0.d0, tmp_2d(1,1), mo_num) , 0.d0, tmp_2d(1,1), mo_num)
!$OMP PARALLEL DO PRIVATE(i,m)
do i = 1, mo_num do i = 1, mo_num
do m = 1, mo_num do m = 1, mo_num
three_e_4_idx_exch23_bi_ort(m,j,k,i) = -tmp_2d(m,i) three_e_4_idx_exch23_bi_ort(m,j,k,i) = -tmp_2d(m,i)
enddo enddo
enddo enddo
!$OMP END PARALLEL DO
enddo ! j enddo ! j
! ---
enddo !k enddo !k
!$OMP END DO
deallocate(tmp_2d) deallocate(tmp_2d)
deallocate(tmp1) deallocate(tmp1)
deallocate(tmp2) deallocate(tmp2)
!$OMP END PARALLEL
deallocate(tmp_aux_1) deallocate(tmp_aux_1)
deallocate(tmp_aux_2) deallocate(tmp_aux_2)
call wall_time(wall1) call wall_time(wall1)
print *, ' wall time for three_e_4_idx_bi_ort', wall1 - wall0 print *, ' wall time for three_e_4_idx_bi_ort', wall1 - wall0
call print_memory_usage() call print_memory_usage()