10
0
mirror of https://github.com/LCPQ/quantum_package synced 2024-11-09 07:33:53 +01:00
quantum_package/scripts/generate_h_apply.py

137 lines
4.8 KiB
Python
Raw Normal View History

2014-05-13 13:57:58 +02:00
#!/usr/bin/env python
import os
file = open(os.environ["QPACKAGE_ROOT"]+'/src/Dets/H_apply_template.f','r')
template = file.read()
file.close()
keywords = """
subroutine
parameters
initialization
declarations
2014-05-21 16:37:54 +02:00
keys_work_locked
keys_work_unlocked
2014-05-13 13:57:58 +02:00
finalization
""".split()
2014-05-17 14:20:55 +02:00
class H_apply(object):
def __init__(self,sub,openmp=True):
s = {}
2014-05-13 13:57:58 +02:00
for k in keywords:
s[k] = ""
2014-05-17 14:20:55 +02:00
s["subroutine"] = "H_apply_%s"%(sub)
self.openmp = openmp
if openmp:
s["subroutine"] += "_OpenMP"
2014-05-21 16:37:54 +02:00
self.selection_pt2 = None
2014-05-17 14:20:55 +02:00
self.perturbation = None
2014-05-21 16:37:54 +02:00
2014-05-13 13:57:58 +02:00
#s["omp_parallel"] = """!$OMP PARALLEL DEFAULT(NONE) &
s["omp_parallel"] = """!$OMP PARALLEL DEFAULT(SHARED) &
!$OMP PRIVATE(i,j,k,l,keys_out,hole,particle, &
!$OMP occ_particle,occ_hole,j_a,k_a,other_spin, &
!$OMP hole_save,ispin,jj,l_a,hij_elec,hij_tab, &
!$OMP accu,i_a,hole_tmp,particle_tmp,occ_particle_tmp, &
!$OMP occ_hole_tmp,key_idx,i_b,j_b,key,N_elec_in_key_part_1,&
!$OMP N_elec_in_key_hole_1,N_elec_in_key_part_2, &
!$OMP N_elec_in_key_hole_2,ia_ja_pairs) &
2014-05-21 16:37:54 +02:00
!$OMP SHARED(key_in,N_int,elec_num_tab,mo_tot_num, &
2014-05-13 13:57:58 +02:00
!$OMP hole_1, particl_1, hole_2, particl_2, &
2014-05-17 14:20:55 +02:00
!$OMP lck,thresh,elec_alpha_num)"""
2014-05-13 13:57:58 +02:00
s["omp_init_lock"] = "call omp_init_lock(lck)"
s["omp_set_lock"] = "call omp_set_lock(lck)"
s["omp_unset_lock"] = "call omp_unset_lock(lck)"
s["omp_test_lock"] = "omp_test_lock(lck)"
s["omp_destroy_lock"] = "call omp_destroy_lock(lck)"
s["omp_end_parallel"] = "!$OMP END PARALLEL"
s["omp_master"] = "!$OMP MASTER"
s["omp_end_master"] = "!$OMP END MASTER"
s["omp_barrier"] = "!$OMP BARRIER"
s["omp_do"] = "!$OMP DO SCHEDULE (static)"
s["omp_enddo"] = "!$OMP ENDDO NOWAIT"
if not openmp:
for k in s:
s[k] = ""
s["omp_test_lock"] = ".False."
s["size_max"] = str(1024*128)
s["set_i_H_j_threshold"] = """
thresh = H_apply_threshold
"""
2014-05-17 14:20:55 +02:00
self.data = s
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __setitem__(self,key,value):
self.data[key] = value
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __getitem__(self,key):
return self.data[key]
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def __repr__(self):
2014-05-13 13:57:58 +02:00
buffer = template
2014-05-17 14:20:55 +02:00
for key,value in self.data.items():
buffer = buffer.replace('$'+key, value)
return buffer
2014-05-13 13:57:58 +02:00
2014-05-17 14:20:55 +02:00
def set_perturbation(self,pert):
2014-05-21 16:37:54 +02:00
if self.perturbation is not None:
raise
2014-05-17 14:20:55 +02:00
self.perturbation = pert
if pert is not None:
self.data["parameters"] = ",sum_e_2_pert_in,sum_norm_pert_in,sum_H_pert_diag_in,N_st,Nint"
self.data["declarations"] = """
integer, intent(in) :: N_st,Nint
double precision, intent(inout) :: sum_e_2_pert_in(N_st)
double precision, intent(inout) :: sum_norm_pert_in(N_st)
double precision, intent(inout) :: sum_H_pert_diag_in
double precision :: sum_e_2_pert(N_st)
double precision :: sum_norm_pert(N_st)
double precision :: sum_H_pert_diag
2014-05-21 16:37:54 +02:00
double precision :: e_2_pert_buffer(N_st,size_max)
double precision :: coef_pert_buffer(N_st,size_max)
2014-05-17 14:20:55 +02:00
"""
self.data["size_max"] = "256"
self.data["initialization"] = """
sum_e_2_pert = sum_e_2_pert_in
sum_norm_pert = sum_norm_pert_in
sum_H_pert_diag = sum_H_pert_diag_in
"""
2014-05-21 16:37:54 +02:00
self.data["keys_work_unlocked"] += """
call perturb_buffer_%s(keys_out,key_idx,e_2_pert_buffer,coef_pert_buffer,sum_e_2_pert, &
2014-05-17 14:20:55 +02:00
sum_norm_pert,sum_H_pert_diag,N_st,Nint)
"""%(pert,)
self.data["finalization"] = """
sum_e_2_pert_in = sum_e_2_pert
sum_norm_pert_in = sum_norm_pert
sum_H_pert_diag_in = sum_H_pert_diag
"""
if self.openmp:
2014-05-21 16:37:54 +02:00
self.data["omp_set_lock"] = ""
self.data["omp_unset_lock"] = ""
2014-05-17 14:20:55 +02:00
self.data["omp_test_lock"] = ".False."
self.data["omp_parallel"] += """&
2014-05-21 16:37:54 +02:00
!$OMP SHARED(N_st,Nint) PRIVATE(e_2_pert_buffer,coef_pert_buffer) &
2014-05-17 14:20:55 +02:00
!$OMP REDUCTION(+:sum_e_2_pert, sum_norm_pert, sum_H_pert_diag)"""
2014-05-13 13:57:58 +02:00
2014-05-21 16:37:54 +02:00
def set_selection_pt2(self,pert):
if self.selection_pt2 is not None:
raise
self.set_perturbation(pert)
self.selection_pt2 = pert
if pert is not None:
self.data["size_max"] = str(1024*128)
self.data["keys_work_unlocked"] = """
e_2_pert_buffer = 0.d0
coef_pert_buffer = 0.d0
""" + self.data["keys_work_unlocked"]
self.data["keys_work_locked"] = """
call fill_H_apply_buffer_selection(key_idx,keys_out,e_2_pert_buffer,coef_pert_buffer,N_st,N_int)
"""
self.data["omp_test_lock"] = "omp_test_lock(lck)"
self.data["omp_set_lock"] = "call omp_set_lock(lck)"
self.data["omp_unset_lock"] = "call omp_unset_lock(lck)"
2014-05-13 13:57:58 +02:00