3
0
mirror of https://github.com/triqs/dft_tools synced 2024-12-23 04:43:42 +01:00

vectorize loop over frequencies in spaghettis

This commit is contained in:
Jonathan Karp 2021-10-13 17:37:33 -04:00
parent baec3b2f31
commit 3b9a9dab9c
2 changed files with 11 additions and 18 deletions

View File

@ -850,15 +850,15 @@ class SumkDFTTools(SumkDFT):
if mu is None: if mu is None:
mu = self.chemical_potential mu = self.chemical_potential
spn = self.spin_block_names[self.SO] spn = self.spin_block_names[self.SO]
mesh = [x.real for x in self.Sigma_imp_w[0].mesh] mesh = numpy.array([x.real for x in self.Sigma_imp_w[0].mesh])
n_om = len(mesh)
if plot_range is None: if plot_range is None:
om_minplot = mesh[0] - 0.001 om_minplot = mesh[0] - 0.001
om_maxplot = mesh[n_om - 1] + 0.001 om_maxplot = mesh[-1] + 0.001
else: else:
om_minplot = plot_range[0] om_minplot = plot_range[0]
om_maxplot = plot_range[1] om_maxplot = plot_range[1]
n_om = len(mesh[(mesh > om_minplot)&(mesh < om_maxplot)])
if ishell is None: if ishell is None:
Akw = {sp: numpy.zeros([self.n_k, n_om], numpy.float_) Akw = {sp: numpy.zeros([self.n_k, n_om], numpy.float_)
@ -882,14 +882,11 @@ class SumkDFTTools(SumkDFT):
if ishell is None: if ishell is None:
# Non-projected A(k,w) # Non-projected A(k,w)
for iom in range(n_om): for bname, gf in G_latt_w:
if (mesh[iom] > om_minplot) and (mesh[iom] < om_maxplot): Akw[bname][ik] = -gf.data[numpy.where((mesh > om_minplot)&(mesh < om_maxplot))].imag.trace(axis1=1, axis2=2)/numpy.pi
for bname, gf in G_latt_w: # shift Akw for plotting stacked k-resolved eps(k)
Akw[bname][ik, iom] += gf.data[iom, :, # curves
:].imag.trace() / (-1.0 * numpy.pi) Akw[bname][ik] += ik * plot_shift
# shift Akw for plotting stacked k-resolved eps(k)
# curves
Akw[bname][ik, iom] += ik * plot_shift
else: # ishell not None else: # ishell not None
# Projected A(k,w): # Projected A(k,w):
@ -907,13 +904,9 @@ class SumkDFTTools(SumkDFT):
G_loc[bname] << self.rotloc( G_loc[bname] << self.rotloc(
ishell, gf, direction='toLocal', shells='all') ishell, gf, direction='toLocal', shells='all')
for iom in range(n_om): for ish in range(self.shells[ishell]['dim']):
if (mesh[iom] > om_minplot) and (mesh[iom] < om_maxplot): for sp in spn:
for ish in range(self.shells[ishell]['dim']): Akw[sp][ish, ik] = -gf.data[numpy.where((mesh > om_minplot)&(mesh < om_maxplot))].imag.trace(axis1=1, axis2=2)/numpy.pi
for sp in spn:
Akw[sp][ish, ik, iom] = G_loc[sp].data[
iom, ish, ish].imag / (-1.0 * numpy.pi)
# Collect data from mpi # Collect data from mpi
for sp in spn: for sp in spn:
Akw[sp] = mpi.all_reduce(mpi.world, Akw[sp], lambda x, y: x + y) Akw[sp] = mpi.all_reduce(mpi.world, Akw[sp], lambda x, y: x + y)

Binary file not shown.