3
0
mirror of https://github.com/triqs/dft_tools synced 2024-06-29 00:15:00 +02:00

edit SumKDFT class to take gf mesh at initialization and force Sigma to have that mesh

This commit is contained in:
Jonathan Karp 2021-08-27 17:31:07 -04:00 committed by Alexander Hampel
parent d735c4a3b7
commit ec8da69e34
2 changed files with 96 additions and 129 deletions

View File

@ -42,7 +42,7 @@ from scipy.optimize import minimize
class SumkDFT(object): class SumkDFT(object):
"""This class provides a general SumK method for combining ab-initio code and triqs.""" """This class provides a general SumK method for combining ab-initio code and triqs."""
def __init__(self, hdf_file, h_field=0.0, use_dft_blocks=False, def __init__(self, hdf_file, h_field=0.0, mesh=None, beta=40, n_iw=1025, use_dft_blocks=False,
dft_data='dft_input', symmcorr_data='dft_symmcorr_input', parproj_data='dft_parproj_input', dft_data='dft_input', symmcorr_data='dft_symmcorr_input', parproj_data='dft_parproj_input',
symmpar_data='dft_symmpar_input', bands_data='dft_bands_input', transp_data='dft_transp_input', symmpar_data='dft_symmpar_input', bands_data='dft_bands_input', transp_data='dft_transp_input',
misc_data='dft_misc_input',bc_data='dft_bandchar_input',fs_data='dft_fs_input'): misc_data='dft_misc_input',bc_data='dft_bandchar_input',fs_data='dft_fs_input'):
@ -99,6 +99,14 @@ class SumkDFT(object):
self.fs_data = fs_data self.fs_data = fs_data
self.h_field = h_field self.h_field = h_field
if mesh is None:
self.mesh = MeshImFreq(beta=beta, S='Fermion',n_max=n_iw)
elif isinstance(mesh, MeshImFreq) or isinstance(mesh, MeshReFreq):
self.mesh = mesh
else:
raise ValueError('mesh must be a triqs mesh of type MeshImFreq or MeshReFreq')
self.block_structure = BlockStructure() self.block_structure = BlockStructure()
# Read input from HDF: # Read input from HDF:
@ -467,7 +475,7 @@ class SumkDFT(object):
return gf_rotated return gf_rotated
def lattice_gf(self, ik, mu=None, iw_or_w="iw", beta=40, broadening=None, mesh=None, with_Sigma=True, with_dc=True): def lattice_gf(self, ik, mu=None, broadening=None, mesh=None, with_Sigma=True, with_dc=True):
r""" r"""
Calculates the lattice Green function for a given k-point from the DFT Hamiltonian and the self energy. Calculates the lattice Green function for a given k-point from the DFT Hamiltonian and the self energy.
@ -508,9 +516,7 @@ class SumkDFT(object):
mu = self.chemical_potential mu = self.chemical_potential
ntoi = self.spin_names_to_ind[self.SO] ntoi = self.spin_names_to_ind[self.SO]
spn = self.spin_block_names[self.SO] spn = self.spin_block_names[self.SO]
if (iw_or_w != "iw") and (iw_or_w != "w"): if self.Sigma_inp is None:
raise ValueError("lattice_gf: Implemented only for Re/Im frequency functions.")
if not hasattr(self, "Sigma_imp_" + iw_or_w):
with_Sigma = False with_Sigma = False
if broadening is None: if broadening is None:
if mesh is None: if mesh is None:
@ -518,45 +524,32 @@ class SumkDFT(object):
else: # broadening = 2 * \Delta omega, where \Delta omega is the spacing of omega points else: # broadening = 2 * \Delta omega, where \Delta omega is the spacing of omega points
broadening = 2.0 * ((mesh[1] - mesh[0]) / (mesh[2] - 1)) broadening = 2.0 * ((mesh[1] - mesh[0]) / (mesh[2] - 1))
# Are we including Sigma?
if with_Sigma:
Sigma_imp = getattr(self, "Sigma_imp_" + iw_or_w)
sigma_minus_dc = [s.copy() for s in Sigma_imp]
if with_dc:
sigma_minus_dc = self.add_dc(iw_or_w)
if iw_or_w == "iw":
# override beta if Sigma_iw is present
beta = Sigma_imp[0].mesh.beta
mesh = Sigma_imp[0].mesh
elif iw_or_w == "w":
mesh = Sigma_imp[0].mesh
if broadening>0 and mpi.is_master_node():
warn('lattice_gf called with Sigma and broadening > 0 (broadening = {}). You might want to explicitly set the broadening to 0.'.format(broadening))
else:
if iw_or_w == "iw":
if beta is None:
raise ValueError("lattice_gf: Give the beta for the lattice GfReFreq.")
# Default number of Matsubara frequencies
mesh = MeshImFreq(beta=beta, S='Fermion', n_max=1025)
elif iw_or_w == "w":
if mesh is None:
raise ValueError("lattice_gf: Give the mesh=(om_min,om_max,n_points) for the lattice GfReFreq.")
mesh = MeshReFreq(mesh[0], mesh[1], mesh[2])
# Check if G_latt is present # Check if G_latt is present
set_up_G_latt = False # Assume not set_up_G_latt = False # Assume not
if not hasattr(self, "G_latt_" + iw_or_w): if not hasattr(self, "G_latt" ):
# Need to create G_latt_(i)w # Need to create G_latt_(i)w
set_up_G_latt = True set_up_G_latt = True
else: # Check that existing GF is consistent else: # Check that existing GF is consistent
G_latt = getattr(self, "G_latt_" + iw_or_w) G_latt = self.G_latt
GFsize = [gf.target_shape[0] for bname, gf in G_latt] GFsize = [gf.target_shape[0] for bname, gf in G_latt]
unchangedsize = all([self.n_orbitals[ik, ntoi[spn[isp]]] == GFsize[ unchangedsize = all([self.n_orbitals[ik, ntoi[spn[isp]]] == GFsize[
isp] for isp in range(self.n_spin_blocks[self.SO])]) isp] for isp in range(self.n_spin_blocks[self.SO])])
if not unchangedsize: if (not mesh is None) or (not unchangedsize):
set_up_G_latt = True set_up_G_latt = True
if (iw_or_w == "iw") and (self.G_latt_iw.mesh.beta != beta):
set_up_G_latt = True # additional check for ImFreq # Are we including Sigma?
if with_Sigma:
Sigma_imp = self.Sigma_imp
sigma_minus_dc = [s.copy() for s in Sigma_imp]
if with_dc:
sigma_minus_dc = self.add_dc()
if isinstance(self.mesh, MeshReFreq) and broadening > 0 and mpi.is_master_node():
warn('lattice_gf called with Sigma and broadening > 0 (broadening = {}). You might want to explicitly set the broadening to 0.'.format(broadening))
elif not mesh is None:
mesh = MeshReFreq(mesh[0], mesh[1], mesh[2])
if mesh is None:
mesh = self.mesh
# Set up G_latt # Set up G_latt
if set_up_G_latt: if set_up_G_latt:
@ -565,19 +558,19 @@ class SumkDFT(object):
gf_struct = [(spn[isp], block_structure[isp]) gf_struct = [(spn[isp], block_structure[isp])
for isp in range(self.n_spin_blocks[self.SO])] for isp in range(self.n_spin_blocks[self.SO])]
block_ind_list = [block for block, inner in gf_struct] block_ind_list = [block for block, inner in gf_struct]
if iw_or_w == "iw": if isinstance(mesh, MeshImFreq):
glist = lambda: [GfImFreq(indices=inner, mesh=mesh) glist = lambda: [GfImFreq(indices=inner, mesh=mesh)
for block, inner in gf_struct] for block, inner in gf_struct]
elif iw_or_w == "w": else:
glist = lambda: [GfReFreq(indices=inner, mesh=mesh) glist = lambda: [GfReFreq(indices=inner, mesh=mesh)
for block, inner in gf_struct] for block, inner in gf_struct]
G_latt = BlockGf(name_list=block_ind_list, G_latt = BlockGf(name_list=block_ind_list,
block_list=glist(), make_copies=False) block_list=glist(), make_copies=False)
G_latt.zero() G_latt.zero()
if iw_or_w == "iw": if isinstance(mesh, MeshImFreq):
G_latt << iOmega_n G_latt << iOmega_n
elif iw_or_w == "w": else:
G_latt << Omega + 1j * broadening G_latt << Omega + 1j * broadening
idmat = [numpy.identity( idmat = [numpy.identity(
@ -599,7 +592,7 @@ class SumkDFT(object):
sigma_minus_dc[icrsh][bname], gf) sigma_minus_dc[icrsh][bname], gf)
G_latt.invert() G_latt.invert()
setattr(self, "G_latt_" + iw_or_w, G_latt) self.G_latt = G_latt
return G_latt return G_latt
@ -629,19 +622,19 @@ class SumkDFT(object):
assert len(Sigma_imp) == self.n_corr_shells,\ assert len(Sigma_imp) == self.n_corr_shells,\
"put_Sigma: give exactly one Sigma for each corr. shell!" "put_Sigma: give exactly one Sigma for each corr. shell!"
if all((isinstance(gf, Gf) and isinstance(gf.mesh, MeshImFreq)) for bname, gf in Sigma_imp[0]): if isinstance(self.mesh, MeshImFreq) and all(isinstance(gf, Gf) and gf.mesh == self.mesh for bname, gf in Sigma_imp[0]):
# Imaginary frequency Sigma: # Imaginary frequency Sigma:
self.Sigma_imp_iw = [self.block_structure.create_gf(ish=icrsh, mesh=Sigma_imp[icrsh].mesh, space='sumk') self.Sigma_imp = [self.block_structure.create_gf(ish=icrsh, mesh=Sigma_imp[icrsh].mesh, space='sumk')
for icrsh in range(self.n_corr_shells)] for icrsh in range(self.n_corr_shells)]
SK_Sigma_imp = self.Sigma_imp_iw SK_Sigma_imp = self.Sigma_imp
elif all(isinstance(gf, Gf) and isinstance(gf.mesh, MeshReFreq) for bname, gf in Sigma_imp[0]): elif isinstance(self.mesh, MeshReFreq) and all(isinstance(gf, Gf) and gf.mesh == self.mesh for bname, gf in Sigma_imp[0]):
# Real frequency Sigma: # Real frequency Sigma:
self.Sigma_imp_w = [self.block_structure.create_gf(ish=icrsh, mesh=Sigma_imp[icrsh].mesh, gf_function=GfReFreq, space='sumk') self.Sigma_imp = [self.block_structure.create_gf(ish=icrsh, mesh=Sigma_imp[icrsh].mesh, gf_function=GfReFreq, space='sumk')
for icrsh in range(self.n_corr_shells)] for icrsh in range(self.n_corr_shells)]
SK_Sigma_imp = self.Sigma_imp_w SK_Sigma_imp = self.Sigma_imp
else: else:
raise ValueError("put_Sigma: This type of Sigma is not handled, give either BlockGf of GfReFreq or GfImFreq.") raise ValueError("put_Sigma: Sigma_imp must have the same mesh as SumKDFT.mesh.")
# rotation from local to global coordinate system: # rotation from local to global coordinate system:
for icrsh in range(self.n_corr_shells): for icrsh in range(self.n_corr_shells):
@ -654,13 +647,12 @@ class SumkDFT(object):
gf << Sigma_imp[icrsh][bname] gf << Sigma_imp[icrsh][bname]
#warning if real frequency self energy is within the bounds of the band energies #warning if real frequency self energy is within the bounds of the band energies
if isinstance(Sigma_imp[0].mesh, MeshReFreq): if isinstance(self.mesh, MeshReFreq):
if self.min_band_energy is None or self.max_band_energy is None: if self.min_band_energy is None or self.max_band_energy is None:
self.calculate_min_max_band_energies() self.calculate_min_max_band_energies()
for gf in Sigma_imp: mesh = numpy.array([i for i in self.mesh.values()])
Sigma_mesh = numpy.array([i for i in gf.mesh.values()]) if mesh[0] > (self.min_band_energy - self.chemical_potential) or mesh[-1] < (self.max_band_energy - self.chemical_potential):
if Sigma_mesh[0] > (self.min_band_energy - self.chemical_potential) or Sigma_mesh[-1] < (self.max_band_energy - self.chemical_potential): warn('The given Sigma is on a mesh which does not cover the band energy range. The Sigma MeshReFreq runs from %f to %f, while the band energy (minus the chemical potential) runs from %f to %f'%(Sigma_mesh[0], Sigma_mesh[-1], self.min_band_energy, self.max_band_energy))
warn('The given Sigma is on a mesh which does not cover the band energy range. The Sigma MeshReFreq runs from %f to %f, while the band energy (minus the chemical potential) runs from %f to %f'%(Sigma_mesh[0], Sigma_mesh[-1], self.min_band_energy, self.max_band_energy))
def transform_to_sumk_blocks(self, Sigma_imp, Sigma_out=None): def transform_to_sumk_blocks(self, Sigma_imp, Sigma_out=None):
r""" transform Sigma from solver to sumk space r""" transform Sigma from solver to sumk space
@ -703,7 +695,7 @@ class SumkDFT(object):
G_out=Sigma_out[icrsh]) G_out=Sigma_out[icrsh])
return Sigma_out return Sigma_out
def extract_G_loc(self, mu=None, iw_or_w='iw', with_Sigma=True, with_dc=True, broadening=None, def extract_G_loc(self, mu=None, with_Sigma=True, with_dc=True, broadening=None,
transform_to_solver_blocks=True, show_warnings=True): transform_to_solver_blocks=True, show_warnings=True):
r""" r"""
Extracts the local downfolded Green function by the Brillouin-zone integration of the lattice Green's function. Extracts the local downfolded Green function by the Brillouin-zone integration of the lattice Green's function.
@ -739,14 +731,14 @@ class SumkDFT(object):
if mu is None: if mu is None:
mu = self.chemical_potential mu = self.chemical_potential
if iw_or_w == "iw": if isinstance(self.mesh, MeshImFreq):
G_loc = [self.Sigma_imp_iw[icrsh].copy() for icrsh in range( G_loc = [self.Sigma_imp[icrsh].copy() for icrsh in range(
self.n_corr_shells)] # this list will be returned self.n_corr_shells)] # this list will be returned
beta = G_loc[0].mesh.beta beta = G_loc[0].mesh.beta
G_loc_inequiv = [BlockGf(name_block_generator=[(block, GfImFreq(target_shape=(block_dim, block_dim), mesh=G_loc[0].mesh)) for block, block_dim in self.gf_struct_solver[ish].items()], G_loc_inequiv = [BlockGf(name_block_generator=[(block, GfImFreq(target_shape=(block_dim, block_dim), mesh=G_loc[0].mesh)) for block, block_dim in self.gf_struct_solver[ish].items()],
make_copies=False) for ish in range(self.n_inequiv_shells)] make_copies=False) for ish in range(self.n_inequiv_shells)]
elif iw_or_w == "w": else:
G_loc = [self.Sigma_imp_w[icrsh].copy() for icrsh in range( G_loc = [self.Sigma_imp[icrsh].copy() for icrsh in range(
self.n_corr_shells)] # this list will be returned self.n_corr_shells)] # this list will be returned
mesh = G_loc[0].mesh mesh = G_loc[0].mesh
G_loc_inequiv = [BlockGf(name_block_generator=[(block, GfReFreq(target_shape=(block_dim, block_dim), mesh=mesh)) for block, block_dim in self.gf_struct_solver[ish].items()], G_loc_inequiv = [BlockGf(name_block_generator=[(block, GfReFreq(target_shape=(block_dim, block_dim), mesh=mesh)) for block, block_dim in self.gf_struct_solver[ish].items()],
@ -757,13 +749,12 @@ class SumkDFT(object):
ikarray = numpy.array(list(range(self.n_k))) ikarray = numpy.array(list(range(self.n_k)))
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
if iw_or_w == 'iw': if isinstance(self.mesh, MeshImFreq):
G_latt = self.lattice_gf( G_latt = self.lattice_gf(
ik=ik, mu=mu, iw_or_w=iw_or_w, with_Sigma=with_Sigma, with_dc=with_dc, beta=beta) ik=ik, mu=mu, with_Sigma=with_Sigma, with_dc=with_dc)
elif iw_or_w == 'w': else:
mesh_parameters = (G_loc[0].mesh.omega_min,G_loc[0].mesh.omega_max,len(G_loc[0].mesh))
G_latt = self.lattice_gf( G_latt = self.lattice_gf(
ik=ik, mu=mu, iw_or_w=iw_or_w, with_Sigma=with_Sigma, with_dc=with_dc, broadening=broadening, mesh=mesh_parameters) ik=ik, mu=mu, with_Sigma=with_Sigma, with_dc=with_dc, broadening=broadening)
G_latt *= self.bz_weights[ik] G_latt *= self.bz_weights[ik]
for icrsh in range(self.n_corr_shells): for icrsh in range(self.n_corr_shells):
@ -1446,7 +1437,7 @@ class SumkDFT(object):
return trans return trans
def density_matrix(self, method='using_gf', beta=40.0): def density_matrix(self, method='using_gf'):
"""Calculate density matrices in one of two ways. """Calculate density matrices in one of two ways.
Parameters Parameters
@ -1477,8 +1468,7 @@ class SumkDFT(object):
if method == "using_gf": if method == "using_gf":
G_latt_iw = self.lattice_gf( G_latt_iw = self.lattice_gf(ik=ik, mu=self.chemical_potential)
ik=ik, mu=self.chemical_potential, iw_or_w="iw", beta=beta)
G_latt_iw *= self.bz_weights[ik] G_latt_iw *= self.bz_weights[ik]
dm = G_latt_iw.density() dm = G_latt_iw.density()
MMat = [dm[sp] for sp in self.spin_block_names[self.SO]] MMat = [dm[sp] for sp in self.spin_block_names[self.SO]]
@ -1777,7 +1767,7 @@ class SumkDFT(object):
self.dc_imp[icrsh][sp] = numpy.dot(T.conjugate().transpose(), self.dc_imp[icrsh][sp] = numpy.dot(T.conjugate().transpose(),
numpy.dot(self.dc_imp[icrsh][sp], T)) numpy.dot(self.dc_imp[icrsh][sp], T))
def add_dc(self, iw_or_w="iw"): def add_dc(self):
r""" r"""
Subtracts the double counting term from the impurity self energy. Subtracts the double counting term from the impurity self energy.
@ -1796,8 +1786,7 @@ class SumkDFT(object):
""" """
# Be careful: Sigma_imp is already in the global coordinate system!! # Be careful: Sigma_imp is already in the global coordinate system!!
sigma_minus_dc = [s.copy() sigma_minus_dc = [s.copy() for s in self.Sigma_imp]
for s in getattr(self, "Sigma_imp_" + iw_or_w)]
for icrsh in range(self.n_corr_shells): for icrsh in range(self.n_corr_shells):
for bname, gf in sigma_minus_dc[icrsh]: for bname, gf in sigma_minus_dc[icrsh]:
# Transform dc_imp to global coordinate system # Transform dc_imp to global coordinate system
@ -1869,7 +1858,7 @@ class SumkDFT(object):
else: else:
gf_to_symm[key].from_L_G_R(v, ss, v.conjugate().transpose()) gf_to_symm[key].from_L_G_R(v, ss, v.conjugate().transpose())
def total_density(self, mu=None, iw_or_w="iw", with_Sigma=True, with_dc=True, broadening=None): def total_density(self, mu=None, with_Sigma=True, with_dc=True, broadening=None):
r""" r"""
Calculates the total charge within the energy window for a given chemical potential. Calculates the total charge within the energy window for a given chemical potential.
The chemical potential is either given by parameter `mu` or, if it is not specified, The chemical potential is either given by parameter `mu` or, if it is not specified,
@ -1922,7 +1911,7 @@ class SumkDFT(object):
ikarray = numpy.array(list(range(self.n_k))) ikarray = numpy.array(list(range(self.n_k)))
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt = self.lattice_gf( G_latt = self.lattice_gf(
ik=ik, mu=mu, iw_or_w=iw_or_w, with_Sigma=with_Sigma, with_dc=with_dc, broadening=broadening) ik=ik, mu=mu, with_Sigma=with_Sigma, with_dc=with_dc, broadening=broadening)
dens += self.bz_weights[ik] * G_latt.total_density() dens += self.bz_weights[ik] * G_latt.total_density()
# collect data from mpi: # collect data from mpi:
dens = mpi.all_reduce(mpi.world, dens, lambda x, y: x + y) dens = mpi.all_reduce(mpi.world, dens, lambda x, y: x + y)

View File

@ -41,17 +41,17 @@ class SumkDFTTools(SumkDFT):
Extends the SumkDFT class with some tools for analysing the data. Extends the SumkDFT class with some tools for analysing the data.
""" """
def __init__(self, hdf_file, h_field=0.0, use_dft_blocks=False, dft_data='dft_input', symmcorr_data='dft_symmcorr_input', def __init__(self, hdf_file, h_field=0.0, mesh=None, bets=40, n_iw=1025, use_dft_blocks=False, dft_data='dft_input', symmcorr_data='dft_symmcorr_input',
parproj_data='dft_parproj_input', symmpar_data='dft_symmpar_input', bands_data='dft_bands_input', parproj_data='dft_parproj_input', symmpar_data='dft_symmpar_input', bands_data='dft_bands_input',
transp_data='dft_transp_input', misc_data='dft_misc_input'): transp_data='dft_transp_input', misc_data='dft_misc_input'):
""" """
Initialisation of the class. Parameters are exactly as for SumKDFT. Initialisation of the class. Parameters are exactly as for SumKDFT.
""" """
SumkDFT.__init__(self, hdf_file=hdf_file, h_field=h_field, use_dft_blocks=use_dft_blocks, SumkDFT.__init__(self, hdf_file=hdf_file, h_field=h_field, mesh=mesh, beta=beta, n_iw=n_iw,
dft_data=dft_data, symmcorr_data=symmcorr_data, parproj_data=parproj_data, use_dft_blocks=use_dft_blocks, dft_data=dft_data, symmcorr_data=symmcorr_data,
symmpar_data=symmpar_data, bands_data=bands_data, transp_data=transp_data, parproj_data=parproj_data, symmpar_data=symmpar_data, bands_data=bands_data,
misc_data=misc_data) transp_data=transp_data, misc_data=misc_data)
# Uses .data of only GfReFreq objects. # Uses .data of only GfReFreq objects.
def dos_wannier_basis(self, mu=None, broadening=None, mesh=None, with_Sigma=True, with_dc=True, save_to_file=True): def dos_wannier_basis(self, mu=None, broadening=None, mesh=None, with_Sigma=True, with_dc=True, save_to_file=True):
@ -82,10 +82,9 @@ class SumkDFTTools(SumkDFT):
DOSproj_orb : Dict of numpy arrays DOSproj_orb : Dict of numpy arrays
DOS projected to atoms and resolved into orbital contributions. DOS projected to atoms and resolved into orbital contributions.
""" """
if (mesh is None) and (not with_Sigma): if mesh is None or with_Sigma:
raise ValueError("lattice_gf: Give the mesh=(om_min,om_max,n_points) for the lattice GfReFreq.") assert isinstance(self.mesh, MeshReFreq), "mesh must be given if self.mesh is a MeshImFreq"
if mesh is None: om_mesh = [x.real for x in self.mesh]
om_mesh = [x.real for x in self.Sigma_imp_w[0].mesh]
om_min = om_mesh[0] om_min = om_mesh[0]
om_max = om_mesh[-1] om_max = om_mesh[-1]
n_om = len(om_mesh) n_om = len(om_mesh)
@ -119,7 +118,7 @@ class SumkDFTTools(SumkDFT):
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(
ik=ik, mu=mu, iw_or_w="w", broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc) ik=ik, mu=mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc)
G_latt_w *= self.bz_weights[ik] G_latt_w *= self.bz_weights[ik]
# Non-projected DOS # Non-projected DOS
@ -218,10 +217,9 @@ class SumkDFTTools(SumkDFT):
DOSproj_orb : Dict of numpy arrays DOSproj_orb : Dict of numpy arrays
DOS projected to atoms and resolved into orbital contributions. DOS projected to atoms and resolved into orbital contributions.
""" """
if (mesh is None) and (not with_Sigma): if mesh is None or with_Sigma:
raise ValueError("lattice_gf: Give the mesh=(om_min,om_max,n_points) for the lattice GfReFreq.") assert isinstance(self.mesh, MeshReFreq), "mesh must be given if self.mesh is a MeshImFreq"
if mesh is None: om_mesh = [x.real for x in self.mesh]
om_mesh = [x.real for x in self.Sigma_imp_w[0].mesh]
om_min = om_mesh[0] om_min = om_mesh[0]
om_max = om_mesh[-1] om_max = om_mesh[-1]
n_om = len(om_mesh) n_om = len(om_mesh)
@ -255,7 +253,7 @@ class SumkDFTTools(SumkDFT):
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(
ik=ik, mu=mu, iw_or_w="w", broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc) ik=ik, mu=mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc)
G_latt_w *= self.bz_weights[ik] G_latt_w *= self.bz_weights[ik]
# Non-projected DOS # Non-projected DOS
@ -350,10 +348,9 @@ class SumkDFTTools(SumkDFT):
if self.symm_op: if self.symm_op:
self.symmpar = Symmetry(self.hdf_file, subgroup=self.symmpar_data) self.symmpar = Symmetry(self.hdf_file, subgroup=self.symmpar_data)
if (mesh is None) and (not with_Sigma): if mesh is None or with_Sigma:
raise ValueError("lattice_gf: Give the mesh=(om_min,om_max,n_points) for the lattice GfReFreq.") assert isinstance(self.mesh, MeshReFreq), "mesh must be given if self.mesh is a MeshImFreq"
if mesh is None: om_mesh = [x.real for x in self.mesh]
om_mesh = [x.real for x in self.Sigma_imp_w[0].mesh]
om_min = om_mesh[0] om_min = om_mesh[0]
om_max = om_mesh[-1] om_max = om_mesh[-1]
n_om = len(om_mesh) n_om = len(om_mesh)
@ -389,7 +386,7 @@ class SumkDFTTools(SumkDFT):
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(
ik=ik, mu=mu, iw_or_w="w", broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc) ik=ik, mu=mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc)
G_latt_w *= self.bz_weights[ik] G_latt_w *= self.bz_weights[ik]
# Non-projected DOS # Non-projected DOS
@ -502,10 +499,9 @@ class SumkDFTTools(SumkDFT):
if not value_read: if not value_read:
return value_read return value_read
if (mesh is None) and (not with_Sigma): if mesh is None or with_Sigma:
raise ValueError("lattice_gf: Give the mesh=(om_min,om_max,n_points) for the lattice GfReFreq.") assert isinstance(self.mesh, MeshReFreq), "mesh must be given if self.mesh is a MeshImFreq"
if mesh is None: om_mesh = [x.real for x in self.mesh]
om_mesh = [x.real for x in self.Sigma_imp_w[0].mesh]
om_min = om_mesh[0] om_min = om_mesh[0]
om_max = om_mesh[-1] om_max = om_mesh[-1]
n_om = len(om_mesh) n_om = len(om_mesh)
@ -532,7 +528,7 @@ class SumkDFTTools(SumkDFT):
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(
ik=ik, mu=mu, iw_or_w="w", broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc) ik=ik, mu=mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc)
G_latt_w *= self.bz_weights[ik] G_latt_w *= self.bz_weights[ik]
if(nk!=None): if(nk!=None):
for iom in range(n_om): for iom in range(n_om):
@ -667,8 +663,9 @@ class SumkDFTTools(SumkDFT):
if not value_read: if not value_read:
return value_read return value_read
if with_Sigma is True: if with_Sigma is True or mesh is None:
om_mesh = [x.real for x in self.Sigma_imp_w[0].mesh] assert isinstance(self.mesh, MeshReFreq), "SumkDFT.mesh must be real if with_Sigma is True or mesh is not given"
om_mesh = [x.real for x in self.mesh]
#for Fermi Surface calculations #for Fermi Surface calculations
if FS: if FS:
jw=[i for i in range(len(om_mesh)) if om_mesh[i] == 0.0] jw=[i for i in range(len(om_mesh)) if om_mesh[i] == 0.0]
@ -690,17 +687,6 @@ class SumkDFTTools(SumkDFT):
mesh = (om_min, om_max, n_om) mesh = (om_min, om_max, n_om)
if broadening is None: if broadening is None:
broadening=0.0 broadening=0.0
elif mesh is None:
#default is to set "mesh" to be just for the Fermi surface - omega=0.0
om_min = 0.000
om_max = 0.001
n_om = 3
mesh = (om_min, om_max, n_om)
om_mesh = numpy.linspace(om_min, om_max, n_om)
if broadening is None:
broadening=0.01
FS=True
jw=[i for i in range(len(om_mesh)) if((om_mesh[i]<=om_max)and(om_mesh[i]>=om_min))]
else: else:
#a range of frequencies can be used if desired #a range of frequencies can be used if desired
om_min, om_max, n_om = mesh om_min, om_max, n_om = mesh
@ -732,7 +718,7 @@ class SumkDFTTools(SumkDFT):
vkc[ik,:] = numpy.matmul(self.bmat,self.vkl[ik,:]) vkc[ik,:] = numpy.matmul(self.bmat,self.vkl[ik,:])
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(
ik=ik, mu=mu, iw_or_w="w", broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc) ik=ik, mu=mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma, with_dc=with_dc)
for iom in range(n_om): for iom in range(n_om):
for bname, gf in G_latt_w: for bname, gf in G_latt_w:
@ -854,7 +840,8 @@ class SumkDFTTools(SumkDFT):
if mu is None: if mu is None:
mu = self.chemical_potential mu = self.chemical_potential
spn = self.spin_block_names[self.SO] spn = self.spin_block_names[self.SO]
mesh = numpy.array([x.real for x in self.Sigma_imp_w[0].mesh]) mesh = [x.real for x in self.mesh]
n_om = len(mesh)
if plot_range is None: if plot_range is None:
om_minplot = mesh[0] - 0.001 om_minplot = mesh[0] - 0.001
@ -881,8 +868,7 @@ class SumkDFTTools(SumkDFT):
ikarray = numpy.array(list(range(self.n_k))) ikarray = numpy.array(list(range(self.n_k)))
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_w = self.lattice_gf( G_latt_w = self.lattice_gf(ik=ik, mu=mu, broadening=broadening)
ik=ik, mu=mu, iw_or_w="w", broadening=broadening)
if ishell is None: if ishell is None:
# Non-projected A(k,w) # Non-projected A(k,w)
@ -953,7 +939,7 @@ class SumkDFTTools(SumkDFT):
return Akw return Akw
def partial_charges(self, beta=40, mu=None, with_Sigma=True, with_dc=True): def partial_charges(self, mu=None, with_Sigma=True, with_dc=True):
""" """
Calculates the orbitally-resolved density matrix for all the orbitals considered in the input, consistent with Calculates the orbitally-resolved density matrix for all the orbitals considered in the input, consistent with
the definition of Wien2k. Hence, (possibly non-orthonormal) projectors have to be provided in the partial projectors subgroup of the definition of Wien2k. Hence, (possibly non-orthonormal) projectors have to be provided in the partial projectors subgroup of
@ -995,23 +981,16 @@ class SumkDFTTools(SumkDFT):
# Set up G_loc # Set up G_loc
gf_struct_parproj = [[(sp, self.shells[ish]['dim']) for sp in spn] gf_struct_parproj = [[(sp, self.shells[ish]['dim']) for sp in spn]
for ish in range(self.n_shells)] for ish in range(self.n_shells)]
if with_Sigma: G_loc = [BlockGf(name_block_generator=[(block, GfImFreq(target_shape=(block_dim, block_dim), mesh=self.mesh))
G_loc = [BlockGf(name_block_generator=[(block, GfImFreq(target_shape=(block_dim, block_dim), mesh=self.Sigma_imp_iw[0].mesh)) for block, block_dim in gf_struct_parproj[ish]], make_copies=False)
for block, block_dim in gf_struct_parproj[ish]], make_copies=False) for ish in range(self.n_shells)]
for ish in range(self.n_shells)]
beta = self.Sigma_imp_iw[0].mesh.beta
else:
G_loc = [BlockGf(name_block_generator=[(block, GfImFreq(target_shape=(block_dim, block_dim), beta=beta))
for block, block_dim in gf_struct_parproj[ish]], make_copies=False)
for ish in range(self.n_shells)]
for ish in range(self.n_shells): for ish in range(self.n_shells):
G_loc[ish].zero() G_loc[ish].zero()
ikarray = numpy.array(list(range(self.n_k))) ikarray = numpy.array(list(range(self.n_k)))
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
G_latt_iw = self.lattice_gf( G_latt_iw = self.lattice_gf(ik=ik, mu=mu, with_Sigma=with_Sigma, with_dc=with_dc)
ik=ik, mu=mu, iw_or_w="iw", beta=beta, with_Sigma=with_Sigma, with_dc=with_dc)
G_latt_iw *= self.bz_weights[ik] G_latt_iw *= self.bz_weights[ik]
for ish in range(self.n_shells): for ish in range(self.n_shells):
tmp = G_loc[ish].copy() tmp = G_loc[ish].copy()
@ -1207,7 +1186,7 @@ class SumkDFTTools(SumkDFT):
# Define mesh for Green's function and in the specified energy window # Define mesh for Green's function and in the specified energy window
if (with_Sigma == True): if (with_Sigma == True):
self.omega = numpy.array([round(x.real, 12) self.omega = numpy.array([round(x.real, 12)
for x in self.Sigma_imp_w[0].mesh]) for x in self.mesh])
mesh = None mesh = None
mu = self.chemical_potential mu = self.chemical_potential
n_om = len(self.omega) n_om = len(self.omega)
@ -1225,13 +1204,13 @@ class SumkDFTTools(SumkDFT):
# In the future there should be an option in gf to manipulate the mesh (e.g. truncate) directly. # In the future there should be an option in gf to manipulate the mesh (e.g. truncate) directly.
# For now we stick with this: # For now we stick with this:
for icrsh in range(self.n_corr_shells): for icrsh in range(self.n_corr_shells):
Sigma_save = self.Sigma_imp_w[icrsh].copy() Sigma_save = self.Sigma_imp[icrsh].copy()
spn = self.spin_block_names[self.corr_shells[icrsh]['SO']] spn = self.spin_block_names[self.corr_shells[icrsh]['SO']]
glist = lambda: [GfReFreq(target_shape=(block_dim, block_dim), window=(self.omega[ glist = lambda: [GfReFreq(target_shape=(block_dim, block_dim), window=(self.omega[
0], self.omega[-1]), n_points=n_om) for block, block_dim in self.gf_struct_sumk[icrsh]] 0], self.omega[-1]), n_points=n_om) for block, block_dim in self.gf_struct_sumk[icrsh]]
self.Sigma_imp_w[icrsh] = BlockGf( self.Sigma_imp_w[icrsh] = BlockGf(
name_list=spn, block_list=glist(), make_copies=False) name_list=spn, block_list=glist(), make_copies=False)
for i, g in self.Sigma_imp_w[icrsh]: for i, g in self.Sigma_imp[icrsh]:
for iL in g.indices[0]: for iL in g.indices[0]:
for iR in g.indices[0]: for iR in g.indices[0]:
for iom in range(n_om): for iom in range(n_om):
@ -1267,8 +1246,7 @@ class SumkDFTTools(SumkDFT):
ikarray = numpy.array(list(range(self.n_k))) ikarray = numpy.array(list(range(self.n_k)))
for ik in mpi.slice_array(ikarray): for ik in mpi.slice_array(ikarray):
# Calculate G_w for ik and initialize A_kw # Calculate G_w for ik and initialize A_kw
G_w = self.lattice_gf(ik, mu, iw_or_w="w", beta=beta, G_w = self.lattice_gf(ik, mu, broadening=broadening, mesh=mesh, with_Sigma=with_Sigma)
broadening=broadening, mesh=mesh, with_Sigma=with_Sigma)
A_kw = [numpy.zeros((self.n_orbitals[ik][isp], self.n_orbitals[ik][isp], n_om), dtype=numpy.complex_) A_kw = [numpy.zeros((self.n_orbitals[ik][isp], self.n_orbitals[ik][isp], n_om), dtype=numpy.complex_)
for isp in range(n_inequiv_spin_blocks)] for isp in range(n_inequiv_spin_blocks)]