2021-08-18 12:40:27 +02:00
import os
import shutil
import numpy as np
2021-08-18 14:42:10 +02:00
import trexio as tr
2021-08-18 12:40:27 +02:00
#=========================================================#
#======== SETUP THE BACK END AND OUTPUT FILE NAME ========#
#=========================================================#
# 0: TREXIO_HDF5 ; 1: TREXIO_TEXT
2021-08-27 15:08:39 +02:00
TEST_TREXIO_BACKEND = 0
2021-08-18 12:40:27 +02:00
OUTPUT_FILENAME_TEXT = ' test_py_swig.dir '
OUTPUT_FILENAME_HDF5 = ' test_py_swig.h5 '
2021-08-21 12:14:59 +02:00
# define TREXIO file name
2021-08-18 12:40:27 +02:00
if TEST_TREXIO_BACKEND == tr . TREXIO_HDF5 :
output_filename = OUTPUT_FILENAME_HDF5
elif TEST_TREXIO_BACKEND == tr . TREXIO_TEXT :
output_filename = OUTPUT_FILENAME_TEXT
else :
raise ValueError ( ' Specify one of the supported back ends as TEST_TREXIO_BACKEND ' )
2021-08-21 12:14:59 +02:00
# remove TREXIO file if exists in the current directory
2021-08-18 12:40:27 +02:00
try :
if TEST_TREXIO_BACKEND == tr . TREXIO_HDF5 :
os . remove ( output_filename )
elif TEST_TREXIO_BACKEND == tr . TREXIO_TEXT :
shutil . rmtree ( output_filename )
except :
print ( ' Nothing to remove. ' )
#=========================================================#
#============ WRITE THE DATA IN THE TEST FILE ============#
#=========================================================#
2021-08-27 15:08:39 +02:00
2021-08-21 12:14:59 +02:00
# create TREXIO file and open it for writing
2021-08-26 16:01:53 +02:00
#test_file = tr.open(output_filename, 'w', TEST_TREXIO_BACKEND)
test_file = tr . File ( output_filename , mode = ' w ' , back_end = TEST_TREXIO_BACKEND )
2021-08-18 12:40:27 +02:00
2021-08-21 12:14:59 +02:00
# Print docstring of the tr.open function
#print(tr.open.__doc__)
2021-08-18 12:40:27 +02:00
nucleus_num = 12
2021-08-21 12:14:59 +02:00
# write nucleus_num in the file
2021-08-18 12:40:27 +02:00
tr . write_nucleus_num ( test_file , nucleus_num )
# initialize charge arrays as a list and convert it to numpy array
charges = [ 6. , 6. , 6. , 6. , 6. , 6. , 1. , 1. , 1. , 1. , 1. , 1. ]
2021-08-26 12:14:46 +02:00
#charges_np = np.array(charges, dtype=np.float32)
charges_np = np . array ( charges , dtype = np . int32 )
2021-08-18 12:40:27 +02:00
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemaps from numpy.i
2021-08-24 11:51:43 +02:00
tr . write_nucleus_charge ( test_file , charges_np )
2021-08-18 12:40:27 +02:00
# initialize arrays of nuclear indices as a list and convert it to numpy array
indices = [ i for i in range ( nucleus_num ) ]
# type cast is important here because by default numpy transforms a list of integers into int64 array
2021-08-26 12:14:46 +02:00
indices_np = np . array ( indices , dtype = np . int64 )
2021-08-18 12:40:27 +02:00
# function call below works with both lists and numpy arrays, dimension needed for memory-safety is derived
# from the size of the list/array by SWIG using typemacs from numpy.i
2021-08-20 12:43:15 +02:00
tr . write_basis_nucleus_index ( test_file , indices_np )
2021-08-18 12:40:27 +02:00
2021-08-24 11:51:43 +02:00
# initialize a list of nuclear coordinates
coords = [
0.00000000 , 1.39250319 , 0.00000000 ,
- 1.20594314 , 0.69625160 , 0.00000000 ,
- 1.20594314 , - 0.69625160 , 0.00000000 ,
0.00000000 , - 1.39250319 , 0.00000000 ,
1.20594314 , - 0.69625160 , 0.00000000 ,
1.20594314 , 0.69625160 , 0.00000000 ,
- 2.14171677 , 1.23652075 , 0.00000000 ,
- 2.14171677 , - 1.23652075 , 0.00000000 ,
0.00000000 , - 2.47304151 , 0.00000000 ,
2.14171677 , - 1.23652075 , 0.00000000 ,
2.14171677 , 1.23652075 , 0.00000000 ,
0.00000000 , 2.47304151 , 0.00000000 ,
]
# write coordinates in the file
tr . write_nucleus_coord ( test_file , coords )
2021-08-18 12:40:27 +02:00
point_group = ' B3U '
2021-08-21 12:14:59 +02:00
# write nucleus_point_group in the file
2021-08-18 12:40:27 +02:00
tr . write_nucleus_point_group ( test_file , point_group )
labels = [
' C ' ,
' C ' ,
' C ' ,
' C ' ,
' C ' ,
' C ' ,
' H ' ,
' H ' ,
' H ' ,
' H ' ,
' H ' ,
' H ' ]
2021-08-21 12:14:59 +02:00
# write nucleus_label in the file
2021-08-18 12:40:27 +02:00
tr . write_nucleus_label ( test_file , labels )
2021-08-21 12:14:59 +02:00
# close TREXIO file
2021-08-26 16:01:53 +02:00
# [TODO:] this functional call is no longer needed as we introduced TREXIO_File class which has a desctructor that closes the file
#tr.close(test_file)
# [TODO:] without calling destructor on test_file the TREXIO_FILE is not getting created and the data is not written when using TEXT back end. This, the user still has to explicitly call destructor on test_file object instead
# tr.close function. This is only an issue when the data is getting written and read in the same session (e.g. in Jupyter notebook)
del test_file
2021-08-26 12:14:46 +02:00
2021-08-27 15:08:39 +02:00
2021-08-18 12:40:27 +02:00
#==========================================================#
#============ READ THE DATA FROM THE TEST FILE ============#
#==========================================================#
2021-08-21 12:14:59 +02:00
# open previously created TREXIO file, now in 'read' mode
2021-08-26 16:01:53 +02:00
#test_file2 = tr.open(output_filename, 'r', TEST_TREXIO_BACKEND)
test_file2 = tr . File ( output_filename , ' r ' , TEST_TREXIO_BACKEND )
2021-08-18 12:40:27 +02:00
2021-08-21 12:14:59 +02:00
# read nucleus_num from file
2021-08-18 12:40:27 +02:00
rnum = tr . read_nucleus_num ( test_file2 )
assert rnum == nucleus_num
2021-08-21 12:14:59 +02:00
# safe call to read_nucleus_charge array of float values
2021-08-20 12:43:15 +02:00
rcharges_np = tr . read_nucleus_charge ( test_file2 , dim = nucleus_num )
2021-08-18 12:40:27 +02:00
assert rcharges_np . dtype is np . dtype ( np . float64 )
np . testing . assert_array_almost_equal ( rcharges_np , charges_np , decimal = 8 )
2021-08-18 14:42:10 +02:00
# unsafe call to read_safe should fail with error message corresponding to TREXIO_UNSAFE_ARRAY_DIM
try :
2021-08-20 12:43:15 +02:00
rcharges_fail = tr . read_nucleus_charge ( test_file2 , dim = nucleus_num * 5 )
2021-08-18 14:42:10 +02:00
except Exception :
2021-08-21 12:14:59 +02:00
print ( " Unsafe call to safe API: checked " )
2021-08-18 12:40:27 +02:00
2021-08-21 12:14:59 +02:00
# safe call to read array of int values (nuclear indices)
2021-08-26 12:14:46 +02:00
rindices_np_16 = tr . read_basis_nucleus_index ( test_file2 , dim = nucleus_num , dtype = np . int16 )
assert rindices_np_16 . dtype is np . dtype ( np . int16 )
for i in range ( nucleus_num ) :
assert rindices_np_16 [ i ] == indices_np [ i ]
rindices_np_32 = tr . read_basis_nucleus_index ( test_file2 , dim = nucleus_num , dtype = np . int32 )
assert rindices_np_32 . dtype is np . dtype ( np . int32 )
for i in range ( nucleus_num ) :
assert rindices_np_32 [ i ] == indices_np [ i ]
rindices_np_64 = tr . read_basis_nucleus_index ( test_file2 )
assert rindices_np_64 . dtype is np . dtype ( np . int64 )
assert rindices_np_64 . size == nucleus_num
2021-08-18 12:40:27 +02:00
for i in range ( nucleus_num ) :
2021-08-26 12:14:46 +02:00
assert rindices_np_64 [ i ] == indices_np [ i ]
2021-08-18 12:40:27 +02:00
2021-08-24 11:51:43 +02:00
# read nuclear coordinates without providing optional argument dim
rcoords_np = tr . read_nucleus_coord ( test_file2 )
2021-08-27 15:08:39 +02:00
2021-08-24 11:51:43 +02:00
assert rcoords_np . size == nucleus_num * 3
2021-08-27 15:08:39 +02:00
np . testing . assert_array_almost_equal ( rcoords_np , np . array ( coords ) . reshape ( nucleus_num , 3 ) , decimal = 8 )
# set doReshape to False to get a flat 1D array (e.g. when reading matrices like nuclear coordinates)
#rcoords_reshaped_2 = tr.read_nucleus_coord(test_file2, doReshape=False)
2021-08-24 11:51:43 +02:00
2021-08-21 12:14:59 +02:00
# read array of nuclear labels
2021-08-20 12:43:15 +02:00
rlabels_2d = tr . read_nucleus_label ( test_file2 , dim = nucleus_num )
2021-08-18 12:40:27 +02:00
print ( rlabels_2d )
for i in range ( nucleus_num ) :
assert rlabels_2d [ i ] == labels [ i ]
2021-08-21 12:14:59 +02:00
# read a string corresponding to nuclear point group
2021-08-18 12:40:27 +02:00
rpoint_group = tr . read_nucleus_point_group ( test_file2 )
assert rpoint_group == point_group
2021-08-21 12:14:59 +02:00
# close TREXIO file
2021-08-26 16:01:53 +02:00
#tr.close(test_file2)
2021-08-18 12:40:27 +02:00
2021-08-21 12:14:59 +02:00
# cleaning (remove the TREXIO file)
2021-08-18 12:40:27 +02:00
try :
if TEST_TREXIO_BACKEND == tr . TREXIO_HDF5 :
os . remove ( output_filename )
elif TEST_TREXIO_BACKEND == tr . TREXIO_TEXT :
shutil . rmtree ( output_filename )
except :
print ( f ' No output file { output_filename } has been produced ' )
#==========================================================#