3
0
mirror of https://github.com/triqs/dft_tools synced 2024-11-07 22:53:50 +01:00

Modified ProjectorShell object accordingly

* Modified ProjectorShell to retrieve dictionary 'ions' from
  the input and construct a list of equivalence classes (ion sorts).
This commit is contained in:
Oleg Peil 2018-05-04 16:06:31 +02:00 committed by Manuel
parent 7471691219
commit 0fa24a28ef

View File

@ -70,7 +70,7 @@ class ProjectorShell:
""" """
def __init__(self, sh_pars, proj_raw, proj_params, kmesh, structure, nc_flag): def __init__(self, sh_pars, proj_raw, proj_params, kmesh, structure, nc_flag):
self.lorb = sh_pars['lshell'] self.lorb = sh_pars['lshell']
self.ion_list = sh_pars['ion_list'] self.ions = sh_pars['ions']
self.user_index = sh_pars['user_index'] self.user_index = sh_pars['user_index']
self.nc_flag = nc_flag self.nc_flag = nc_flag
# try: # try:
@ -81,8 +81,17 @@ class ProjectorShell:
self.lm1 = self.lorb**2 self.lm1 = self.lorb**2
self.lm2 = (self.lorb+1)**2 self.lm2 = (self.lorb+1)**2
self.nion = self.ions['nion']
# Extract ion list and equivalence classes (ion sorts)
self.ion_list = sorted(it.chain(*self.ions['ion_list']))
self.ion_sort = []
for ion in self.ion_list:
for icl, eq_cl in enumerate(self.ions['ion_list']):
if ion in eq_cl:
self.ion_sort.append(icl + 1) # Enumerate classes starting from 1
break
self.ndim = self.extract_tmatrices(sh_pars) self.ndim = self.extract_tmatrices(sh_pars)
self.nion = len(self.ion_list)
self.extract_projectors(proj_raw, proj_params, kmesh, structure) self.extract_projectors(proj_raw, proj_params, kmesh, structure)
@ -106,7 +115,7 @@ class ProjectorShell:
Flag 'self.do_transform' is introduced for the optimization purposes Flag 'self.do_transform' is introduced for the optimization purposes
to avoid superfluous matrix multiplications. to avoid superfluous matrix multiplications.
""" """
nion = len(self.ion_list) nion = self.nion
nm = self.lm2 - self.lm1 nm = self.lm2 - self.lm1
if 'tmatrices' in sh_pars: if 'tmatrices' in sh_pars:
@ -213,7 +222,8 @@ class ProjectorShell:
""" """
assert self.nc_flag == False, "Non-collinear case is not implemented" assert self.nc_flag == False, "Non-collinear case is not implemented"
nion = len(self.ion_list) # nion = len(self.ion_list)
nion = self.nion
nlm = self.lm2 - self.lm1 nlm = self.lm2 - self.lm1
_, ns, nk, nb = proj_raw.shape _, ns, nk, nb = proj_raw.shape