10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-11-18 12:03:57 +01:00
quantum_package/scripts/generate_h_apply.py

465 lines
15 KiB
Python
Raw Normal View History

2014-05-13 13:57:58 +02:00
#!/usr/bin/env python
import os
keywords = """
2016-02-19 00:20:28 +01:00
check_double_excitation
copy_buffer
2014-05-13 13:57:58 +02:00
declarations
decls_main
2014-05-28 23:12:00 +02:00
deinit_thread
2014-06-02 21:43:55 +02:00
skip
init_main
2014-06-06 16:19:14 +02:00
filter_integrals
2016-02-17 17:15:54 +01:00
filter2p
2016-03-11 23:27:39 +01:00
filter2h2p_double
filter2h2p_single
2016-02-17 17:15:54 +01:00
filter1h
filter1p
only_2p_single
only_2p_double
2016-02-16 11:14:19 +01:00
filter_only_1h1p_single
filter_only_1h1p_double
2016-03-11 23:27:39 +01:00
filter_only_1h2p_single
filter_only_1h2p_double
filter_only_2h2p_single
filter_only_2h2p_double
2015-03-25 12:06:50 +01:00
filterhole
2016-02-19 00:20:28 +01:00
filter_integrals
filter_only_1h1p_double
filter_only_1h1p_single
2015-03-25 12:06:50 +01:00
filterparticle
2015-07-15 14:01:06 +02:00
filter_vvvv_excitation
2016-02-19 00:20:28 +01:00
finalization
generate_psi_guess
initialization
init_main
init_thread
keys_work
omp_barrier
omp_do
omp_enddo
omp_end_master
omp_end_parallel
omp_master
omp_parallel
only_2p_double
only_2p_single
parameters
params_main
printout_always
printout_now
skip
subroutine
2014-05-13 13:57:58 +02:00
""".split()
2014-05-17 14:20:55 +02:00
class H_apply(object):
2016-02-19 00:20:28 +01:00
def read_template(self):
file = open(os.environ["QP_ROOT"]+'/src/Determinants/H_apply.template.f','r')
self.template = file.read()
file.close()
file = open(os.environ["QP_ROOT"]+'/src/Determinants/H_apply_nozmq.template.f','r')
self.template += file.read()
file.close()
def __init__(self,sub,SingleRef=False,do_mono_exc=True, do_double_exc=True):
2016-02-19 00:20:28 +01:00
self.read_template()
2014-05-17 14:20:55 +02:00
s = {}
2014-05-13 13:57:58 +02:00
for k in keywords:
s[k] = ""
2014-05-17 14:20:55 +02:00
s["subroutine"] = "H_apply_%s"%(sub)
s["params_post"] = ""
2014-05-21 16:37:54 +02:00
self.selection_pt2 = None
2016-04-01 23:33:58 +02:00
self.energy = "CI_electronic_energy"
2014-05-17 14:20:55 +02:00
self.perturbation = None
self.do_double_exc = do_double_exc
2014-05-13 13:57:58 +02:00
#s["omp_parallel"] = """!$OMP PARALLEL DEFAULT(NONE) &
2015-07-02 16:51:56 +02:00
s["omp_parallel"] = """ PROVIDE elec_num_tab
!$OMP PARALLEL DEFAULT(SHARED) &
2014-05-13 13:57:58 +02:00
!$OMP PRIVATE(i,j,k,l,keys_out,hole,particle, &
!$OMP occ_particle,occ_hole,j_a,k_a,other_spin, &
2014-06-06 16:19:14 +02:00
!$OMP hole_save,ispin,jj,l_a,ib_jb_pairs,array_pairs, &
2014-05-13 13:57:58 +02:00
!$OMP accu,i_a,hole_tmp,particle_tmp,occ_particle_tmp, &
!$OMP occ_hole_tmp,key_idx,i_b,j_b,key,N_elec_in_key_part_1,&
!$OMP N_elec_in_key_hole_1,N_elec_in_key_part_2, &
2015-07-15 14:01:06 +02:00
!$OMP N_elec_in_key_hole_2,ia_ja_pairs,key_union_hole_part) &
2014-05-21 16:37:54 +02:00
!$OMP SHARED(key_in,N_int,elec_num_tab,mo_tot_num, &
2014-05-13 13:57:58 +02:00
!$OMP hole_1, particl_1, hole_2, particl_2, &
2015-04-09 21:46:28 +02:00
!$OMP elec_alpha_num,i_generator) FIRSTPRIVATE(iproc)"""
2014-05-13 13:57:58 +02:00
s["omp_end_parallel"] = "!$OMP END PARALLEL"
s["omp_master"] = "!$OMP MASTER"
s["omp_end_master"] = "!$OMP END MASTER"
s["omp_barrier"] = "!$OMP BARRIER"
2015-11-20 19:51:56 +01:00
s["omp_do"] = "!$OMP DO SCHEDULE (static,1)"
2014-05-13 13:57:58 +02:00
s["omp_enddo"] = "!$OMP ENDDO NOWAIT"
d = { True : '.True.', False : '.False.'}
s["do_mono_excitations"] = d[do_mono_exc]
s["do_double_excitations"] = d[do_double_exc]
s["keys_work"] += "call fill_H_apply_buffer_no_selection(key_idx,keys_out,N_int,iproc)"
2014-06-06 16:19:14 +02:00
s["filter_integrals"] = "array_pairs = .True."
if SingleRef:
s["filter_integrals"] = """
call get_mo_bielec_integrals_existing_ik(i_a,j_a,mo_tot_num,array_pairs,mo_integrals_map)
"""
s["generate_psi_guess"] = """
! Sort H_jj to find the N_states lowest states
2014-05-28 23:12:00 +02:00
integer :: i
integer, allocatable :: iorder(:)
double precision, allocatable :: H_jj(:)
double precision, external :: diag_h_mat_elem
allocate(H_jj(N_det),iorder(N_det))
!$OMP PARALLEL DEFAULT(NONE) &
!$OMP SHARED(psi_det,N_int,H_jj,iorder,N_det) &
!$OMP PRIVATE(i)
!$OMP DO
do i = 1, N_det
H_jj(i) = diag_h_mat_elem(psi_det(1,1,i),N_int)
iorder(i) = i
enddo
!$OMP END DO
!$OMP END PARALLEL
call dsort(H_jj,iorder,N_det)
do k=1,N_states
psi_coef(iorder(k),k) = 1.d0
enddo
deallocate(H_jj,iorder)
"""
2015-12-30 11:35:06 +01:00
s["size_max"] = "8192"
s["copy_buffer"] = """call copy_H_apply_buffer_to_wf
2014-10-16 23:13:38 +02:00
if (s2_eig) then
call make_s2_eigenfunction
endif
SOFT_TOUCH psi_det psi_coef N_det
"""
s["printout_now"] = """write(output_determinants,*) &
2014-07-16 14:03:05 +02:00
100.*float(i_generator)/float(N_det_generators), '% in ', wall_1-wall_0, 's'"""
2014-05-17 14:20:55 +02:00
self.data = s
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __setitem__(self,key,value):
self.data[key] = value
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __getitem__(self,key):
return self.data[key]
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __repr__(self):
2016-02-19 00:20:28 +01:00
buffer = self.template
2014-05-17 14:20:55 +02:00
for key,value in self.data.items():
buffer = buffer.replace('$'+key, value)
return buffer
2014-05-13 13:57:58 +02:00
2015-03-25 12:06:50 +01:00
def unset_double_excitations(self):
self["do_double_excitations"] = ".False."
self["check_double_excitation"] = """
check_double_excitation = .False.
"""
2015-07-15 14:01:06 +02:00
def filter_vvvv_excitation(self):
self["filter_vvvv_excitation"] = """
key_union_hole_part = 0_bit_kind
2015-11-24 11:40:49 +01:00
call set_bit_to_integer(i_a,key_union_hole_part,N_int)
call set_bit_to_integer(j_a,key_union_hole_part,N_int)
call set_bit_to_integer(i_b,key_union_hole_part,N_int)
call set_bit_to_integer(j_b,key_union_hole_part,N_int)
2015-07-15 14:01:06 +02:00
do jtest_vvvv = 1, N_int
if(iand(key_union_hole_part(jtest_vvvv),virt_bitmask(jtest_vvvv,1).ne.key_union_hole_part(jtest_vvvv)))then
b_cycle = .False.
endif
enddo
if(b_cycle) cycle
"""
2015-03-25 12:06:50 +01:00
def set_filter_holes(self):
self["filterhole"] = """
if(iand(ibset(0_bit_kind,j),hole(k,other_spin)).eq.0_bit_kind )cycle
2015-03-25 12:06:50 +01:00
"""
def set_filter_particl(self):
self["filterparticle"] = """
if(iand(ibset(0_bit_kind,j_a),hole(k_a,other_spin)).eq.0_bit_kind )cycle
2015-03-25 12:06:50 +01:00
"""
2016-02-17 17:15:54 +01:00
def filter_1h(self):
self["filter1h"] = """
! ! DIR$ FORCEINLINE
if (is_a_1h(hole)) cycle
"""
def filter_2p(self):
self["filter2p"] = """
! ! DIR$ FORCEINLINE
if (is_a_2p(hole)) cycle
"""
def filter_1p(self):
2016-02-19 17:32:35 +01:00
self["filter1p"] = """
2016-02-17 17:15:54 +01:00
! ! DIR$ FORCEINLINE
if (is_a_1p(hole)) cycle
"""
def filter_only_2p(self):
self["only_2p_single"] = """
! ! DIR$ FORCEINLINE
2016-02-19 00:20:28 +01:00
if (.not. is_a_2p(hole)) cycle
2016-02-17 17:15:54 +01:00
"""
self["only_2p_double"] = """
! ! DIR$ FORCEINLINE
2016-02-19 00:20:28 +01:00
if (.not. is_a_2p(key)) cycle
2016-02-17 17:15:54 +01:00
"""
2016-02-16 11:14:19 +01:00
def filter_only_1h1p(self):
self["filter_only_1h1p_single"] = """
! ! DIR$ FORCEINLINE
if (is_a_1h1p(hole).eqv..False.) cycle
2016-02-16 11:14:19 +01:00
"""
self["filter_only_1h1p_double"] = """
! ! DIR$ FORCEINLINE
if (is_a_1h1p(key).eqv..False.) cycle
2016-02-16 11:14:19 +01:00
"""
2016-03-11 23:27:39 +01:00
def filter_only_2h2p(self):
self["filter_only_2h2p_single"] = """
! ! DIR$ FORCEINLINE
if (is_a_two_holes_two_particles(hole).eqv..False.) cycle
"""
self["filter_only_1h1p_double"] = """
! ! DIR$ FORCEINLINE
if (is_a_two_holes_two_particles(key).eqv..False.) cycle
"""
def filter_only_1h2p(self):
self["filter_only_1h2p_single"] = """
! ! DIR$ FORCEINLINE
if (is_a_1h2p(hole).eqv..False.) cycle
"""
self["filter_only_1h2p_double"] = """
! ! DIR$ FORCEINLINE
if (is_a_1h2p(key).eqv..False.) cycle
"""
2016-02-16 11:14:19 +01:00
def unset_skip(self):
self["skip"] = """
"""
2015-03-20 10:58:24 +01:00
def set_filter_2h_2p(self):
2016-03-11 23:27:39 +01:00
self["filter2h2p_double"] = """
if (is_a_two_holes_two_particles(key)) cycle
"""
2016-03-11 23:27:39 +01:00
self["filter2h2p_single"] = """
if (is_a_two_holes_two_particles(hole)) cycle
"""
2015-03-20 10:58:24 +01:00
2014-05-17 14:20:55 +02:00
def set_perturbation(self,pert):
2014-05-21 16:37:54 +02:00
if self.perturbation is not None:
raise
2014-05-17 14:20:55 +02:00
self.perturbation = pert
if pert is not None:
self.data["parameters"] = ",sum_e_2_pert_in,sum_norm_pert_in,sum_H_pert_diag_in,N_st,Nint"
self.data["declarations"] = """
integer, intent(in) :: N_st,Nint
double precision, intent(inout) :: sum_e_2_pert_in(N_st)
double precision, intent(inout) :: sum_norm_pert_in(N_st)
2014-05-28 23:12:00 +02:00
double precision, intent(inout) :: sum_H_pert_diag_in(N_st)
2014-05-17 14:20:55 +02:00
double precision :: sum_e_2_pert(N_st)
double precision :: sum_norm_pert(N_st)
2014-05-28 23:12:00 +02:00
double precision :: sum_H_pert_diag(N_st)
double precision, allocatable :: e_2_pert_buffer(:,:)
double precision, allocatable :: coef_pert_buffer(:,:)
ASSERT (Nint == N_int)
"""
self.data["init_thread"] = """
allocate (e_2_pert_buffer(N_st,size_max), coef_pert_buffer(N_st,size_max))
do k=1,N_st
sum_e_2_pert(k) = 0.d0
sum_norm_pert(k) = 0.d0
sum_H_pert_diag(k) = 0.d0
enddo
2014-05-30 18:07:04 +02:00
"""
2014-05-28 23:12:00 +02:00
self.data["deinit_thread"] = """
2016-06-06 11:23:04 +02:00
! OMP CRITICAL
2014-05-28 23:12:00 +02:00
do k=1,N_st
sum_e_2_pert_in(k) = sum_e_2_pert_in(k) + sum_e_2_pert(k)
sum_norm_pert_in(k) = sum_norm_pert_in(k) + sum_norm_pert(k)
sum_H_pert_diag_in(k) = sum_H_pert_diag_in(k) + sum_H_pert_diag(k)
enddo
2016-06-06 11:23:04 +02:00
! OMP END CRITICAL
2014-05-28 23:12:00 +02:00
deallocate (e_2_pert_buffer, coef_pert_buffer)
2014-05-17 14:20:55 +02:00
"""
2015-12-30 11:35:06 +01:00
self.data["size_max"] = "8192"
2014-05-17 14:20:55 +02:00
self.data["initialization"] = """
2015-11-20 12:27:20 +01:00
PROVIDE psi_selectors_coef psi_selectors E_corr_per_selectors psi_det_sorted_bit
2014-05-17 14:20:55 +02:00
"""
if self.do_double_exc == True:
2016-02-19 00:20:28 +01:00
self.data["keys_work"] = """
! if(check_double_excitation)then
call perturb_buffer_%s(i_generator,keys_out,key_idx,e_2_pert_buffer,coef_pert_buffer,sum_e_2_pert, &
2016-04-01 23:33:58 +02:00
sum_norm_pert,sum_H_pert_diag,N_st,N_int,key_mask,fock_diag_tmp,%s)
"""%(pert,self.energy)
else:
2016-02-19 00:20:28 +01:00
self.data["keys_work"] = """
call perturb_buffer_by_mono_%s(i_generator,keys_out,key_idx,e_2_pert_buffer,coef_pert_buffer,sum_e_2_pert, &
2016-04-01 23:33:58 +02:00
sum_norm_pert,sum_H_pert_diag,N_st,N_int,key_mask,fock_diag_tmp,%s)
"""%(pert,self.energy)
2016-02-17 17:15:54 +01:00
2014-05-17 14:20:55 +02:00
self.data["finalization"] = """
"""
self.data["copy_buffer"] = ""
self.data["generate_psi_guess"] = ""
self.data["params_main"] = "pt2, norm_pert, H_pert_diag, N_st"
self.data["params_post"] = ","+self.data["params_main"] +", N_int"
self.data["decls_main"] = """ integer, intent(in) :: N_st
double precision, intent(inout):: pt2(N_st)
double precision, intent(inout):: norm_pert(N_st)
2014-05-28 23:12:00 +02:00
double precision, intent(inout):: H_pert_diag(N_st)
2014-05-31 01:18:02 +02:00
double precision :: delta_pt2(N_st), norm_psi(N_st), pt2_old(N_st)
2015-11-20 12:27:20 +01:00
PROVIDE N_det_generators
2014-05-28 23:12:00 +02:00
do k=1,N_st
pt2(k) = 0.d0
norm_pert(k) = 0.d0
H_pert_diag(k) = 0.d0
2014-05-31 01:18:02 +02:00
norm_psi(k) = 0.d0
delta_pt2(k) = 0.d0
pt2_old(k) = 0.d0
2014-05-28 23:12:00 +02:00
enddo
2016-02-19 00:20:28 +01:00
write(output_determinants,'(A12, 1X, A8, 3(2X, A9), 2X, A8, 2X, A8, 2X, A8)') &
2014-05-31 01:18:02 +02:00
'N_generators', 'Norm', 'Delta PT2', 'PT2', 'Est. PT2', 'secs'
2016-02-19 00:20:28 +01:00
write(output_determinants,'(A12, 1X, A8, 3(2X, A9), 2X, A8, 2X, A8, 2X, A8)') &
2014-05-31 01:18:02 +02:00
'============', '========', '=========', '=========', '=========', &
'========='
"""
2014-05-31 01:18:02 +02:00
self.data["printout_always"] = """
do k=1,N_st
2015-04-02 11:40:16 +02:00
norm_psi(k) = norm_psi(k) + psi_coef_generators(i_generator,k)*psi_coef_generators(i_generator,k)
2014-05-31 01:18:02 +02:00
delta_pt2(k) = pt2(k) - pt2_old(k)
enddo
"""
self.data["printout_now"] = """
do k=1,N_st
write(output_determinants,'(I10, 4(2X, F9.6), 2X, F8.1)') &
2014-05-31 01:18:02 +02:00
i_generator, norm_psi(k), delta_pt2(k), pt2(k), &
2015-03-16 18:30:38 +01:00
pt2(k)/(norm_psi(k)*norm_psi(k)), &
2014-07-16 14:03:05 +02:00
wall_1-wall_0
2014-05-31 01:18:02 +02:00
pt2_old(k) = pt2(k)
enddo
"""
2014-06-06 16:19:14 +02:00
self.data["omp_parallel"] += """&
2014-05-28 23:12:00 +02:00
!$OMP SHARED(N_st) PRIVATE(e_2_pert_buffer,coef_pert_buffer) &
!$OMP PRIVATE(sum_e_2_pert, sum_norm_pert, sum_H_pert_diag)"""
2014-05-13 13:57:58 +02:00
2014-05-21 16:37:54 +02:00
def set_selection_pt2(self,pert):
if self.selection_pt2 is not None:
raise
self.set_perturbation(pert)
self.selection_pt2 = pert
if pert is not None:
2014-06-02 21:43:55 +02:00
self.data["parameters"] += ",select_max_out"
self.data["declarations"] += """
double precision, intent(inout) :: select_max_out"""
2014-06-27 10:36:40 +02:00
self.data["params_post"] += ", select_max(min(i_generator,size(select_max,1)))"
2015-12-30 11:35:06 +01:00
self.data["size_max"] = "8192"
self.data["copy_buffer"] = """
call copy_H_apply_buffer_to_wf
2014-10-28 17:56:29 +01:00
if (s2_eig) then
call make_s2_eigenfunction
endif
2015-03-20 10:58:24 +01:00
! SOFT_TOUCH psi_det psi_coef N_det
2014-06-25 14:58:58 +02:00
selection_criterion_min = min(selection_criterion_min, maxval(select_max))*0.1d0
selection_criterion = selection_criterion_min
call write_double(output_determinants,selection_criterion,'Selection criterion')
"""
self.data["keys_work"] = """
2014-05-21 16:37:54 +02:00
e_2_pert_buffer = 0.d0
coef_pert_buffer = 0.d0
""" + self.data["keys_work"]
self.data["keys_work"] += """
2014-06-02 21:43:55 +02:00
call fill_H_apply_buffer_selection(key_idx,keys_out,e_2_pert_buffer, &
coef_pert_buffer,N_st,N_int,iproc,select_max_out)
"""
self.data["omp_parallel"] += """&
!$OMP REDUCTION (max:select_max_out)"""
self.data["skip"] = """
2014-06-25 14:58:58 +02:00
if (i_generator < size_select_max) then
if (select_max(i_generator) < selection_criterion_min*selection_criterion_factor) then
2016-06-06 11:23:04 +02:00
! OMP CRITICAL
2014-06-25 14:58:58 +02:00
do k=1,N_st
norm_psi(k) = norm_psi(k) + psi_coef_generators(i_generator,k)*psi_coef_generators(i_generator,k)
2016-02-19 00:20:28 +01:00
pt2_old(k) = 0.d0
2014-06-25 14:58:58 +02:00
enddo
2016-06-06 11:23:04 +02:00
! OMP END CRITICAL
2014-06-25 14:58:58 +02:00
cycle
endif
2014-06-27 10:36:40 +02:00
select_max(i_generator) = 0.d0
2014-06-02 21:43:55 +02:00
endif
2014-05-21 16:37:54 +02:00
"""
2014-05-13 13:57:58 +02:00
2016-02-19 00:20:28 +01:00
def unset_openmp(self):
for k in keywords:
if k.startswith("omp_"):
self[k] = ""
class H_apply_zmq(H_apply):
def read_template(self):
file = open(os.environ["QP_ROOT"]+'/src/Determinants/H_apply.template.f','r')
self.template = file.read()
file.close()
file = open(os.environ["QP_ROOT"]+'/src/Determinants/H_apply_zmq.template.f','r')
self.template += file.read()
file.close()
def set_perturbation(self,pert):
H_apply.set_perturbation(self,pert)
self.data["printout_now"] = ""
self.data["printout_always"] = ""
self.data["decls_main"] = """ integer, intent(in) :: N_st
double precision, intent(inout):: pt2(N_st)
double precision, intent(inout):: norm_pert(N_st)
double precision, intent(inout):: H_pert_diag(N_st)
double precision :: delta_pt2(N_st), norm_psi(N_st), pt2_old(N_st)
PROVIDE N_det_generators
do k=1,N_st
pt2(k) = 0.d0
norm_pert(k) = 0.d0
H_pert_diag(k) = 0.d0
norm_psi(k) = 0.d0
2016-07-28 18:53:24 +02:00
energy(k) = %s(k)
2016-02-19 00:20:28 +01:00
enddo
2016-07-28 18:53:24 +02:00
""" % (self.energy)
2016-05-29 23:24:18 +02:00
self.data["copy_buffer"] = """
2016-09-27 15:55:38 +02:00
do i=1,N_det_generators
do k=1,N_st
pt2(k) = pt2(k) + pt2_generators(k,i)
norm_pert(k) = norm_pert(k) + norm_pert_generators(k,i)
H_pert_diag(k) = H_pert_diag(k) + H_pert_diag_generators(k,i)
enddo
2016-05-29 23:24:18 +02:00
enddo
"""
2016-02-19 00:20:28 +01:00
def set_selection_pt2(self,pert):
H_apply.set_selection_pt2(self,pert)
self.data["skip"] = """
if (i_generator < size_select_max) then
if (select_max(i_generator) < selection_criterion_min*selection_criterion_factor) then
do k=1,N_st
pt2(k) = select_max(i_generator)
enddo
cycle
endif
select_max(i_generator) = 0.d0
endif
"""
2016-05-29 23:24:18 +02:00