10
1
mirror of https://gitlab.com/scemama/qmcchem.git synced 2024-06-02 11:25:18 +02:00
qmcchem/src/opt_Jast/opt_jast_freegrad.py
2021-04-27 02:48:03 +02:00

226 lines
6.9 KiB
Python

#!/usr/bin/env python3
import sys, os
QMCCHEM_PATH=os.environ["QMCCHEM_PATH"]
sys.path.insert(0,QMCCHEM_PATH+"/EZFIO/Python/")
from ezfio import ezfio
from datetime import datetime
import time
import numpy as np
import subprocess
import atexit
import scipy as sp
import scipy.optimize
from math import sqrt
#------------------------------------------------------------------------------
def make_atom_map():
labels = {}
dimension = 0
# i: label of nuclei
# k: counter of nuclei
for i,k in enumerate(ezfio.nuclei_nucl_label):
if k in labels:
labels[k].append(i)
else:
labels[k] = [dimension, i]
dimension += 1
atom_map = [[] for i in range(dimension)]
for atom in labels.keys():
l = labels[atom]
atom_map[l[0]] = l[1:]
return atom_map
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def get_params_pen():
d = ezfio.jastrow_jast_pen
return np.array([d[m[0]] for m in atom_map])
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def set_params_pen(x):
y = list(ezfio.jastrow_jast_pen)
for i,m in enumerate(atom_map):
for j in m:
y[j] = x[i]
ezfio.set_jastrow_jast_pen(y)
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def get_params_b():
b = ezfio.get_jastrow_jast_b_up_up()
return b
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def set_params_b(b):
ezfio.set_jastrow_jast_b_up_up(b)
ezfio.set_jastrow_jast_b_up_dn(b)
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def get_energy():
buffer = subprocess.check_output(
['qmcchem', 'result', '-e', 'e_loc', EZFIO_file], encoding='UTF-8')
if buffer.strip() != "":
buffer = buffer.splitlines()[-1]
_, energy, error = [float(x) for x in buffer.split()]
return energy, error
else:
return None, None
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def get_variance():
buffer = subprocess.check_output(
['qmcchem', 'result', '-e', 'e_loc_qmcvar', EZFIO_file], encoding='UTF-8')
if buffer.strip() != "":
buffer = buffer.splitlines()[-1]
_, variance, error = [float(x) for x in buffer.split()]
return variance, error
else:
return None, None
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def set_vmc_params(block_time,total_time):
subprocess.check_output(['qmcchem', 'edit', '-c', '-j', 'Simple',
'-t', str(total_time),
'-l', str(block_time), EZFIO_file])
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def run_qmc():
return subprocess.check_output(['qmcchem', 'run', EZFIO_file])
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def stop_qmc():
subprocess.check_output(['qmcchem', 'stop', EZFIO_file])
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
def f(x):
# !!!
global i_fev
global memo_energy
print(' eval {} of f on:'.format(i_fev))
print(' nuc param Jast = {}'.format(x[:-1]))
print(' b param Jast = {}'.format(x[-1]))
h = str(x)
if h in memo_energy:
return memo_energy[h]
# !!!
i_fev = i_fev + 1
# !!!
set_params_pen(x[:-1])
set_params_b(x[-1])
block_time_f = 30
total_time_f = 65
set_vmc_params(block_time_f, total_time_f)
# !!!
loc_err = 10.
ii = 0
ii_max = 5
energy = None
err = None
while( thresh < loc_err ):
run_qmc()
energy, err = get_energy()
if( (energy is None) or (err is None) ):
continue
elif( memo_energy['fmin'] < (energy-2.*err) ):
print(" %d energy: %f %f "%(ii, energy, err))
break
else:
loc_err = err
ii = ii + 1
print(" %d energy: %f %f "%(ii, energy, err))
if( ii_max < ii ):
break
print(" ")
# !!!
memo_energy[h] = energy + err
memo_energy['fmin'] = min(energy, memo_energy['fmin'])
# !!!
return energy
#------------------------------------------------------------------------------
##
###
##
#------------------------------------------------------------------------------
if __name__ == '__main__':
# !!!
t0 = time.time()
# !!!
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ #
EZFIO_file = "/home/ammar/qp2/src/svdwf/h2o_optJast"
# PARAMETERS
thresh = 1.e-2
# maximum allowed number of function evaluations
N_fev = 4
# ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ ~ #
# !!!
ezfio.set_file(EZFIO_file)
print(" Today's date:", datetime.now() )
print(" EZFIO file = {}".format(EZFIO_file))
# !!!
# map nuclei to a list
# for H2O this will give: atom_map = [[0], [1, 2]]
atom_map = make_atom_map()
n_par = len(atom_map) # nb of nclear parameters
n_par = n_par + 1 # e-e parameter b
# !!!
# x = get_params_pen()
x = [1.29386006, 0.21362821]
print(' initial pen: {}'.format(x))
#b_par = get_params_b()
b_par = 1.5291090863304375
print(' initial b: {}'.format(b_par))
x.append(b_par)
# !!!
i_fev = 1
bnds = [(0.001, 9.99) for _ in range(n_par)]
memo_energy = {'fmin': 100.}
opt = sp.optimize.minimize(f, x, method="Powell", bounds=bnds
, options= {'disp':True,
'ftol':0.2,
'xtol':0.2,
'maxfev':5} )
print(" x = "+str(opt))
set_params_pen(opt['x'])
print(' number of function evaluations = {}'.format(i_fev))
# !!!
print(' memo_energy: {}'.format(memo_energy))
# !!!
print(" end after {:.3f} minutes".format((time.time()-t0)/60.) )
# !!!