mirror of
https://github.com/TREX-CoE/trexio.git
synced 2025-01-08 04:18:47 +01:00
Fix Python tests: align with the C ones
This commit is contained in:
parent
2c7d8eab7f
commit
91585a6811
@ -6,6 +6,62 @@ import trexio
|
||||
from benzene_data import *
|
||||
|
||||
|
||||
# this function is copied from the trexio-tools github repository (BSD-3 license):
|
||||
# https://github.com/TREX-CoE/trexio_tools/blob/master/src/trexio_tools/group_tools/determinant.py
|
||||
def to_determinant_list(orbital_list: list, int64_num: int) -> list:
|
||||
"""
|
||||
Convert a list of occupied orbitals from the `orbital_list`
|
||||
into a list of Slater determinants (in their bit string representation).
|
||||
|
||||
Orbitals in the `orbital_list` should be 0-based, namely the lowest orbital has index 0, not 1.
|
||||
|
||||
int64_num is the number of 64-bit integers needed to represent a Slater determinant bit string.
|
||||
It depends on the number of molecular orbitals as follows: int64_num = int((mo_num-1)/64) + 1
|
||||
"""
|
||||
|
||||
if not isinstance(orbital_list, list):
|
||||
raise TypeError(f"orbital_list should be a list, not {type(orbital_list)}")
|
||||
|
||||
det_list = []
|
||||
bitfield = 0
|
||||
shift = 0
|
||||
|
||||
# since orbital indices are 0-based but the code below works for 1-based --> increment the input indices by +1
|
||||
orb_list_upshifted = [ orb+1 for orb in orbital_list]
|
||||
|
||||
# orbital list has to be sorted in increasing order for the bitfields to be set correctly
|
||||
orb_list_sorted = sorted(orb_list_upshifted)
|
||||
|
||||
for orb in orb_list_sorted:
|
||||
|
||||
if orb-shift > 64:
|
||||
# this removes the 0 bit from the beginning of the bitfield
|
||||
bitfield = bitfield >> 1
|
||||
# append a bitfield to the list
|
||||
det_list.append(bitfield)
|
||||
bitfield = 0
|
||||
|
||||
modulo = int((orb-1)/64)
|
||||
shift = modulo*64
|
||||
bitfield |= (1 << (orb-shift))
|
||||
|
||||
# this removes the 0 bit from the beginning of the bitfield
|
||||
bitfield = bitfield >> 1
|
||||
det_list.append(bitfield)
|
||||
#print('Popcounts: ', [bin(d).count('1') for d in det_list)
|
||||
#print('Bitfields: ', [bin(d) for d in det_list])
|
||||
|
||||
bitfield_num = len(det_list)
|
||||
if bitfield_num > int64_num:
|
||||
raise Exception(f'Number of bitfields {bitfield_num} cannot be more than the int64_num {int64_num}.')
|
||||
if bitfield_num < int64_num:
|
||||
for _ in range(int64_num - bitfield_num):
|
||||
print("Appending an empty bitfield.")
|
||||
det_list.append(0)
|
||||
|
||||
return det_list
|
||||
|
||||
|
||||
def clean(back_end, filename):
|
||||
"""Remove test files."""
|
||||
if back_end == trexio.TREXIO_HDF5:
|
||||
@ -206,16 +262,36 @@ class TestIO:
|
||||
"""Write CI determinants and coefficients."""
|
||||
self.open()
|
||||
# write mo_num (needed later to write determinants)
|
||||
trexio.write_mo_num(self.test_file, mo_num)
|
||||
MO_NUM_TEST = 100
|
||||
trexio.write_mo_num(self.test_file, MO_NUM_TEST)
|
||||
# get the number of bit-strings per spin component
|
||||
INT64_NUM_TEST = int((MO_NUM_TEST-1)/64) + 1
|
||||
int_num = trexio.get_int64_num(self.test_file)
|
||||
assert int_num == int64_num
|
||||
assert int_num == INT64_NUM_TEST
|
||||
# write the number of up and down electrons
|
||||
trexio.write_electron_up_num(self.test_file, 4)
|
||||
trexio.write_electron_dn_num(self.test_file, 3)
|
||||
# orbital lists
|
||||
orb_list_up = [0,1,2,3]
|
||||
orb_list_dn = [0,1,2]
|
||||
|
||||
# data to write
|
||||
DET_NUM_TEST = 100
|
||||
det_up = to_determinant_list(orb_list_up, INT64_NUM_TEST)
|
||||
det_dn = to_determinant_list(orb_list_dn, INT64_NUM_TEST)
|
||||
|
||||
det_list = []
|
||||
coeff_list = []
|
||||
for i in range(DET_NUM_TEST):
|
||||
det_list.append(det_up + det_dn)
|
||||
coeff_list.append(3.14 + float(i))
|
||||
|
||||
# write the data for the ground state
|
||||
offset = 0
|
||||
trexio.write_state_id(self.test_file, 0)
|
||||
trexio.write_determinant_list(self.test_file, offset, det_num, dets)
|
||||
trexio.write_determinant_list(self.test_file, offset, DET_NUM_TEST, det_list)
|
||||
assert trexio.has_determinant_list(self.test_file)
|
||||
trexio.write_determinant_coefficient(self.test_file, offset, det_num, coeffs)
|
||||
trexio.write_determinant_coefficient(self.test_file, offset, DET_NUM_TEST, coeff_list)
|
||||
assert trexio.has_determinant_coefficient(self.test_file)
|
||||
# manually check the consistency between coefficient_size and number of determinants
|
||||
assert trexio.read_determinant_coefficient_size(self.test_file) == trexio.read_determinant_num(self.test_file)
|
||||
@ -350,26 +426,6 @@ class TestIO:
|
||||
if self.test_file.isOpen:
|
||||
self.test_file.close()
|
||||
|
||||
def test_determinant_read(self):
|
||||
"""Read the CI determinants."""
|
||||
self.open(mode='r')
|
||||
# read determinants (list of ints and float coefficients)
|
||||
buf_size = 100
|
||||
offset_file = 0
|
||||
# read full buf_size (i.e. the one that does not reach EOF)
|
||||
dets_np, read_buf_size, eof = trexio.read_determinant_list(self.test_file, offset_file, buf_size)
|
||||
#print(f'First complete read of determinant list: {read_buf_size}')
|
||||
assert not eof
|
||||
assert read_buf_size == buf_size
|
||||
assert dets_np[0][0] == 0
|
||||
assert dets_np[read_buf_size-1][int64_num*2-1] == read_buf_size * int64_num * 2- 1
|
||||
|
||||
coefficients_np, read_buf_size, eof = trexio.read_determinant_coefficient(self.test_file, offset_file, buf_size)
|
||||
#print(f'First complete read of determinant coefficients: {read_buf_size}')
|
||||
assert not eof
|
||||
assert read_buf_size == buf_size
|
||||
if self.test_file.isOpen:
|
||||
self.test_file.close()
|
||||
|
||||
def test_array_str_read(self):
|
||||
"""Read an array of strings."""
|
||||
|
Loading…
Reference in New Issue
Block a user