mirror of
https://github.com/TREX-CoE/trexio.git
synced 2025-01-09 12:44:11 +01:00
Fix Python tests: align with the C ones
This commit is contained in:
parent
2c7d8eab7f
commit
91585a6811
@ -6,6 +6,62 @@ import trexio
|
|||||||
from benzene_data import *
|
from benzene_data import *
|
||||||
|
|
||||||
|
|
||||||
|
# this function is copied from the trexio-tools github repository (BSD-3 license):
|
||||||
|
# https://github.com/TREX-CoE/trexio_tools/blob/master/src/trexio_tools/group_tools/determinant.py
|
||||||
|
def to_determinant_list(orbital_list: list, int64_num: int) -> list:
|
||||||
|
"""
|
||||||
|
Convert a list of occupied orbitals from the `orbital_list`
|
||||||
|
into a list of Slater determinants (in their bit string representation).
|
||||||
|
|
||||||
|
Orbitals in the `orbital_list` should be 0-based, namely the lowest orbital has index 0, not 1.
|
||||||
|
|
||||||
|
int64_num is the number of 64-bit integers needed to represent a Slater determinant bit string.
|
||||||
|
It depends on the number of molecular orbitals as follows: int64_num = int((mo_num-1)/64) + 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(orbital_list, list):
|
||||||
|
raise TypeError(f"orbital_list should be a list, not {type(orbital_list)}")
|
||||||
|
|
||||||
|
det_list = []
|
||||||
|
bitfield = 0
|
||||||
|
shift = 0
|
||||||
|
|
||||||
|
# since orbital indices are 0-based but the code below works for 1-based --> increment the input indices by +1
|
||||||
|
orb_list_upshifted = [ orb+1 for orb in orbital_list]
|
||||||
|
|
||||||
|
# orbital list has to be sorted in increasing order for the bitfields to be set correctly
|
||||||
|
orb_list_sorted = sorted(orb_list_upshifted)
|
||||||
|
|
||||||
|
for orb in orb_list_sorted:
|
||||||
|
|
||||||
|
if orb-shift > 64:
|
||||||
|
# this removes the 0 bit from the beginning of the bitfield
|
||||||
|
bitfield = bitfield >> 1
|
||||||
|
# append a bitfield to the list
|
||||||
|
det_list.append(bitfield)
|
||||||
|
bitfield = 0
|
||||||
|
|
||||||
|
modulo = int((orb-1)/64)
|
||||||
|
shift = modulo*64
|
||||||
|
bitfield |= (1 << (orb-shift))
|
||||||
|
|
||||||
|
# this removes the 0 bit from the beginning of the bitfield
|
||||||
|
bitfield = bitfield >> 1
|
||||||
|
det_list.append(bitfield)
|
||||||
|
#print('Popcounts: ', [bin(d).count('1') for d in det_list)
|
||||||
|
#print('Bitfields: ', [bin(d) for d in det_list])
|
||||||
|
|
||||||
|
bitfield_num = len(det_list)
|
||||||
|
if bitfield_num > int64_num:
|
||||||
|
raise Exception(f'Number of bitfields {bitfield_num} cannot be more than the int64_num {int64_num}.')
|
||||||
|
if bitfield_num < int64_num:
|
||||||
|
for _ in range(int64_num - bitfield_num):
|
||||||
|
print("Appending an empty bitfield.")
|
||||||
|
det_list.append(0)
|
||||||
|
|
||||||
|
return det_list
|
||||||
|
|
||||||
|
|
||||||
def clean(back_end, filename):
|
def clean(back_end, filename):
|
||||||
"""Remove test files."""
|
"""Remove test files."""
|
||||||
if back_end == trexio.TREXIO_HDF5:
|
if back_end == trexio.TREXIO_HDF5:
|
||||||
@ -206,16 +262,36 @@ class TestIO:
|
|||||||
"""Write CI determinants and coefficients."""
|
"""Write CI determinants and coefficients."""
|
||||||
self.open()
|
self.open()
|
||||||
# write mo_num (needed later to write determinants)
|
# write mo_num (needed later to write determinants)
|
||||||
trexio.write_mo_num(self.test_file, mo_num)
|
MO_NUM_TEST = 100
|
||||||
|
trexio.write_mo_num(self.test_file, MO_NUM_TEST)
|
||||||
# get the number of bit-strings per spin component
|
# get the number of bit-strings per spin component
|
||||||
|
INT64_NUM_TEST = int((MO_NUM_TEST-1)/64) + 1
|
||||||
int_num = trexio.get_int64_num(self.test_file)
|
int_num = trexio.get_int64_num(self.test_file)
|
||||||
assert int_num == int64_num
|
assert int_num == INT64_NUM_TEST
|
||||||
|
# write the number of up and down electrons
|
||||||
|
trexio.write_electron_up_num(self.test_file, 4)
|
||||||
|
trexio.write_electron_dn_num(self.test_file, 3)
|
||||||
|
# orbital lists
|
||||||
|
orb_list_up = [0,1,2,3]
|
||||||
|
orb_list_dn = [0,1,2]
|
||||||
|
|
||||||
|
# data to write
|
||||||
|
DET_NUM_TEST = 100
|
||||||
|
det_up = to_determinant_list(orb_list_up, INT64_NUM_TEST)
|
||||||
|
det_dn = to_determinant_list(orb_list_dn, INT64_NUM_TEST)
|
||||||
|
|
||||||
|
det_list = []
|
||||||
|
coeff_list = []
|
||||||
|
for i in range(DET_NUM_TEST):
|
||||||
|
det_list.append(det_up + det_dn)
|
||||||
|
coeff_list.append(3.14 + float(i))
|
||||||
|
|
||||||
# write the data for the ground state
|
# write the data for the ground state
|
||||||
offset = 0
|
offset = 0
|
||||||
trexio.write_state_id(self.test_file, 0)
|
trexio.write_state_id(self.test_file, 0)
|
||||||
trexio.write_determinant_list(self.test_file, offset, det_num, dets)
|
trexio.write_determinant_list(self.test_file, offset, DET_NUM_TEST, det_list)
|
||||||
assert trexio.has_determinant_list(self.test_file)
|
assert trexio.has_determinant_list(self.test_file)
|
||||||
trexio.write_determinant_coefficient(self.test_file, offset, det_num, coeffs)
|
trexio.write_determinant_coefficient(self.test_file, offset, DET_NUM_TEST, coeff_list)
|
||||||
assert trexio.has_determinant_coefficient(self.test_file)
|
assert trexio.has_determinant_coefficient(self.test_file)
|
||||||
# manually check the consistency between coefficient_size and number of determinants
|
# manually check the consistency between coefficient_size and number of determinants
|
||||||
assert trexio.read_determinant_coefficient_size(self.test_file) == trexio.read_determinant_num(self.test_file)
|
assert trexio.read_determinant_coefficient_size(self.test_file) == trexio.read_determinant_num(self.test_file)
|
||||||
@ -350,26 +426,6 @@ class TestIO:
|
|||||||
if self.test_file.isOpen:
|
if self.test_file.isOpen:
|
||||||
self.test_file.close()
|
self.test_file.close()
|
||||||
|
|
||||||
def test_determinant_read(self):
|
|
||||||
"""Read the CI determinants."""
|
|
||||||
self.open(mode='r')
|
|
||||||
# read determinants (list of ints and float coefficients)
|
|
||||||
buf_size = 100
|
|
||||||
offset_file = 0
|
|
||||||
# read full buf_size (i.e. the one that does not reach EOF)
|
|
||||||
dets_np, read_buf_size, eof = trexio.read_determinant_list(self.test_file, offset_file, buf_size)
|
|
||||||
#print(f'First complete read of determinant list: {read_buf_size}')
|
|
||||||
assert not eof
|
|
||||||
assert read_buf_size == buf_size
|
|
||||||
assert dets_np[0][0] == 0
|
|
||||||
assert dets_np[read_buf_size-1][int64_num*2-1] == read_buf_size * int64_num * 2- 1
|
|
||||||
|
|
||||||
coefficients_np, read_buf_size, eof = trexio.read_determinant_coefficient(self.test_file, offset_file, buf_size)
|
|
||||||
#print(f'First complete read of determinant coefficients: {read_buf_size}')
|
|
||||||
assert not eof
|
|
||||||
assert read_buf_size == buf_size
|
|
||||||
if self.test_file.isOpen:
|
|
||||||
self.test_file.close()
|
|
||||||
|
|
||||||
def test_array_str_read(self):
|
def test_array_str_read(self):
|
||||||
"""Read an array of strings."""
|
"""Read an array of strings."""
|
||||||
|
Loading…
Reference in New Issue
Block a user