1
0
mirror of https://github.com/TREX-CoE/trexio.git synced 2025-01-08 04:18:47 +01:00

Fix Python tests: align with the C ones

This commit is contained in:
q-posev 2025-01-02 21:29:36 +01:00
parent 2c7d8eab7f
commit 91585a6811

View File

@ -6,6 +6,62 @@ import trexio
from benzene_data import *
# this function is copied from the trexio-tools github repository (BSD-3 license):
# https://github.com/TREX-CoE/trexio_tools/blob/master/src/trexio_tools/group_tools/determinant.py
def to_determinant_list(orbital_list: list, int64_num: int) -> list:
"""
Convert a list of occupied orbitals from the `orbital_list`
into a list of Slater determinants (in their bit string representation).
Orbitals in the `orbital_list` should be 0-based, namely the lowest orbital has index 0, not 1.
int64_num is the number of 64-bit integers needed to represent a Slater determinant bit string.
It depends on the number of molecular orbitals as follows: int64_num = int((mo_num-1)/64) + 1
"""
if not isinstance(orbital_list, list):
raise TypeError(f"orbital_list should be a list, not {type(orbital_list)}")
det_list = []
bitfield = 0
shift = 0
# since orbital indices are 0-based but the code below works for 1-based --> increment the input indices by +1
orb_list_upshifted = [ orb+1 for orb in orbital_list]
# orbital list has to be sorted in increasing order for the bitfields to be set correctly
orb_list_sorted = sorted(orb_list_upshifted)
for orb in orb_list_sorted:
if orb-shift > 64:
# this removes the 0 bit from the beginning of the bitfield
bitfield = bitfield >> 1
# append a bitfield to the list
det_list.append(bitfield)
bitfield = 0
modulo = int((orb-1)/64)
shift = modulo*64
bitfield |= (1 << (orb-shift))
# this removes the 0 bit from the beginning of the bitfield
bitfield = bitfield >> 1
det_list.append(bitfield)
#print('Popcounts: ', [bin(d).count('1') for d in det_list)
#print('Bitfields: ', [bin(d) for d in det_list])
bitfield_num = len(det_list)
if bitfield_num > int64_num:
raise Exception(f'Number of bitfields {bitfield_num} cannot be more than the int64_num {int64_num}.')
if bitfield_num < int64_num:
for _ in range(int64_num - bitfield_num):
print("Appending an empty bitfield.")
det_list.append(0)
return det_list
def clean(back_end, filename):
"""Remove test files."""
if back_end == trexio.TREXIO_HDF5:
@ -206,16 +262,36 @@ class TestIO:
"""Write CI determinants and coefficients."""
self.open()
# write mo_num (needed later to write determinants)
trexio.write_mo_num(self.test_file, mo_num)
MO_NUM_TEST = 100
trexio.write_mo_num(self.test_file, MO_NUM_TEST)
# get the number of bit-strings per spin component
INT64_NUM_TEST = int((MO_NUM_TEST-1)/64) + 1
int_num = trexio.get_int64_num(self.test_file)
assert int_num == int64_num
assert int_num == INT64_NUM_TEST
# write the number of up and down electrons
trexio.write_electron_up_num(self.test_file, 4)
trexio.write_electron_dn_num(self.test_file, 3)
# orbital lists
orb_list_up = [0,1,2,3]
orb_list_dn = [0,1,2]
# data to write
DET_NUM_TEST = 100
det_up = to_determinant_list(orb_list_up, INT64_NUM_TEST)
det_dn = to_determinant_list(orb_list_dn, INT64_NUM_TEST)
det_list = []
coeff_list = []
for i in range(DET_NUM_TEST):
det_list.append(det_up + det_dn)
coeff_list.append(3.14 + float(i))
# write the data for the ground state
offset = 0
trexio.write_state_id(self.test_file, 0)
trexio.write_determinant_list(self.test_file, offset, det_num, dets)
trexio.write_determinant_list(self.test_file, offset, DET_NUM_TEST, det_list)
assert trexio.has_determinant_list(self.test_file)
trexio.write_determinant_coefficient(self.test_file, offset, det_num, coeffs)
trexio.write_determinant_coefficient(self.test_file, offset, DET_NUM_TEST, coeff_list)
assert trexio.has_determinant_coefficient(self.test_file)
# manually check the consistency between coefficient_size and number of determinants
assert trexio.read_determinant_coefficient_size(self.test_file) == trexio.read_determinant_num(self.test_file)
@ -350,26 +426,6 @@ class TestIO:
if self.test_file.isOpen:
self.test_file.close()
def test_determinant_read(self):
"""Read the CI determinants."""
self.open(mode='r')
# read determinants (list of ints and float coefficients)
buf_size = 100
offset_file = 0
# read full buf_size (i.e. the one that does not reach EOF)
dets_np, read_buf_size, eof = trexio.read_determinant_list(self.test_file, offset_file, buf_size)
#print(f'First complete read of determinant list: {read_buf_size}')
assert not eof
assert read_buf_size == buf_size
assert dets_np[0][0] == 0
assert dets_np[read_buf_size-1][int64_num*2-1] == read_buf_size * int64_num * 2- 1
coefficients_np, read_buf_size, eof = trexio.read_determinant_coefficient(self.test_file, offset_file, buf_size)
#print(f'First complete read of determinant coefficients: {read_buf_size}')
assert not eof
assert read_buf_size == buf_size
if self.test_file.isOpen:
self.test_file.close()
def test_array_str_read(self):
"""Read an array of strings."""