From d1507c937a510fc6b85a562fbe12c0c00db31513 Mon Sep 17 00:00:00 2001 From: Anthony Scemama Date: Wed, 1 Feb 2017 19:43:38 +0100 Subject: [PATCH] Optimized selection --- plugins/Full_CI_ZMQ/selection.irp.f | 163 +++++++++++++++++----------- 1 file changed, 102 insertions(+), 61 deletions(-) diff --git a/plugins/Full_CI_ZMQ/selection.irp.f b/plugins/Full_CI_ZMQ/selection.irp.f index afcb51db..587618c8 100644 --- a/plugins/Full_CI_ZMQ/selection.irp.f +++ b/plugins/Full_CI_ZMQ/selection.irp.f @@ -55,7 +55,7 @@ subroutine get_mask_phase(det, phasemask) implicit none integer(bit_kind), intent(in) :: det(N_int, 2) - integer(1), intent(out) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(out) :: phasemask(2,N_int*bit_kind_size) integer :: s, ni, i logical :: change @@ -65,7 +65,7 @@ subroutine get_mask_phase(det, phasemask) do ni=1,N_int do i=0,bit_kind_size-1 if(BTEST(det(ni, s), i)) change = .not. change - if(change) phasemask((ni-1)*bit_kind_size + i + 1, s) = 1_1 + if(change) phasemask(s, (ni-1)*bit_kind_size + i + 1) = 1_1 end do end do end do @@ -104,18 +104,20 @@ double precision function get_phase_bi(phasemask, s1, s2, h1, p1, h2, p2) use bitmasks implicit none - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,*) integer, intent(in) :: s1, s2, h1, h2, p1, p2 logical :: change - integer(1) :: np - double precision, parameter :: res(0:1) = (/1d0, -1d0/) + integer(1) :: np1 + integer :: np + double precision, save :: res(0:1) = (/1d0, -1d0/) - np = phasemask(h1,s1) + phasemask(p1,s1) + phasemask(h2,s2) + phasemask(p2,s2) - if(p1 < h1) np = np + 1_1 - if(p2 < h2) np = np + 1_1 + np1 = phasemask(s1,h1) + phasemask(s1,p1) + phasemask(s2,h2) + phasemask(s2,p2) + np = np1 + if(p1 < h1) np = np + 1 + if(p2 < h2) np = np + 1 - if(s1 == s2 .and. max(h1, p1) > min(h2, p2)) np = np + 1_1 - get_phase_bi = res(iand(np,1_1)) + if(s1 == s2 .and. max(h1, p1) > min(h2, p2)) np = np + 1 + get_phase_bi = res(iand(np,1)) end @@ -125,7 +127,7 @@ subroutine get_m2(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: gen(N_int, 2), mask(N_int, 2) - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num) double precision, intent(in) :: coefs(N_states) double precision, intent(inout) :: vect(N_states, mo_tot_num) @@ -184,7 +186,7 @@ subroutine get_m1(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: gen(N_int, 2), mask(N_int, 2) - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num) double precision, intent(in) :: coefs(N_states) double precision, intent(inout) :: vect(N_states, mo_tot_num) @@ -246,7 +248,7 @@ subroutine get_m0(gen, phasemask, bannedOrb, vect, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: gen(N_int, 2), mask(N_int, 2) - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num) double precision, intent(in) :: coefs(N_states) double precision, intent(inout) :: vect(N_states, mo_tot_num) @@ -337,8 +339,10 @@ subroutine select_doubles(i_generator,hole_mask,particle_mask,fock_diag_tmp,E0,p end do do i=1,N_det - nt = 0 - do j=1,N_int + mobMask(1,1) = iand(negMask(1,1), psi_det_sorted(1,1,i)) + mobMask(1,2) = iand(negMask(1,2), psi_det_sorted(1,2,i)) + nt = popcnt(mobMask(1, 1)) + popcnt(mobMask(1, 2)) + do j=2,N_int mobMask(j,1) = iand(negMask(j,1), psi_det_sorted(j,1,i)) mobMask(j,2) = iand(negMask(j,2), psi_det_sorted(j,2,i)) nt = nt + popcnt(mobMask(j, 1)) + popcnt(mobMask(j, 2)) @@ -578,9 +582,18 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere do i=1, N_sel ! interesting(0) !i = interesting(ii) + if (interesting(i) < 0) then + stop 'prefetch interesting(i)' + endif + - nt = 0 - do j=1,N_int + mobMask(1,1) = iand(negMask(1,1), det(1,1,i)) + mobMask(1,2) = iand(negMask(1,2), det(1,2,i)) + nt = popcnt(mobMask(1, 1)) + popcnt(mobMask(1, 2)) + + if(nt > 4) cycle + + do j=2,N_int mobMask(j,1) = iand(negMask(j,1), det(j,1,i)) mobMask(j,2) = iand(negMask(j,2), det(j,2,i)) nt = nt + popcnt(mobMask(j, 1)) + popcnt(mobMask(j, 2)) @@ -588,25 +601,7 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere if(nt > 4) cycle - do j=1,N_int - perMask(j,1) = iand(mask(j,1), not(det(j,1,i))) - perMask(j,2) = iand(mask(j,2), not(det(j,2,i))) - end do - - call bitstring_to_list(perMask(1,1), h(1,1), h(0,1), N_int) - call bitstring_to_list(perMask(1,2), h(1,2), h(0,2), N_int) - - call bitstring_to_list(mobMask(1,1), p(1,1), p(0,1), N_int) - call bitstring_to_list(mobMask(1,2), p(1,2), p(0,2), N_int) - - if(interesting(i) < i_gen) then - if(nt == 4) call past_d2(banned, p, sp) - if(nt == 3) call past_d1(bannedOrb, p) - else - if(interesting(i) /= i_gen) then - continue - else -! bandon = .true. + if (interesting(i) == i_gen) then if(sp == 3) then do j=1,mo_tot_num do k=1,mo_tot_num @@ -620,14 +615,32 @@ subroutine splash_pq(mask, sp, det, i_gen, N_sel, bannedOrb, banned, mat, intere end do end do end if - end if - if(nt == 4) then - call get_d2(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) - else if(nt == 3) then - call get_d1(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) - else - call get_d0(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) - end if + end if + + call bitstring_to_list_in_selection(mobMask(1,1), p(1,1), p(0,1), N_int) + call bitstring_to_list_in_selection(mobMask(1,2), p(1,2), p(0,2), N_int) + + perMask(1,1) = iand(mask(1,1), not(det(1,1,i))) + perMask(1,2) = iand(mask(1,2), not(det(1,2,i))) + do j=2,N_int + perMask(j,1) = iand(mask(j,1), not(det(j,1,i))) + perMask(j,2) = iand(mask(j,2), not(det(j,2,i))) + end do + + call bitstring_to_list_in_selection(perMask(1,1), h(1,1), h(0,1), N_int) + call bitstring_to_list_in_selection(perMask(1,2), h(1,2), h(0,2), N_int) + + if (interesting(i) >= i_gen) then + if(nt == 4) then + call get_d2(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) + else if(nt == 3) then + call get_d1(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) + else + call get_d0(det(1,1,i), psi_phasemask(1,1,interesting(i)), bannedOrb, banned, mat, mask, h, p, sp, psi_selectors_coef_transp(1, interesting(i))) + end if + else + if(nt == 4) call past_d2(banned, p, sp) + if(nt == 3) call past_d1(bannedOrb, p) end if end do end @@ -638,7 +651,7 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: mask(N_int, 2), gen(N_int, 2) - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num, 2), banned(mo_tot_num, mo_tot_num,2) double precision, intent(in) :: coefs(N_states) double precision, intent(inout) :: mat(N_states, mo_tot_num, mo_tot_num) @@ -687,20 +700,20 @@ subroutine get_d2(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs) end if end do else - do i = 1,2 + h1 = h(1,1) + h2 = h(1,2) do j = 1,2 - puti = p(i, 1) putj = p(j, 2) - - if(banned(puti,putj,bant)) cycle - p1 = p(turn2(i), 1) p2 = p(turn2(j), 2) - h1 = h(1,1) - h2 = h(1,2) - - hij = integral8(p1, p2, h1, h2) * get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2) - mat(:, puti, putj) += coefs * hij - end do + do i = 1,2 + puti = p(i, 1) + + if(banned(puti,putj,bant)) cycle + p1 = p(turn2(i), 1) + + hij = integral8(p1, p2, h1, h2) * get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2) + mat(:, puti, putj) += coefs * hij + end do end do end if @@ -756,7 +769,7 @@ subroutine get_d1(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: mask(N_int, 2), gen(N_int, 2) - integer(1),intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1),intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num, 2), banned(mo_tot_num, mo_tot_num,2) integer(bit_kind) :: det(N_int, 2) double precision, intent(in) :: coefs(N_states) @@ -925,7 +938,7 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs) implicit none integer(bit_kind), intent(in) :: gen(N_int, 2), mask(N_int, 2) - integer(1), intent(in) :: phasemask(N_int*bit_kind_size, 2) + integer(1), intent(in) :: phasemask(2,N_int*bit_kind_size) logical, intent(in) :: bannedOrb(mo_tot_num, 2), banned(mo_tot_num, mo_tot_num,2) integer(bit_kind) :: det(N_int, 2) double precision, intent(in) :: coefs(N_states) @@ -953,8 +966,8 @@ subroutine get_d0(gen, phasemask, bannedOrb, banned, mat, mask, h, p, sp, coefs) call apply_particles(mask, 1,p1,2,p2, det, ok, N_int) call i_h_j(gen, det, N_int, hij) else - hij = integral8(p1, p2, h1, h2) * get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2) phase = get_phase_bi(phasemask, 1, 2, h1, p1, h2, p2) + hij = integral8(p1, p2, h1, h2) * phase end if mat(:, p1, p2) += coefs(:) * hij end do @@ -1059,9 +1072,37 @@ subroutine spot_isinwf(mask, det, i_gen, N, banned, fullMatch, interesting) myMask(j, 2) = iand(det(j, 2, i), negMask(j, 2)) end do - call bitstring_to_list(myMask(1,1), list(1), na, N_int) - call bitstring_to_list(myMask(1,2), list(na+1), nb, N_int) + call bitstring_to_list_in_selection(myMask(1,1), list(1), na, N_int) + call bitstring_to_list_in_selection(myMask(1,2), list(na+1), nb, N_int) banned(list(1), list(2)) = .true. end do genl end + +subroutine bitstring_to_list_in_selection( string, list, n_elements, Nint) + use bitmasks + implicit none + BEGIN_DOC + ! Gives the inidices(+1) of the bits set to 1 in the bit string + END_DOC + integer, intent(in) :: Nint + integer(bit_kind), intent(in) :: string(Nint) + integer, intent(out) :: list(Nint*bit_kind_size) + integer, intent(out) :: n_elements + + integer :: i, ishift + integer(bit_kind) :: l + + n_elements = 0 + ishift = 2 + do i=1,Nint + l = string(i) + do while (l /= 0_bit_kind) + n_elements = n_elements+1 + list(n_elements) = ishift+popcnt(l-1_bit_kind) - popcnt(l) + l = iand(l,l-1_bit_kind) + enddo + ishift = ishift + bit_kind_size + enddo + +end